diff --git a/.github/workflows/cibuildwheel.yml b/.github/workflows/cibuildwheel.yml index c8dcbab..1ab2818 100644 --- a/.github/workflows/cibuildwheel.yml +++ b/.github/workflows/cibuildwheel.yml @@ -21,7 +21,7 @@ jobs: actions: read with: wheel-name-pattern: "srsly-*.whl" - pure-python: false + pure-python: true create-release: ${{ startsWith(github.ref, 'refs/tags/release-') || startsWith(github.ref, 'refs/tags/prerelease-') }} secrets: gh-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 84b3ceb..c31bff0 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -1,5 +1,5 @@ -# The cibuildwheel action triggers on creation of a release, this -# triggers on publication. +# The cibuildwheel action triggers on all pushes and PRs; +# this action triggers on release publication. # The expected workflow is to create a draft release and let the wheels # upload, and then hit 'publish', which uploads to PyPi. diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2bfc614..e208c3a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,17 +14,26 @@ concurrency: cancel-in-progress: true env: - MODULE_NAME: 'srsly' RUN_MYPY: 'false' jobs: tests: + name: ${{ matrix.python_version }} ${{ matrix.os }} numpy=${{ matrix.numpy }} strategy: fail-fast: false matrix: os: [ubuntu-latest, windows-latest] - # FIXME: ujson segfault on 3.14 - python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + # Note: ruamel.yaml does not support Python 3.13t + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"] + numpy: [true] + include: + - os: ubuntu-latest + python_version: "3.9" + numpy: false + - os: ubuntu-latest + python_version: "3.14" + numpy: false + runs-on: ${{ matrix.os }} steps: @@ -35,46 +44,26 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python_version }} - architecture: x64 - - - name: Build sdist - run: | - python -m pip install -U build pip setuptools - python -m pip install -U -r requirements.txt - python -m build --sdist - name: Run mypy - shell: bash if: ${{ env.RUN_MYPY == 'true' }} run: | - python -m mypy $MODULE_NAME - - - name: Delete source directory - shell: bash - run: | - rm -rf $MODULE_NAME + python -m pip install mypy + python -m mypy srsly - - name: Uninstall all packages - run: | - python -m pip freeze > installed.txt - python -m pip uninstall -y -r installed.txt + - name: Install package + run: python -m pip install . - - name: Install from sdist + - name: Test that numpy was not installed shell: bash - run: | - SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1) - python -m pip install dist/$SDIST + run: if pip list | grep -q '^numpy'; then exit 1; fi - - name: Test import - shell: bash - run: | - python -c "import $MODULE_NAME" -Werror + - name: Install numpy + if: ${{ matrix.numpy }} + run: python -m pip install numpy - name: Install test requirements - run: | - python -m pip install -U -r requirements.txt + run: python -m pip install pytest - name: Run tests - shell: bash - run: | - python -m pytest --pyargs $MODULE_NAME -Werror + run: pytest diff --git a/.gitignore b/.gitignore index a327cc0..b5cbd43 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ .env/ .env* .vscode/ -cythonize.json # Byte-compiled / optimized / DLL files __pycache__/ @@ -108,8 +107,5 @@ venv.bak/ # mypy .mypy_cache/ -# Cython intermediate files -*.cpp - # Vim files *.sw* diff --git a/README.md b/README.md index b5e49d5..56a691d 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # srsly: Modern high-performance serialization utilities for Python This package bundles some of the best Python serialization libraries into one -standalone package, with a high-level API that makes it easy to write code +convenience package, with a high-level API that makes it easy to write code that's correct across platforms and Pythons. This allows us to provide all the serialization utilities we need in a single binary wheel. Currently supports **JSON**, **JSONL**, **MessagePack**, **Pickle** and **YAML**. @@ -24,31 +24,24 @@ wrap the multiple serialization formats we need to support (especially `json`, `msgpack` and `pickle`). These wrapping functions ended up duplicated across our codebases, so we wanted to put them in one place. -At the same time, we noticed that having a lot of small dependencies was making -maintenance harder, and making installation slower. To solve this, we've made -`srsly` standalone, by including the component packages directly within it. This -way we can provide all the serialization utilities we need in a single binary -wheel. - -`srsly` currently includes forks of the following packages: +`srsly` currently includes wrappers around the following packages: - [`ujson`](https://github.com/esnme/ultrajson) - [`msgpack`](https://github.com/msgpack/msgpack-python) -- [`msgpack-numpy`](https://github.com/lebedov/msgpack-numpy) - [`cloudpickle`](https://github.com/cloudpipe/cloudpickle) - [`ruamel.yaml`](https://github.com/pycontribs/ruamel-yaml) (without unsafe implementations!) -## Installation +Additionally, it includes a heavily customized fork of +[`msgpack-numpy`](https://github.com/lebedov/msgpack-numpy), with corrected +round-trip behaviour for np.float64 objects. + -> ⚠️ Note that `v2.x` is only compatible with **Python 3.6+**. For 2.7+ -> compatibility, use `v1.x`. +## Installation -`srsly` can be installed from pip. Before installing, make sure that your `pip`, -`setuptools` and `wheel` are up to date. +`srsly` can be installed from pip. ```bash -python -m pip install -U pip setuptools wheel python -m pip install srsly ``` @@ -58,12 +51,15 @@ Or from conda via conda-forge: conda install -c conda-forge srsly ``` -Alternatively, you can also compile the library from source. You'll need to make -sure that you have a development environment with a Python distribution -including header files, a compiler (XCode command-line tools on macOS / OS X or -Visual C++ build tools on Windows), pip and git installed. +This will automatically install/upgrade all dependencies. + +numpy and cupy are optional dependencies for msgpack. +If numpy is installed, numpy objects can be serialized. +If cupy is installed, cupy objects will be automaticaly converted +to numpy and then serialized. + -Install from source: +Alternatively, you can also install the library from the repository: ```bash # clone the repo @@ -74,10 +70,7 @@ cd srsly python -m venv .env source .env/bin/activate -# update pip -python -m pip install -U pip setuptools wheel - -# compile and install from source +# install from source python -m pip install . ``` @@ -86,7 +79,6 @@ mode without build isolation: ```bash # install in editable mode -python -m pip install -r requirements.txt python -m pip install --no-build-isolation --editable . # run test suite @@ -97,9 +89,6 @@ python -m pytest --pyargs srsly ### JSON -> 📦 The underlying module is exposed via `srsly.ujson`. However, we normally -> interact with it via the utility functions only. - #### function `srsly.json_dumps` Serialize an object to a JSON string. Falls back to `json` if `sort_keys=True` @@ -264,9 +253,6 @@ assert srsly.is_json_serializable(lambda x: x) is False ### msgpack -> 📦 The underlying module is exposed via `srsly.msgpack`. However, we normally -> interact with it via the utility functions only. - #### function `srsly.msgpack_dumps` Serialize an object to a msgpack byte string. @@ -326,9 +312,6 @@ data = srsly.read_msgpack("/path/to/file.msg") ### pickle -> 📦 The underlying module is exposed via `srsly.cloudpickle`. However, we -> normally interact with it via the utility functions only. - #### function `srsly.pickle_dumps` Serialize a Python object with pickle. @@ -360,9 +343,6 @@ data = srsly.pickle_loads(pickled_data) ### YAML -> 📦 The underlying module is exposed via `srsly.ruamel_yaml`. However, we -> normally interact with it via the utility functions only. - #### function `srsly.yaml_dumps` Serialize an object to a YAML string. See the diff --git a/pyproject.toml b/pyproject.toml index 483235c..fed528d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,63 +1,3 @@ [build-system] -requires = [ - "setuptools", - "cython>=0.29.1", -] +requires = ["setuptools"] build-backend = "setuptools.build_meta" - -[tool.cibuildwheel] -build = "*" -skip = [ - "cp38*", # Obsolete - "cp314-*", # FIXME ujson segfaults - "cp314t-*", # TODO free-threading support (note: 3.13t is skipped by default) -] -test-skip = "" - -archs = ["native"] - -build-frontend = "default" -config-settings = {} -dependency-versions = "pinned" -environment = {} -environment-pass = [] -build-verbosity = 0 - -before-all = "" -before-build = "" -repair-wheel-command = "" - -test-command = "" -before-test = "" -test-requires = [] -test-extras = [] - -container-engine = "docker" - -manylinux-x86_64-image = "manylinux2014" -manylinux-i686-image = "manylinux2014" -manylinux-aarch64-image = "manylinux2014" -manylinux-ppc64le-image = "manylinux2014" -manylinux-s390x-image = "manylinux2014" -manylinux-pypy_x86_64-image = "manylinux2014" -manylinux-pypy_i686-image = "manylinux2014" -manylinux-pypy_aarch64-image = "manylinux2014" - -musllinux-x86_64-image = "musllinux_1_2" -musllinux-i686-image = "musllinux_1_2" -musllinux-aarch64-image = "musllinux_1_2" -musllinux-ppc64le-image = "musllinux_1_2" -musllinux-s390x-image = "musllinux_1_2" - - -[tool.cibuildwheel.linux] -repair-wheel-command = "auditwheel repair -w {dest_dir} {wheel}" - -[tool.cibuildwheel.macos] -repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel}" - -[tool.cibuildwheel.windows] - -[tool.cibuildwheel.pyodide] - - diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7de940c..0000000 --- a/requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -catalogue>=2.0.3,<2.1.0 -# Development requirements -cython>=0.29.1 -pytest>=4.6.5 -pytest-timeout>=1.3.3 -mock>=2.0.0,<3.0.0 -numpy>=1.15.0 -psutil diff --git a/setup.cfg b/setup.cfg index 2b3679a..8097293 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,27 +12,25 @@ classifiers = Intended Audience :: Developers Intended Audience :: Science/Research License :: OSI Approved :: MIT License - Operating System :: POSIX :: Linux - Operating System :: MacOS :: MacOS X - Operating System :: Microsoft :: Windows - Programming Language :: Cython + Operating System :: OS Independent Programming Language :: Python :: 3 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Programming Language :: Python :: 3.12 Programming Language :: Python :: 3.13 + Programming Language :: Python :: 3.14 Topic :: Scientific/Engineering [options] zip_safe = true include_package_data = true -# FIXME ujson segfaults on 3.14 -python_requires = >=3.9,<3.14 -setup_requires = - cython>=0.29.1 +python_requires = >=3.9 install_requires = - catalogue>=2.0.3,<2.1.0 + cloudpickle >=3.1.2,<4 + msgpack >=1.1,<2 + ruamel.yaml >=0.18.16,<1 + ujson >=5.11.0,<6 [options.entry_points] # If spaCy is installed in the same environment as srsly, it will automatically @@ -44,7 +42,7 @@ spacy_readers = srsly.read_msgpack.v1 = srsly:read_msgpack [bdist_wheel] -universal = false +universal = true [sdist] formats = gztar @@ -55,11 +53,6 @@ max-line-length = 80 select = B,C,E,F,W,T4,B9 exclude = srsly/__init__.py - srsly/msgpack/__init__.py - srsly/cloudpickle/__init__.py [mypy] ignore_missing_imports = True - -[mypy-srsly.cloudpickle.*] -ignore_errors=True diff --git a/setup.py b/setup.py index 7fd4252..4346012 100644 --- a/setup.py +++ b/setup.py @@ -1,141 +1,19 @@ #!/usr/bin/env python -import sys -from setuptools.command.build_ext import build_ext -from sysconfig import get_path -from setuptools import Extension, setup, find_packages +from setuptools import setup, find_packages from pathlib import Path -from Cython.Build import cythonize -from Cython.Compiler import Options -import contextlib -import os - - -# Preserve `__doc__` on functions and classes -# http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options -Options.docstrings = True - - -PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.c", "*.h", "*.cpp"]} -PACKAGES = find_packages() -# msgpack has this whacky build where it only builds _cmsgpack which textually includes -# _packer and _unpacker. I refactored this. -MOD_NAMES = ["srsly.msgpack._epoch", "srsly.msgpack._packer", "srsly.msgpack._unpacker"] -COMPILE_OPTIONS = { - "msvc": ["/Ox", "/EHsc"], - "mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"], - "other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"], -} -COMPILER_DIRECTIVES = { - "language_level": -3, - "embedsignature": True, - "annotation_typing": False, -} -LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": ["-lstdc++", "-lm"]} - -if sys.byteorder == "big": - macros = [("__BIG_ENDIAN__", "1")] -else: - macros = [("__LITTLE_ENDIAN__", "1")] - - -# By subclassing build_extensions we have the actual compiler that will be used -# which is really known only after finalize_options -# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used -class build_ext_options: - def build_options(self): - if hasattr(self.compiler, "initialize"): - self.compiler.initialize() - self.compiler.platform = sys.platform[:6] - for e in self.extensions: - e.extra_compile_args += COMPILE_OPTIONS.get( - self.compiler.compiler_type, COMPILE_OPTIONS["other"] - ) - e.extra_link_args += LINK_OPTIONS.get( - self.compiler.compiler_type, LINK_OPTIONS["other"] - ) - - -class build_ext_subclass(build_ext, build_ext_options): - def build_extensions(self): - build_ext_options.build_options(self) - build_ext.build_extensions(self) - - -def clean(path): - n_cleaned = 0 - for name in MOD_NAMES: - name = name.replace(".", "/") - for ext in ["so", "html", "cpp", "c"]: - file_path = path / f"{name}.{ext}" - if file_path.exists(): - file_path.unlink() - n_cleaned += 1 - print(f"Cleaned {n_cleaned} files") - - -@contextlib.contextmanager -def chdir(new_dir): - old_dir = os.getcwd() - try: - os.chdir(new_dir) - sys.path.insert(0, new_dir) - yield - finally: - del sys.path[0] - os.chdir(old_dir) def setup_package(): root = Path(__file__).parent - - if len(sys.argv) > 1 and sys.argv[1] == "clean": - return clean(root) - with (root / "srsly" / "about.py").open("r") as f: about = {} exec(f.read(), about) - with chdir(str(root)): - include_dirs = [get_path("include"), ".", "srsly"] - ext_modules = [] - for name in MOD_NAMES: - mod_path = name.replace(".", "/") + ".pyx" - ext_modules.append( - Extension( - name, - [mod_path], - language="c++", - include_dirs=include_dirs, - define_macros=macros, - ) - ) - ext_modules.append( - Extension( - "srsly.ujson.ujson", - sources=[ - "./srsly/ujson/ujson.c", - "./srsly/ujson/objToJSON.c", - "./srsly/ujson/JSONtoObj.c", - "./srsly/ujson/lib/ultrajsonenc.c", - "./srsly/ujson/lib/ultrajsondec.c", - ], - include_dirs=["./srsly/ujson", "./srsly/ujson/lib"], - extra_compile_args=["-D_GNU_SOURCE"], - ) - ) - print("Cythonizing sources") - ext_modules = cythonize( - ext_modules, compiler_directives=COMPILER_DIRECTIVES, language_level=2 - ) - - setup( - name="srsly", - packages=PACKAGES, - version=about["__version__"], - ext_modules=ext_modules, - cmdclass={"build_ext": build_ext_subclass}, - package_data=PACKAGE_DATA, - ) + setup( + name="srsly", + packages=find_packages(), + version=about["__version__"], + ) if __name__ == "__main__": diff --git a/srsly/_json_api.py b/srsly/_json_api.py index 24d25fd..e755217 100644 --- a/srsly/_json_api.py +++ b/srsly/_json_api.py @@ -1,9 +1,10 @@ -from typing import Union, Iterable, Sequence, Any, Optional, Iterator +from typing import Union, Iterable, Any, Optional, Iterator import sys import json as _builtin_json import gzip -from . import ujson +import ujson + from .util import force_path, force_string, FilePath, JSONInput, JSONOutput diff --git a/srsly/_msgpack_api.py b/srsly/_msgpack_api.py index 3da0fe6..5259380 100644 --- a/srsly/_msgpack_api.py +++ b/srsly/_msgpack_api.py @@ -1,8 +1,103 @@ import gc +from contextlib import contextmanager + +import msgpack -from . import msgpack -from .msgpack import msgpack_encoders, msgpack_decoders # noqa: F401 from .util import force_path, FilePath, JSONInputBin, JSONOutputBin +from ._msgpack_numpy import encode_numpy, decode_numpy + + +class _MsgpackExtensions: + """API for extending msgpack (de)serialization: + + srsly.msgpack_encoders.register(name, func) + srsly.msgpack_decoders.register(name, func) + + where `name` is a unique ID and `func` is a callable that accepts a single + argument: + + - For encoders, the argument is the object to serialize. The callable should + return a new object (typically a dict) or the original object if the callback + does not recognize it. + - For decoders, the argument is the dict to deserialize, as returned by the encoders. + The callable should return a new object or the original dict if the callback + does not recognize it. + """ + + __slots__ = ("_ext",) + + def __init__(self): + self._ext = {} + + def register(self, name, func): + """Register a custom encoder/decoder function""" + self._ext[name] = func + + def deregister(self, name): + del self._ext[name] + + def _run(self, obj): + for func in self._ext.values(): + out = func(obj) + if out is not obj: + return out + return obj + + +class _MsgpackEncoderExtensions(_MsgpackExtensions): + def _run(self, obj): + out = super()._run(obj) + if out is not obj: + return out + + # Convert subtypes of base types and tuples to lists. + # Effectively this undoes the strict_types=True option of msgpack. + # This is needed to support np.float64, which is a subclass of builtin float. + # Run this last to allow the user to register their own handlers first. + + if isinstance(obj, tuple): + return list(obj) + # Note: bool and memoryview can't be subclassed + # set and frozenset are not supported by msgpack + for cls in (int, float, list, dict, str, bytes): + if isinstance(obj, cls): + return cls(obj) + + return obj + + +msgpack_encoders = _MsgpackEncoderExtensions() +msgpack_decoders = _MsgpackExtensions() + + +def encode_complex(obj): + if isinstance(obj, complex): + return {b"complex": True, b"data": repr(obj)} + return obj + + +def decode_complex(obj): + if b"complex" in obj: + return complex(obj[b"data"]) + return obj + + +msgpack_encoders.register("numpy", func=encode_numpy) +msgpack_decoders.register("numpy", func=decode_numpy) +# Note: np.complex128 is a subclass of built-in complex, so +# encode_complex must be registered after encode_numpy. +msgpack_encoders.register("complex", func=encode_complex) +msgpack_decoders.register("complex", func=decode_complex) + + +@contextmanager +def _without_gc(): + """msgpack-python docs suggest disabling gc before unpacking large messages""" + gc.disable() + try: + yield + finally: + gc.enable() def msgpack_dumps(data: JSONInputBin) -> bytes: @@ -11,7 +106,13 @@ def msgpack_dumps(data: JSONInputBin) -> bytes: data: The data to serialize. RETURNS (bytes): The serialized bytes. """ - return msgpack.dumps(data, use_bin_type=True) + return msgpack.dumps( + data, + # strict_types is False for everything except np.float64 + # and np.complex128 (see above) + strict_types=True, + default=msgpack_encoders._run, + ) def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin: @@ -22,11 +123,10 @@ def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin: deserialization slower. RETURNS: The deserialized Python object. """ - # msgpack-python docs suggest disabling gc before unpacking large messages - gc.disable() - msg = msgpack.loads(data, raw=False, use_list=use_list) - gc.enable() - return msg + with _without_gc(): + return msgpack.loads( + data, raw=False, use_list=use_list, object_hook=msgpack_decoders._run + ) def write_msgpack(path: FilePath, data: JSONInputBin) -> None: @@ -37,7 +137,7 @@ def write_msgpack(path: FilePath, data: JSONInputBin) -> None: """ file_path = force_path(path, require_exists=False) with file_path.open("wb") as f: - msgpack.dump(data, f, use_bin_type=True) + msgpack.dump(data, f, strict_types=True, default=msgpack_encoders._run) def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin: @@ -49,9 +149,7 @@ def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin: RETURNS (JSONOutputBin): The loaded and deserialized content. """ file_path = force_path(path) - with file_path.open("rb") as f: - # msgpack-python docs suggest disabling gc before unpacking large messages - gc.disable() - msg = msgpack.load(f, raw=False, use_list=use_list) - gc.enable() - return msg + with file_path.open("rb") as f, _without_gc(): + return msgpack.load( + f, raw=False, use_list=use_list, object_hook=msgpack_decoders._run + ) diff --git a/srsly/_msgpack_numpy.py b/srsly/_msgpack_numpy.py new file mode 100644 index 0000000..a3a842c --- /dev/null +++ b/srsly/_msgpack_numpy.py @@ -0,0 +1,80 @@ +""" +Support for serialization of numpy data types with msgpack. +This is a heavily modified fork of a very old version of msgpack-numpy. +""" + +# Copyright (c) 2013-2018, Lev E. Givon +# All rights reserved. +# Distributed under the terms of the BSD license: +# http://www.opensource.org/licenses/bsd-license +try: + import numpy as np + + has_numpy = True +except ImportError: + has_numpy = False + +try: + import cupy + + has_cupy = True +except ImportError: + has_cupy = False + + +def encode_numpy(obj): + """ + Data encoder for serializing numpy data types. + """ + if not has_numpy: + return obj + if has_cupy and isinstance(obj, cupy.ndarray): + obj = obj.get() + if isinstance(obj, np.ndarray): + # If the dtype is structured, store the interface description; + # otherwise, store the corresponding array protocol type string: + if obj.dtype.kind == "V": + kind = b"V" + descr = obj.dtype.descr + else: + kind = b"" + descr = obj.dtype.str + return { + b"nd": True, + b"type": descr, + b"kind": kind, + b"shape": obj.shape, + b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(), + } + if isinstance(obj, (np.bool_, np.number)): + return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data} + + return obj + + +def decode_numpy(obj): + """ + Decoder for deserializing numpy data types. + """ + if b"nd" not in obj: + return obj + + # Crash with a clean ModuleNotFoundError if numpy is not available + # instead of AttributeError + import numpy # noqa: F401 + + if obj[b"nd"]: + # Check if "kind" is in obj to enable decoding of data + # serialized with older versions (#20): + if b"kind" in obj and obj[b"kind"] == b"V": + descr = [ + tuple(t.decode() if type(t) is bytes else t for t in d) + for d in obj[b"type"] + ] + else: + descr = obj[b"type"] + return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(obj[b"shape"]) + else: + # NumPy scalar + descr = obj[b"type"] + return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0] diff --git a/srsly/_pickle_api.py b/srsly/_pickle_api.py index 0e894d9..3a2df96 100644 --- a/srsly/_pickle_api.py +++ b/srsly/_pickle_api.py @@ -1,6 +1,7 @@ from typing import Optional -from . import cloudpickle +import cloudpickle + from .util import JSONInput, JSONOutput diff --git a/srsly/_yaml_api.py b/srsly/_yaml_api.py index 36f24aa..101d3f5 100644 --- a/srsly/_yaml_api.py +++ b/srsly/_yaml_api.py @@ -2,8 +2,9 @@ from io import StringIO import sys -from .ruamel_yaml import YAML -from .ruamel_yaml.representer import RepresenterError +from ruamel.yaml import YAML +from ruamel.yaml.representer import RepresenterError + from .util import force_path, FilePath, YAMLInput, YAMLOutput diff --git a/srsly/about.py b/srsly/about.py index 667b52f..528787c 100644 --- a/srsly/about.py +++ b/srsly/about.py @@ -1 +1 @@ -__version__ = "2.5.2" +__version__ = "3.0.0" diff --git a/srsly/cloudpickle/__init__.py b/srsly/cloudpickle/__init__.py deleted file mode 100644 index c802221..0000000 --- a/srsly/cloudpickle/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .cloudpickle import * # noqa -from .cloudpickle_fast import CloudPickler, dumps, dump # noqa - -# Conform to the convention used by python serialization libraries, which -# expose their Pickler subclass at top-level under the "Pickler" name. -Pickler = CloudPickler - -__version__ = '2.2.0' diff --git a/srsly/cloudpickle/cloudpickle.py b/srsly/cloudpickle/cloudpickle.py deleted file mode 100644 index 317be69..0000000 --- a/srsly/cloudpickle/cloudpickle.py +++ /dev/null @@ -1,948 +0,0 @@ -""" -This class is defined to override standard pickle functionality - -The goals of it follow: --Serialize lambdas and nested functions to compiled byte code --Deal with main module correctly --Deal with other non-serializable objects - -It does not include an unpickler, as standard python unpickling suffices. - -This module was extracted from the `cloud` package, developed by `PiCloud, Inc. -`_. - -Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the University of California, Berkeley nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - -import builtins -import dis -import opcode -import platform -import sys -import types -import weakref -import uuid -import threading -import typing -import warnings - -from .compat import pickle -from collections import OrderedDict -from typing import ClassVar, Generic, Union, Tuple, Callable -from pickle import _getattribute -from importlib._bootstrap import _find_spec - -try: # pragma: no branch - import typing_extensions as _typing_extensions - from typing_extensions import Literal, Final -except ImportError: - _typing_extensions = Literal = Final = None - -if sys.version_info >= (3, 8): - from types import CellType -else: - def f(): - a = 1 - - def g(): - return a - return g - CellType = type(f().__closure__[0]) - - -# cloudpickle is meant for inter process communication: we expect all -# communicating processes to run the same Python version hence we favor -# communication speed over compatibility: -DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL - -# Names of modules whose resources should be treated as dynamic. -_PICKLE_BY_VALUE_MODULES = set() - -# Track the provenance of reconstructed dynamic classes to make it possible to -# reconstruct instances from the matching singleton class definition when -# appropriate and preserve the usual "isinstance" semantics of Python objects. -_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() -_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() -_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock() - -PYPY = platform.python_implementation() == "PyPy" - -builtin_code_type = None -if PYPY: - # builtin-code objects only exist in pypy - builtin_code_type = type(float.__new__.__code__) - -_extract_code_globals_cache = weakref.WeakKeyDictionary() - - -def _get_or_create_tracker_id(class_def): - with _DYNAMIC_CLASS_TRACKER_LOCK: - class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def) - if class_tracker_id is None: - class_tracker_id = uuid.uuid4().hex - _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id - _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def - return class_tracker_id - - -def _lookup_class_or_track(class_tracker_id, class_def): - if class_tracker_id is not None: - with _DYNAMIC_CLASS_TRACKER_LOCK: - class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault( - class_tracker_id, class_def) - _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id - return class_def - - -def register_pickle_by_value(module): - """Register a module to make it functions and classes picklable by value. - - By default, functions and classes that are attributes of an importable - module are to be pickled by reference, that is relying on re-importing - the attribute from the module at load time. - - If `register_pickle_by_value(module)` is called, all its functions and - classes are subsequently to be pickled by value, meaning that they can - be loaded in Python processes where the module is not importable. - - This is especially useful when developing a module in a distributed - execution environment: restarting the client Python process with the new - source code is enough: there is no need to re-install the new version - of the module on all the worker nodes nor to restart the workers. - - Note: this feature is considered experimental. See the cloudpickle - README.md file for more details and limitations. - """ - if not isinstance(module, types.ModuleType): - raise ValueError( - f"Input should be a module object, got {str(module)} instead" - ) - # In the future, cloudpickle may need a way to access any module registered - # for pickling by value in order to introspect relative imports inside - # functions pickled by value. (see - # https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633). - # This access can be ensured by checking that module is present in - # sys.modules at registering time and assuming that it will still be in - # there when accessed during pickling. Another alternative would be to - # store a weakref to the module. Even though cloudpickle does not implement - # this introspection yet, in order to avoid a possible breaking change - # later, we still enforce the presence of module inside sys.modules. - if module.__name__ not in sys.modules: - raise ValueError( - f"{module} was not imported correctly, have you used an " - f"`import` statement to access it?" - ) - _PICKLE_BY_VALUE_MODULES.add(module.__name__) - - -def unregister_pickle_by_value(module): - """Unregister that the input module should be pickled by value.""" - if not isinstance(module, types.ModuleType): - raise ValueError( - f"Input should be a module object, got {str(module)} instead" - ) - if module.__name__ not in _PICKLE_BY_VALUE_MODULES: - raise ValueError(f"{module} is not registered for pickle by value") - else: - _PICKLE_BY_VALUE_MODULES.remove(module.__name__) - - -def list_registry_pickle_by_value(): - return _PICKLE_BY_VALUE_MODULES.copy() - - -def _is_registered_pickle_by_value(module): - module_name = module.__name__ - if module_name in _PICKLE_BY_VALUE_MODULES: - return True - while True: - parent_name = module_name.rsplit(".", 1)[0] - if parent_name == module_name: - break - if parent_name in _PICKLE_BY_VALUE_MODULES: - return True - module_name = parent_name - return False - - -def _whichmodule(obj, name): - """Find the module an object belongs to. - - This function differs from ``pickle.whichmodule`` in two ways: - - it does not mangle the cases where obj's module is __main__ and obj was - not found in any module. - - Errors arising during module introspection are ignored, as those errors - are considered unwanted side effects. - """ - if sys.version_info[:2] < (3, 7) and isinstance(obj, typing.TypeVar): # pragma: no branch # noqa - # Workaround bug in old Python versions: prior to Python 3.7, - # T.__module__ would always be set to "typing" even when the TypeVar T - # would be defined in a different module. - if name is not None and getattr(typing, name, None) is obj: - # Built-in TypeVar defined in typing such as AnyStr - return 'typing' - else: - # User defined or third-party TypeVar: __module__ attribute is - # irrelevant, thus trigger a exhaustive search for obj in all - # modules. - module_name = None - else: - module_name = getattr(obj, '__module__', None) - - if module_name is not None: - return module_name - # Protect the iteration by using a copy of sys.modules against dynamic - # modules that trigger imports of other modules upon calls to getattr or - # other threads importing at the same time. - for module_name, module in sys.modules.copy().items(): - # Some modules such as coverage can inject non-module objects inside - # sys.modules - if ( - module_name == '__main__' or - module is None or - not isinstance(module, types.ModuleType) - ): - continue - try: - if _getattribute(module, name)[0] is obj: - return module_name - except Exception: - pass - return None - - -def _should_pickle_by_reference(obj, name=None): - """Test whether an function or a class should be pickled by reference - - Pickling by reference means by that the object (typically a function or a - class) is an attribute of a module that is assumed to be importable in the - target Python environment. Loading will therefore rely on importing the - module and then calling `getattr` on it to access the function or class. - - Pickling by reference is the only option to pickle functions and classes - in the standard library. In cloudpickle the alternative option is to - pickle by value (for instance for interactively or locally defined - functions and classes or for attributes of modules that have been - explicitly registered to be pickled by value. - """ - if isinstance(obj, types.FunctionType) or issubclass(type(obj), type): - module_and_name = _lookup_module_and_qualname(obj, name=name) - if module_and_name is None: - return False - module, name = module_and_name - return not _is_registered_pickle_by_value(module) - - elif isinstance(obj, types.ModuleType): - # We assume that sys.modules is primarily used as a cache mechanism for - # the Python import machinery. Checking if a module has been added in - # is sys.modules therefore a cheap and simple heuristic to tell us - # whether we can assume that a given module could be imported by name - # in another Python process. - if _is_registered_pickle_by_value(obj): - return False - return obj.__name__ in sys.modules - else: - raise TypeError( - "cannot check importability of {} instances".format( - type(obj).__name__) - ) - - -def _lookup_module_and_qualname(obj, name=None): - if name is None: - name = getattr(obj, '__qualname__', None) - if name is None: # pragma: no cover - # This used to be needed for Python 2.7 support but is probably not - # needed anymore. However we keep the __name__ introspection in case - # users of cloudpickle rely on this old behavior for unknown reasons. - name = getattr(obj, '__name__', None) - - module_name = _whichmodule(obj, name) - - if module_name is None: - # In this case, obj.__module__ is None AND obj was not found in any - # imported module. obj is thus treated as dynamic. - return None - - if module_name == "__main__": - return None - - # Note: if module_name is in sys.modules, the corresponding module is - # assumed importable at unpickling time. See #357 - module = sys.modules.get(module_name, None) - if module is None: - # The main reason why obj's module would not be imported is that this - # module has been dynamically created, using for example - # types.ModuleType. The other possibility is that module was removed - # from sys.modules after obj was created/imported. But this case is not - # supported, as the standard pickle does not support it either. - return None - - try: - obj2, parent = _getattribute(module, name) - except AttributeError: - # obj was not found inside the module it points to - return None - if obj2 is not obj: - return None - return module, name - - -def _extract_code_globals(co): - """ - Find all globals names read or written to by codeblock co - """ - out_names = _extract_code_globals_cache.get(co) - if out_names is None: - # We use a dict with None values instead of a set to get a - # deterministic order (assuming Python 3.6+) and avoid introducing - # non-deterministic pickle bytes as a results. - out_names = {name: None for name in _walk_global_ops(co)} - - # Declaring a function inside another one using the "def ..." - # syntax generates a constant code object corresponding to the one - # of the nested function's As the nested function may itself need - # global variables, we need to introspect its code, extract its - # globals, (look for code object in it's co_consts attribute..) and - # add the result to code_globals - if co.co_consts: - for const in co.co_consts: - if isinstance(const, types.CodeType): - out_names.update(_extract_code_globals(const)) - - _extract_code_globals_cache[co] = out_names - - return out_names - - -def _find_imported_submodules(code, top_level_dependencies): - """ - Find currently imported submodules used by a function. - - Submodules used by a function need to be detected and referenced for the - function to work correctly at depickling time. Because submodules can be - referenced as attribute of their parent package (``package.submodule``), we - need a special introspection technique that does not rely on GLOBAL-related - opcodes to find references of them in a code object. - - Example: - ``` - import concurrent.futures - import cloudpickle - def func(): - x = concurrent.futures.ThreadPoolExecutor - if __name__ == '__main__': - cloudpickle.dumps(func) - ``` - The globals extracted by cloudpickle in the function's state include the - concurrent package, but not its submodule (here, concurrent.futures), which - is the module used by func. Find_imported_submodules will detect the usage - of concurrent.futures. Saving this module alongside with func will ensure - that calling func once depickled does not fail due to concurrent.futures - not being imported - """ - - subimports = [] - # check if any known dependency is an imported package - for x in top_level_dependencies: - if (isinstance(x, types.ModuleType) and - hasattr(x, '__package__') and x.__package__): - # check if the package has any currently loaded sub-imports - prefix = x.__name__ + '.' - # A concurrent thread could mutate sys.modules, - # make sure we iterate over a copy to avoid exceptions - for name in list(sys.modules): - # Older versions of pytest will add a "None" module to - # sys.modules. - if name is not None and name.startswith(prefix): - # check whether the function can address the sub-module - tokens = set(name[len(prefix):].split('.')) - if not tokens - set(code.co_names): - subimports.append(sys.modules[name]) - return subimports - - -def cell_set(cell, value): - """Set the value of a closure cell. - - The point of this function is to set the cell_contents attribute of a cell - after its creation. This operation is necessary in case the cell contains a - reference to the function the cell belongs to, as when calling the - function's constructor - ``f = types.FunctionType(code, globals, name, argdefs, closure)``, - closure will not be able to contain the yet-to-be-created f. - - In Python3.7, cell_contents is writeable, so setting the contents of a cell - can be done simply using - >>> cell.cell_contents = value - - In earlier Python3 versions, the cell_contents attribute of a cell is read - only, but this limitation can be worked around by leveraging the Python 3 - ``nonlocal`` keyword. - - In Python2 however, this attribute is read only, and there is no - ``nonlocal`` keyword. For this reason, we need to come up with more - complicated hacks to set this attribute. - - The chosen approach is to create a function with a STORE_DEREF opcode, - which sets the content of a closure variable. Typically: - - >>> def inner(value): - ... lambda: cell # the lambda makes cell a closure - ... cell = value # cell is a closure, so this triggers a STORE_DEREF - - (Note that in Python2, A STORE_DEREF can never be triggered from an inner - function. The function g for example here - >>> def f(var): - ... def g(): - ... var += 1 - ... return g - - will not modify the closure variable ``var```inplace, but instead try to - load a local variable var and increment it. As g does not assign the local - variable ``var`` any initial value, calling f(1)() will fail at runtime.) - - Our objective is to set the value of a given cell ``cell``. So we need to - somewhat reference our ``cell`` object into the ``inner`` function so that - this object (and not the smoke cell of the lambda function) gets affected - by the STORE_DEREF operation. - - In inner, ``cell`` is referenced as a cell variable (an enclosing variable - that is referenced by the inner function). If we create a new function - cell_set with the exact same code as ``inner``, but with ``cell`` marked as - a free variable instead, the STORE_DEREF will be applied on its closure - - ``cell``, which we can specify explicitly during construction! The new - cell_set variable thus actually sets the contents of a specified cell! - - Note: we do not make use of the ``nonlocal`` keyword to set the contents of - a cell in early python3 versions to limit possible syntax errors in case - test and checker libraries decide to parse the whole file. - """ - - if sys.version_info[:2] >= (3, 7): # pragma: no branch - cell.cell_contents = value - else: - _cell_set = types.FunctionType( - _cell_set_template_code, {}, '_cell_set', (), (cell,),) - _cell_set(value) - - -def _make_cell_set_template_code(): - def _cell_set_factory(value): - lambda: cell - cell = value - - co = _cell_set_factory.__code__ - - _cell_set_template_code = types.CodeType( - co.co_argcount, - co.co_kwonlyargcount, # Python 3 only argument - co.co_nlocals, - co.co_stacksize, - co.co_flags, - co.co_code, - co.co_consts, - co.co_names, - co.co_varnames, - co.co_filename, - co.co_name, - co.co_firstlineno, - co.co_lnotab, - co.co_cellvars, # co_freevars is initialized with co_cellvars - (), # co_cellvars is made empty - ) - return _cell_set_template_code - - -if sys.version_info[:2] < (3, 7): - _cell_set_template_code = _make_cell_set_template_code() - -# relevant opcodes -STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] -DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] -LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL'] -GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL) -HAVE_ARGUMENT = dis.HAVE_ARGUMENT -EXTENDED_ARG = dis.EXTENDED_ARG - - -_BUILTIN_TYPE_NAMES = {} -for k, v in types.__dict__.items(): - if type(v) is type: - _BUILTIN_TYPE_NAMES[v] = k - - -def _builtin_type(name): - if name == "ClassType": # pragma: no cover - # Backward compat to load pickle files generated with cloudpickle - # < 1.3 even if loading pickle files from older versions is not - # officially supported. - return type - return getattr(types, name) - - -def _walk_global_ops(code): - """ - Yield referenced name for all global-referencing instructions in *code*. - """ - for instr in dis.get_instructions(code): - op = instr.opcode - if op in GLOBAL_OPS: - yield instr.argval - - -def _extract_class_dict(cls): - """Retrieve a copy of the dict of a class without the inherited methods""" - clsdict = dict(cls.__dict__) # copy dict proxy to a dict - if len(cls.__bases__) == 1: - inherited_dict = cls.__bases__[0].__dict__ - else: - inherited_dict = {} - for base in reversed(cls.__bases__): - inherited_dict.update(base.__dict__) - to_remove = [] - for name, value in clsdict.items(): - try: - base_value = inherited_dict[name] - if value is base_value: - to_remove.append(name) - except KeyError: - pass - for name in to_remove: - clsdict.pop(name) - return clsdict - - -if sys.version_info[:2] < (3, 7): # pragma: no branch - def _is_parametrized_type_hint(obj): - # This is very cheap but might generate false positives. So try to - # narrow it down is good as possible. - type_module = getattr(type(obj), '__module__', None) - from_typing_extensions = type_module == 'typing_extensions' - from_typing = type_module == 'typing' - - # general typing Constructs - is_typing = getattr(obj, '__origin__', None) is not None - - # typing_extensions.Literal - is_literal = ( - (getattr(obj, '__values__', None) is not None) - and from_typing_extensions - ) - - # typing_extensions.Final - is_final = ( - (getattr(obj, '__type__', None) is not None) - and from_typing_extensions - ) - - # typing.ClassVar - is_classvar = ( - (getattr(obj, '__type__', None) is not None) and from_typing - ) - - # typing.Union/Tuple for old Python 3.5 - is_union = getattr(obj, '__union_params__', None) is not None - is_tuple = getattr(obj, '__tuple_params__', None) is not None - is_callable = ( - getattr(obj, '__result__', None) is not None and - getattr(obj, '__args__', None) is not None - ) - return any((is_typing, is_literal, is_final, is_classvar, is_union, - is_tuple, is_callable)) - - def _create_parametrized_type_hint(origin, args): - return origin[args] -else: - _is_parametrized_type_hint = None - _create_parametrized_type_hint = None - - -def parametrized_type_hint_getinitargs(obj): - # The distorted type check sematic for typing construct becomes: - # ``type(obj) is type(TypeHint)``, which means "obj is a - # parametrized TypeHint" - if type(obj) is type(Literal): # pragma: no branch - initargs = (Literal, obj.__values__) - elif type(obj) is type(Final): # pragma: no branch - initargs = (Final, obj.__type__) - elif type(obj) is type(ClassVar): - initargs = (ClassVar, obj.__type__) - elif type(obj) is type(Generic): - initargs = (obj.__origin__, obj.__args__) - elif type(obj) is type(Union): - initargs = (Union, obj.__args__) - elif type(obj) is type(Tuple): - initargs = (Tuple, obj.__args__) - elif type(obj) is type(Callable): - (*args, result) = obj.__args__ - if len(args) == 1 and args[0] is Ellipsis: - args = Ellipsis - else: - args = list(args) - initargs = (Callable, (args, result)) - else: # pragma: no cover - raise pickle.PicklingError( - f"Cloudpickle Error: Unknown type {type(obj)}" - ) - return initargs - - -# Tornado support - -def is_tornado_coroutine(func): - """ - Return whether *func* is a Tornado coroutine function. - Running coroutines are not supported. - """ - if 'tornado.gen' not in sys.modules: - return False - gen = sys.modules['tornado.gen'] - if not hasattr(gen, "is_coroutine_function"): - # Tornado version is too old - return False - return gen.is_coroutine_function(func) - - -def _rebuild_tornado_coroutine(func): - from tornado import gen - return gen.coroutine(func) - - -# including pickles unloading functions in this namespace -load = pickle.load -loads = pickle.loads - - -def subimport(name): - # We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is - # the name of a submodule, __import__ will return the top-level root module - # of this submodule. For instance, __import__('os.path') returns the `os` - # module. - __import__(name) - return sys.modules[name] - - -def dynamic_subimport(name, vars): - mod = types.ModuleType(name) - mod.__dict__.update(vars) - mod.__dict__['__builtins__'] = builtins.__dict__ - return mod - - -def _gen_ellipsis(): - return Ellipsis - - -def _gen_not_implemented(): - return NotImplemented - - -def _get_cell_contents(cell): - try: - return cell.cell_contents - except ValueError: - # sentinel used by ``_fill_function`` which will leave the cell empty - return _empty_cell_value - - -def instance(cls): - """Create a new instance of a class. - - Parameters - ---------- - cls : type - The class to create an instance of. - - Returns - ------- - instance : cls - A new instance of ``cls``. - """ - return cls() - - -@instance -class _empty_cell_value: - """sentinel for empty closures - """ - @classmethod - def __reduce__(cls): - return cls.__name__ - - -def _fill_function(*args): - """Fills in the rest of function data into the skeleton function object - - The skeleton itself is create by _make_skel_func(). - """ - if len(args) == 2: - func = args[0] - state = args[1] - elif len(args) == 5: - # Backwards compat for cloudpickle v0.4.0, after which the `module` - # argument was introduced - func = args[0] - keys = ['globals', 'defaults', 'dict', 'closure_values'] - state = dict(zip(keys, args[1:])) - elif len(args) == 6: - # Backwards compat for cloudpickle v0.4.1, after which the function - # state was passed as a dict to the _fill_function it-self. - func = args[0] - keys = ['globals', 'defaults', 'dict', 'module', 'closure_values'] - state = dict(zip(keys, args[1:])) - else: - raise ValueError(f'Unexpected _fill_value arguments: {args!r}') - - # - At pickling time, any dynamic global variable used by func is - # serialized by value (in state['globals']). - # - At unpickling time, func's __globals__ attribute is initialized by - # first retrieving an empty isolated namespace that will be shared - # with other functions pickled from the same original module - # by the same CloudPickler instance and then updated with the - # content of state['globals'] to populate the shared isolated - # namespace with all the global variables that are specifically - # referenced for this function. - func.__globals__.update(state['globals']) - - func.__defaults__ = state['defaults'] - func.__dict__ = state['dict'] - if 'annotations' in state: - func.__annotations__ = state['annotations'] - if 'doc' in state: - func.__doc__ = state['doc'] - if 'name' in state: - func.__name__ = state['name'] - if 'module' in state: - func.__module__ = state['module'] - if 'qualname' in state: - func.__qualname__ = state['qualname'] - if 'kwdefaults' in state: - func.__kwdefaults__ = state['kwdefaults'] - # _cloudpickle_subimports is a set of submodules that must be loaded for - # the pickled function to work correctly at unpickling time. Now that these - # submodules are depickled (hence imported), they can be removed from the - # object's state (the object state only served as a reference holder to - # these submodules) - if '_cloudpickle_submodules' in state: - state.pop('_cloudpickle_submodules') - - cells = func.__closure__ - if cells is not None: - for cell, value in zip(cells, state['closure_values']): - if value is not _empty_cell_value: - cell_set(cell, value) - - return func - - -def _make_function(code, globals, name, argdefs, closure): - # Setting __builtins__ in globals is needed for nogil CPython. - globals["__builtins__"] = __builtins__ - return types.FunctionType(code, globals, name, argdefs, closure) - - -def _make_empty_cell(): - if False: - # trick the compiler into creating an empty cell in our lambda - cell = None - raise AssertionError('this route should not be executed') - - return (lambda: cell).__closure__[0] - - -def _make_cell(value=_empty_cell_value): - cell = _make_empty_cell() - if value is not _empty_cell_value: - cell_set(cell, value) - return cell - - -def _make_skel_func(code, cell_count, base_globals=None): - """ Creates a skeleton function object that contains just the provided - code and the correct number of cells in func_closure. All other - func attributes (e.g. func_globals) are empty. - """ - # This function is deprecated and should be removed in cloudpickle 1.7 - warnings.warn( - "A pickle file created using an old (<=1.4.1) version of cloudpickle " - "is currently being loaded. This is not supported by cloudpickle and " - "will break in cloudpickle 1.7", category=UserWarning - ) - # This is backward-compatibility code: for cloudpickle versions between - # 0.5.4 and 0.7, base_globals could be a string or None. base_globals - # should now always be a dictionary. - if base_globals is None or isinstance(base_globals, str): - base_globals = {} - - base_globals['__builtins__'] = __builtins__ - - closure = ( - tuple(_make_empty_cell() for _ in range(cell_count)) - if cell_count >= 0 else - None - ) - return types.FunctionType(code, base_globals, None, None, closure) - - -def _make_skeleton_class(type_constructor, name, bases, type_kwargs, - class_tracker_id, extra): - """Build dynamic class with an empty __dict__ to be filled once memoized - - If class_tracker_id is not None, try to lookup an existing class definition - matching that id. If none is found, track a newly reconstructed class - definition under that id so that other instances stemming from the same - class id will also reuse this class definition. - - The "extra" variable is meant to be a dict (or None) that can be used for - forward compatibility shall the need arise. - """ - skeleton_class = types.new_class( - name, bases, {'metaclass': type_constructor}, - lambda ns: ns.update(type_kwargs) - ) - return _lookup_class_or_track(class_tracker_id, skeleton_class) - - -def _rehydrate_skeleton_class(skeleton_class, class_dict): - """Put attributes from `class_dict` back on `skeleton_class`. - - See CloudPickler.save_dynamic_class for more info. - """ - registry = None - for attrname, attr in class_dict.items(): - if attrname == "_abc_impl": - registry = attr - else: - setattr(skeleton_class, attrname, attr) - if registry is not None: - for subclass in registry: - skeleton_class.register(subclass) - - return skeleton_class - - -def _make_skeleton_enum(bases, name, qualname, members, module, - class_tracker_id, extra): - """Build dynamic enum with an empty __dict__ to be filled once memoized - - The creation of the enum class is inspired by the code of - EnumMeta._create_. - - If class_tracker_id is not None, try to lookup an existing enum definition - matching that id. If none is found, track a newly reconstructed enum - definition under that id so that other instances stemming from the same - class id will also reuse this enum definition. - - The "extra" variable is meant to be a dict (or None) that can be used for - forward compatibility shall the need arise. - """ - # enums always inherit from their base Enum class at the last position in - # the list of base classes: - enum_base = bases[-1] - metacls = enum_base.__class__ - classdict = metacls.__prepare__(name, bases) - - for member_name, member_value in members.items(): - classdict[member_name] = member_value - enum_class = metacls.__new__(metacls, name, bases, classdict) - enum_class.__module__ = module - enum_class.__qualname__ = qualname - - return _lookup_class_or_track(class_tracker_id, enum_class) - - -def _make_typevar(name, bound, constraints, covariant, contravariant, - class_tracker_id): - tv = typing.TypeVar( - name, *constraints, bound=bound, - covariant=covariant, contravariant=contravariant - ) - if class_tracker_id is not None: - return _lookup_class_or_track(class_tracker_id, tv) - else: # pragma: nocover - # Only for Python 3.5.3 compat. - return tv - - -def _decompose_typevar(obj): - return ( - obj.__name__, obj.__bound__, obj.__constraints__, - obj.__covariant__, obj.__contravariant__, - _get_or_create_tracker_id(obj), - ) - - -def _typevar_reduce(obj): - # TypeVar instances require the module information hence why we - # are not using the _should_pickle_by_reference directly - module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__) - - if module_and_name is None: - return (_make_typevar, _decompose_typevar(obj)) - elif _is_registered_pickle_by_value(module_and_name[0]): - return (_make_typevar, _decompose_typevar(obj)) - - return (getattr, module_and_name) - - -def _get_bases(typ): - if '__orig_bases__' in getattr(typ, '__dict__', {}): - # For generic types (see PEP 560) - # Note that simply checking `hasattr(typ, '__orig_bases__')` is not - # correct. Subclasses of a fully-parameterized generic class does not - # have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')` - # will return True because it's defined in the base class. - bases_attr = '__orig_bases__' - else: - # For regular class objects - bases_attr = '__bases__' - return getattr(typ, bases_attr) - - -def _make_dict_keys(obj, is_ordered=False): - if is_ordered: - return OrderedDict.fromkeys(obj).keys() - else: - return dict.fromkeys(obj).keys() - - -def _make_dict_values(obj, is_ordered=False): - if is_ordered: - return OrderedDict((i, _) for i, _ in enumerate(obj)).values() - else: - return {i: _ for i, _ in enumerate(obj)}.values() - - -def _make_dict_items(obj, is_ordered=False): - if is_ordered: - return OrderedDict(obj).items() - else: - return obj.items() diff --git a/srsly/cloudpickle/cloudpickle_fast.py b/srsly/cloudpickle/cloudpickle_fast.py deleted file mode 100644 index 8741dcb..0000000 --- a/srsly/cloudpickle/cloudpickle_fast.py +++ /dev/null @@ -1,844 +0,0 @@ -""" -New, fast version of the CloudPickler. - -This new CloudPickler class can now extend the fast C Pickler instead of the -previous Python implementation of the Pickler class. Because this functionality -is only available for Python versions 3.8+, a lot of backward-compatibility -code is also removed. - -Note that the C Pickler subclassing API is CPython-specific. Therefore, some -guards present in cloudpickle.py that were written to handle PyPy specificities -are not present in cloudpickle_fast.py -""" -import _collections_abc -import abc -import copyreg -import io -import itertools -import logging -import sys -import struct -import types -import weakref -import typing - -from enum import Enum -from collections import ChainMap, OrderedDict - -from .compat import pickle, Pickler -from .cloudpickle import ( - _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL, - _find_imported_submodules, _get_cell_contents, _should_pickle_by_reference, - _builtin_type, _get_or_create_tracker_id, _make_skeleton_class, - _make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport, - _typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType, - _is_parametrized_type_hint, PYPY, cell_set, - parametrized_type_hint_getinitargs, _create_parametrized_type_hint, - builtin_code_type, - _make_dict_keys, _make_dict_values, _make_dict_items, _make_function, -) - - -if pickle.HIGHEST_PROTOCOL >= 5: - # Shorthands similar to pickle.dump/pickle.dumps - - def dump(obj, file, protocol=None, buffer_callback=None): - """Serialize obj as bytes streamed into file - - protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to - pickle.HIGHEST_PROTOCOL. This setting favors maximum communication - speed between processes running the same Python version. - - Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure - compatibility with older versions of Python. - """ - CloudPickler( - file, protocol=protocol, buffer_callback=buffer_callback - ).dump(obj) - - def dumps(obj, protocol=None, buffer_callback=None): - """Serialize obj as a string of bytes allocated in memory - - protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to - pickle.HIGHEST_PROTOCOL. This setting favors maximum communication - speed between processes running the same Python version. - - Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure - compatibility with older versions of Python. - """ - with io.BytesIO() as file: - cp = CloudPickler( - file, protocol=protocol, buffer_callback=buffer_callback - ) - cp.dump(obj) - return file.getvalue() - -else: - # Shorthands similar to pickle.dump/pickle.dumps - def dump(obj, file, protocol=None): - """Serialize obj as bytes streamed into file - - protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to - pickle.HIGHEST_PROTOCOL. This setting favors maximum communication - speed between processes running the same Python version. - - Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure - compatibility with older versions of Python. - """ - CloudPickler(file, protocol=protocol).dump(obj) - - def dumps(obj, protocol=None): - """Serialize obj as a string of bytes allocated in memory - - protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to - pickle.HIGHEST_PROTOCOL. This setting favors maximum communication - speed between processes running the same Python version. - - Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure - compatibility with older versions of Python. - """ - with io.BytesIO() as file: - cp = CloudPickler(file, protocol=protocol) - cp.dump(obj) - return file.getvalue() - - -load, loads = pickle.load, pickle.loads - - -# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS -# ------------------------------------------------- - -def _class_getnewargs(obj): - type_kwargs = {} - if "__slots__" in obj.__dict__: - type_kwargs["__slots__"] = obj.__slots__ - - __dict__ = obj.__dict__.get('__dict__', None) - if isinstance(__dict__, property): - type_kwargs['__dict__'] = __dict__ - - return (type(obj), obj.__name__, _get_bases(obj), type_kwargs, - _get_or_create_tracker_id(obj), None) - - -def _enum_getnewargs(obj): - members = {e.name: e.value for e in obj} - return (obj.__bases__, obj.__name__, obj.__qualname__, members, - obj.__module__, _get_or_create_tracker_id(obj), None) - - -# COLLECTION OF OBJECTS RECONSTRUCTORS -# ------------------------------------ -def _file_reconstructor(retval): - return retval - - -# COLLECTION OF OBJECTS STATE GETTERS -# ----------------------------------- -def _function_getstate(func): - # - Put func's dynamic attributes (stored in func.__dict__) in state. These - # attributes will be restored at unpickling time using - # f.__dict__.update(state) - # - Put func's members into slotstate. Such attributes will be restored at - # unpickling time by iterating over slotstate and calling setattr(func, - # slotname, slotvalue) - slotstate = { - "__name__": func.__name__, - "__qualname__": func.__qualname__, - "__annotations__": func.__annotations__, - "__kwdefaults__": func.__kwdefaults__, - "__defaults__": func.__defaults__, - "__module__": func.__module__, - "__doc__": func.__doc__, - "__closure__": func.__closure__, - } - - f_globals_ref = _extract_code_globals(func.__code__) - f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in - func.__globals__} - - closure_values = ( - list(map(_get_cell_contents, func.__closure__)) - if func.__closure__ is not None else () - ) - - # Extract currently-imported submodules used by func. Storing these modules - # in a smoke _cloudpickle_subimports attribute of the object's state will - # trigger the side effect of importing these modules at unpickling time - # (which is necessary for func to work correctly once depickled) - slotstate["_cloudpickle_submodules"] = _find_imported_submodules( - func.__code__, itertools.chain(f_globals.values(), closure_values)) - slotstate["__globals__"] = f_globals - - state = func.__dict__ - return state, slotstate - - -def _class_getstate(obj): - clsdict = _extract_class_dict(obj) - clsdict.pop('__weakref__', None) - - if issubclass(type(obj), abc.ABCMeta): - # If obj is an instance of an ABCMeta subclass, don't pickle the - # cache/negative caches populated during isinstance/issubclass - # checks, but pickle the list of registered subclasses of obj. - clsdict.pop('_abc_cache', None) - clsdict.pop('_abc_negative_cache', None) - clsdict.pop('_abc_negative_cache_version', None) - registry = clsdict.pop('_abc_registry', None) - if registry is None: - # in Python3.7+, the abc caches and registered subclasses of a - # class are bundled into the single _abc_impl attribute - clsdict.pop('_abc_impl', None) - (registry, _, _, _) = abc._get_dump(obj) - - clsdict["_abc_impl"] = [subclass_weakref() - for subclass_weakref in registry] - else: - # In the above if clause, registry is a set of weakrefs -- in - # this case, registry is a WeakSet - clsdict["_abc_impl"] = [type_ for type_ in registry] - - if "__slots__" in clsdict: - # pickle string length optimization: member descriptors of obj are - # created automatically from obj's __slots__ attribute, no need to - # save them in obj's state - if isinstance(obj.__slots__, str): - clsdict.pop(obj.__slots__) - else: - for k in obj.__slots__: - clsdict.pop(k, None) - - clsdict.pop('__dict__', None) # unpicklable property object - - return (clsdict, {}) - - -def _enum_getstate(obj): - clsdict, slotstate = _class_getstate(obj) - - members = {e.name: e.value for e in obj} - # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class: - # Those attributes are already handled by the metaclass. - for attrname in ["_generate_next_value_", "_member_names_", - "_member_map_", "_member_type_", - "_value2member_map_"]: - clsdict.pop(attrname, None) - for member in members: - clsdict.pop(member) - # Special handling of Enum subclasses - return clsdict, slotstate - - -# COLLECTIONS OF OBJECTS REDUCERS -# ------------------------------- -# A reducer is a function taking a single argument (obj), and that returns a -# tuple with all the necessary data to re-construct obj. Apart from a few -# exceptions (list, dict, bytes, int, etc.), a reducer is necessary to -# correctly pickle an object. -# While many built-in objects (Exceptions objects, instances of the "object" -# class, etc), are shipped with their own built-in reducer (invoked using -# obj.__reduce__), some do not. The following methods were created to "fill -# these holes". - -def _code_reduce(obj): - """codeobject reducer""" - # If you are not sure about the order of arguments, take a look at help - # of the specific type from types, for example: - # >>> from types import CodeType - # >>> help(CodeType) - if hasattr(obj, "co_exceptiontable"): # pragma: no branch - # Python 3.11 and later: there are some new attributes - # related to the enhanced exceptions. - args = ( - obj.co_argcount, obj.co_posonlyargcount, - obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, - obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, - obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname, - obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable, - obj.co_freevars, obj.co_cellvars, - ) - elif hasattr(obj, "co_linetable"): # pragma: no branch - # Python 3.10 and later: obj.co_lnotab is deprecated and constructor - # expects obj.co_linetable instead. - args = ( - obj.co_argcount, obj.co_posonlyargcount, - obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, - obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, - obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_linetable, obj.co_freevars, - obj.co_cellvars - ) - elif hasattr(obj, "co_nmeta"): # pragma: no cover - # "nogil" Python: modified attributes from 3.9 - args = ( - obj.co_argcount, obj.co_posonlyargcount, - obj.co_kwonlyargcount, obj.co_nlocals, obj.co_framesize, - obj.co_ndefaultargs, obj.co_nmeta, - obj.co_flags, obj.co_code, obj.co_consts, - obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_exc_handlers, - obj.co_jump_table, obj.co_freevars, obj.co_cellvars, - obj.co_free2reg, obj.co_cell2reg - ) - elif hasattr(obj, "co_posonlyargcount"): - # Backward compat for 3.9 and older - args = ( - obj.co_argcount, obj.co_posonlyargcount, - obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, - obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, - obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, - obj.co_cellvars - ) - else: - # Backward compat for even older versions of Python - args = ( - obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, - obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, - obj.co_names, obj.co_varnames, obj.co_filename, - obj.co_name, obj.co_firstlineno, obj.co_lnotab, - obj.co_freevars, obj.co_cellvars - ) - return types.CodeType, args - - -def _cell_reduce(obj): - """Cell (containing values of a function's free variables) reducer""" - try: - obj.cell_contents - except ValueError: # cell is empty - return _make_empty_cell, () - else: - return _make_cell, (obj.cell_contents, ) - - -def _classmethod_reduce(obj): - orig_func = obj.__func__ - return type(obj), (orig_func,) - - -def _file_reduce(obj): - """Save a file""" - import io - - if not hasattr(obj, "name") or not hasattr(obj, "mode"): - raise pickle.PicklingError( - "Cannot pickle files that do not map to an actual file" - ) - if obj is sys.stdout: - return getattr, (sys, "stdout") - if obj is sys.stderr: - return getattr, (sys, "stderr") - if obj is sys.stdin: - raise pickle.PicklingError("Cannot pickle standard input") - if obj.closed: - raise pickle.PicklingError("Cannot pickle closed files") - if hasattr(obj, "isatty") and obj.isatty(): - raise pickle.PicklingError( - "Cannot pickle files that map to tty objects" - ) - if "r" not in obj.mode and "+" not in obj.mode: - raise pickle.PicklingError( - "Cannot pickle files that are not opened for reading: %s" - % obj.mode - ) - - name = obj.name - - retval = io.StringIO() - - try: - # Read the whole file - curloc = obj.tell() - obj.seek(0) - contents = obj.read() - obj.seek(curloc) - except IOError as e: - raise pickle.PicklingError( - "Cannot pickle file %s as it cannot be read" % name - ) from e - retval.write(contents) - retval.seek(curloc) - - retval.name = name - return _file_reconstructor, (retval,) - - -def _getset_descriptor_reduce(obj): - return getattr, (obj.__objclass__, obj.__name__) - - -def _mappingproxy_reduce(obj): - return types.MappingProxyType, (dict(obj),) - - -def _memoryview_reduce(obj): - return bytes, (obj.tobytes(),) - - -def _module_reduce(obj): - if _should_pickle_by_reference(obj): - return subimport, (obj.__name__,) - else: - # Some external libraries can populate the "__builtins__" entry of a - # module's `__dict__` with unpicklable objects (see #316). For that - # reason, we do not attempt to pickle the "__builtins__" entry, and - # restore a default value for it at unpickling time. - state = obj.__dict__.copy() - state.pop('__builtins__', None) - return dynamic_subimport, (obj.__name__, state) - - -def _method_reduce(obj): - return (types.MethodType, (obj.__func__, obj.__self__)) - - -def _logger_reduce(obj): - return logging.getLogger, (obj.name,) - - -def _root_logger_reduce(obj): - return logging.getLogger, () - - -def _property_reduce(obj): - return property, (obj.fget, obj.fset, obj.fdel, obj.__doc__) - - -def _weakset_reduce(obj): - return weakref.WeakSet, (list(obj),) - - -def _dynamic_class_reduce(obj): - """ - Save a class that can't be stored as module global. - - This method is used to serialize classes that are defined inside - functions, or that otherwise can't be serialized as attribute lookups - from global modules. - """ - if Enum is not None and issubclass(obj, Enum): - return ( - _make_skeleton_enum, _enum_getnewargs(obj), _enum_getstate(obj), - None, None, _class_setstate - ) - else: - return ( - _make_skeleton_class, _class_getnewargs(obj), _class_getstate(obj), - None, None, _class_setstate - ) - - -def _class_reduce(obj): - """Select the reducer depending on the dynamic nature of the class obj""" - if obj is type(None): # noqa - return type, (None,) - elif obj is type(Ellipsis): - return type, (Ellipsis,) - elif obj is type(NotImplemented): - return type, (NotImplemented,) - elif obj in _BUILTIN_TYPE_NAMES: - return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],) - elif not _should_pickle_by_reference(obj): - return _dynamic_class_reduce(obj) - return NotImplemented - - -def _dict_keys_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_keys, (list(obj), ) - - -def _dict_values_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_values, (list(obj), ) - - -def _dict_items_reduce(obj): - return _make_dict_items, (dict(obj), ) - - -def _odict_keys_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_keys, (list(obj), True) - - -def _odict_values_reduce(obj): - # Safer not to ship the full dict as sending the rest might - # be unintended and could potentially cause leaking of - # sensitive information - return _make_dict_values, (list(obj), True) - - -def _odict_items_reduce(obj): - return _make_dict_items, (dict(obj), True) - - -# COLLECTIONS OF OBJECTS STATE SETTERS -# ------------------------------------ -# state setters are called at unpickling time, once the object is created and -# it has to be updated to how it was at unpickling time. - - -def _function_setstate(obj, state): - """Update the state of a dynamic function. - - As __closure__ and __globals__ are readonly attributes of a function, we - cannot rely on the native setstate routine of pickle.load_build, that calls - setattr on items of the slotstate. Instead, we have to modify them inplace. - """ - state, slotstate = state - obj.__dict__.update(state) - - obj_globals = slotstate.pop("__globals__") - obj_closure = slotstate.pop("__closure__") - # _cloudpickle_subimports is a set of submodules that must be loaded for - # the pickled function to work correctly at unpickling time. Now that these - # submodules are depickled (hence imported), they can be removed from the - # object's state (the object state only served as a reference holder to - # these submodules) - slotstate.pop("_cloudpickle_submodules") - - obj.__globals__.update(obj_globals) - obj.__globals__["__builtins__"] = __builtins__ - - if obj_closure is not None: - for i, cell in enumerate(obj_closure): - try: - value = cell.cell_contents - except ValueError: # cell is empty - continue - cell_set(obj.__closure__[i], value) - - for k, v in slotstate.items(): - setattr(obj, k, v) - - -def _class_setstate(obj, state): - state, slotstate = state - registry = None - for attrname, attr in state.items(): - if attrname == "_abc_impl": - registry = attr - else: - setattr(obj, attrname, attr) - if registry is not None: - for subclass in registry: - obj.register(subclass) - - return obj - - -class CloudPickler(Pickler): - # set of reducers defined and used by cloudpickle (private) - _dispatch_table = {} - _dispatch_table[classmethod] = _classmethod_reduce - _dispatch_table[io.TextIOWrapper] = _file_reduce - _dispatch_table[logging.Logger] = _logger_reduce - _dispatch_table[logging.RootLogger] = _root_logger_reduce - _dispatch_table[memoryview] = _memoryview_reduce - _dispatch_table[property] = _property_reduce - _dispatch_table[staticmethod] = _classmethod_reduce - _dispatch_table[CellType] = _cell_reduce - _dispatch_table[types.CodeType] = _code_reduce - _dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce - _dispatch_table[types.ModuleType] = _module_reduce - _dispatch_table[types.MethodType] = _method_reduce - _dispatch_table[types.MappingProxyType] = _mappingproxy_reduce - _dispatch_table[weakref.WeakSet] = _weakset_reduce - _dispatch_table[typing.TypeVar] = _typevar_reduce - _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce - _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce - _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce - _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce - _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce - _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce - _dispatch_table[abc.abstractmethod] = _classmethod_reduce - _dispatch_table[abc.abstractclassmethod] = _classmethod_reduce - _dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce - _dispatch_table[abc.abstractproperty] = _property_reduce - - dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table) - - # function reducers are defined as instance methods of CloudPickler - # objects, as they rely on a CloudPickler attribute (globals_ref) - def _dynamic_function_reduce(self, func): - """Reduce a function that is not pickleable via attribute lookup.""" - newargs = self._function_getnewargs(func) - state = _function_getstate(func) - return (_make_function, newargs, state, None, None, - _function_setstate) - - def _function_reduce(self, obj): - """Reducer for function objects. - - If obj is a top-level attribute of a file-backed module, this - reducer returns NotImplemented, making the CloudPickler fallback to - traditional _pickle.Pickler routines to save obj. Otherwise, it reduces - obj using a custom cloudpickle reducer designed specifically to handle - dynamic functions. - - As opposed to cloudpickle.py, There no special handling for builtin - pypy functions because cloudpickle_fast is CPython-specific. - """ - if _should_pickle_by_reference(obj): - return NotImplemented - else: - return self._dynamic_function_reduce(obj) - - def _function_getnewargs(self, func): - code = func.__code__ - - # base_globals represents the future global namespace of func at - # unpickling time. Looking it up and storing it in - # CloudpiPickler.globals_ref allow functions sharing the same globals - # at pickling time to also share them once unpickled, at one condition: - # since globals_ref is an attribute of a CloudPickler instance, and - # that a new CloudPickler is created each time pickle.dump or - # pickle.dumps is called, functions also need to be saved within the - # same invocation of cloudpickle.dump/cloudpickle.dumps (for example: - # cloudpickle.dumps([f1, f2])). There is no such limitation when using - # CloudPickler.dump, as long as the multiple invocations are bound to - # the same CloudPickler. - base_globals = self.globals_ref.setdefault(id(func.__globals__), {}) - - if base_globals == {}: - # Add module attributes used to resolve relative imports - # instructions inside func. - for k in ["__package__", "__name__", "__path__", "__file__"]: - if k in func.__globals__: - base_globals[k] = func.__globals__[k] - - # Do not bind the free variables before the function is created to - # avoid infinite recursion. - if func.__closure__ is None: - closure = None - else: - closure = tuple( - _make_empty_cell() for _ in range(len(code.co_freevars))) - - return code, base_globals, None, None, closure - - def dump(self, obj): - try: - return Pickler.dump(self, obj) - except RuntimeError as e: - if "recursion" in e.args[0]: - msg = ( - "Could not pickle object as excessively deep recursion " - "required." - ) - raise pickle.PicklingError(msg) from e - else: - raise - - if pickle.HIGHEST_PROTOCOL >= 5: - def __init__(self, file, protocol=None, buffer_callback=None): - if protocol is None: - protocol = DEFAULT_PROTOCOL - Pickler.__init__( - self, file, protocol=protocol, buffer_callback=buffer_callback - ) - # map functions __globals__ attribute ids, to ensure that functions - # sharing the same global namespace at pickling time also share - # their global namespace at unpickling time. - self.globals_ref = {} - self.proto = int(protocol) - else: - def __init__(self, file, protocol=None): - if protocol is None: - protocol = DEFAULT_PROTOCOL - Pickler.__init__(self, file, protocol=protocol) - # map functions __globals__ attribute ids, to ensure that functions - # sharing the same global namespace at pickling time also share - # their global namespace at unpickling time. - self.globals_ref = {} - assert hasattr(self, 'proto') - - if pickle.HIGHEST_PROTOCOL >= 5 and not PYPY: - # Pickler is the C implementation of the CPython pickler and therefore - # we rely on reduce_override method to customize the pickler behavior. - - # `CloudPickler.dispatch` is only left for backward compatibility - note - # that when using protocol 5, `CloudPickler.dispatch` is not an - # extension of `Pickler.dispatch` dictionary, because CloudPickler - # subclasses the C-implemented Pickler, which does not expose a - # `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler - # used `CloudPickler.dispatch` as a class-level attribute storing all - # reducers implemented by cloudpickle, but the attribute name was not a - # great choice given the meaning of `CloudPickler.dispatch` when - # `CloudPickler` extends the pure-python pickler. - dispatch = dispatch_table - - # Implementation of the reducer_override callback, in order to - # efficiently serialize dynamic functions and classes by subclassing - # the C-implemented Pickler. - # TODO: decorrelate reducer_override (which is tied to CPython's - # implementation - would it make sense to backport it to pypy? - and - # pickle's protocol 5 which is implementation agnostic. Currently, the - # availability of both notions coincide on CPython's pickle and the - # pickle5 backport, but it may not be the case anymore when pypy - # implements protocol 5 - - def reducer_override(self, obj): - """Type-agnostic reducing callback for function and classes. - - For performance reasons, subclasses of the C _pickle.Pickler class - cannot register custom reducers for functions and classes in the - dispatch_table. Reducer for such types must instead implemented in - the special reducer_override method. - - Note that method will be called for any object except a few - builtin-types (int, lists, dicts etc.), which differs from reducers - in the Pickler's dispatch_table, each of them being invoked for - objects of a specific type only. - - This property comes in handy for classes: although most classes are - instances of the ``type`` metaclass, some of them can be instances - of other custom metaclasses (such as enum.EnumMeta for example). In - particular, the metaclass will likely not be known in advance, and - thus cannot be special-cased using an entry in the dispatch_table. - reducer_override, among other things, allows us to register a - reducer that will be called for any class, independently of its - type. - - - Notes: - - * reducer_override has the priority over dispatch_table-registered - reducers. - * reducer_override can be used to fix other limitations of - cloudpickle for other types that suffered from type-specific - reducers, such as Exceptions. See - https://github.com/cloudpipe/cloudpickle/issues/248 - """ - if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch - return ( - _create_parametrized_type_hint, - parametrized_type_hint_getinitargs(obj) - ) - t = type(obj) - try: - is_anyclass = issubclass(t, type) - except TypeError: # t is not a class (old Boost; see SF #502085) - is_anyclass = False - - if is_anyclass: - return _class_reduce(obj) - elif isinstance(obj, types.FunctionType): - return self._function_reduce(obj) - else: - # fallback to save_global, including the Pickler's - # dispatch_table - return NotImplemented - - else: - # When reducer_override is not available, hack the pure-Python - # Pickler's types.FunctionType and type savers. Note: the type saver - # must override Pickler.save_global, because pickle.py contains a - # hard-coded call to save_global when pickling meta-classes. - dispatch = Pickler.dispatch.copy() - - def _save_reduce_pickle5(self, func, args, state=None, listitems=None, - dictitems=None, state_setter=None, obj=None): - save = self.save - write = self.write - self.save_reduce( - func, args, state=None, listitems=listitems, - dictitems=dictitems, obj=obj - ) - # backport of the Python 3.8 state_setter pickle operations - save(state_setter) - save(obj) # simple BINGET opcode as obj is already memoized. - save(state) - write(pickle.TUPLE2) - # Trigger a state_setter(obj, state) function call. - write(pickle.REDUCE) - # The purpose of state_setter is to carry-out an - # inplace modification of obj. We do not care about what the - # method might return, so its output is eventually removed from - # the stack. - write(pickle.POP) - - def save_global(self, obj, name=None, pack=struct.pack): - """ - Save a "global". - - The name of this method is somewhat misleading: all types get - dispatched here. - """ - if obj is type(None): # noqa - return self.save_reduce(type, (None,), obj=obj) - elif obj is type(Ellipsis): - return self.save_reduce(type, (Ellipsis,), obj=obj) - elif obj is type(NotImplemented): - return self.save_reduce(type, (NotImplemented,), obj=obj) - elif obj in _BUILTIN_TYPE_NAMES: - return self.save_reduce( - _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) - - if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch - # Parametrized typing constructs in Python < 3.7 are not - # compatible with type checks and ``isinstance`` semantics. For - # this reason, it is easier to detect them using a - # duck-typing-based check (``_is_parametrized_type_hint``) than - # to populate the Pickler's dispatch with type-specific savers. - self.save_reduce( - _create_parametrized_type_hint, - parametrized_type_hint_getinitargs(obj), - obj=obj - ) - elif name is not None: - Pickler.save_global(self, obj, name=name) - elif not _should_pickle_by_reference(obj, name=name): - self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj) - else: - Pickler.save_global(self, obj, name=name) - dispatch[type] = save_global - - def save_function(self, obj, name=None): - """ Registered with the dispatch to handle all function types. - - Determines what kind of function obj is (e.g. lambda, defined at - interactive prompt, etc) and handles the pickling appropriately. - """ - if _should_pickle_by_reference(obj, name=name): - return Pickler.save_global(self, obj, name=name) - elif PYPY and isinstance(obj.__code__, builtin_code_type): - return self.save_pypy_builtin_func(obj) - else: - return self._save_reduce_pickle5( - *self._dynamic_function_reduce(obj), obj=obj - ) - - def save_pypy_builtin_func(self, obj): - """Save pypy equivalent of builtin functions. - PyPy does not have the concept of builtin-functions. Instead, - builtin-functions are simple function instances, but with a - builtin-code attribute. - Most of the time, builtin functions should be pickled by attribute. - But PyPy has flaky support for __qualname__, so some builtin - functions such as float.__new__ will be classified as dynamic. For - this reason only, we created this special routine. Because - builtin-functions are not expected to have closure or globals, - there is no additional hack (compared the one already implemented - in pickle) to protect ourselves from reference cycles. A simple - (reconstructor, newargs, obj.__dict__) tuple is save_reduced. Note - also that PyPy improved their support for __qualname__ in v3.6, so - this routing should be removed when cloudpickle supports only PyPy - 3.6 and later. - """ - rv = (types.FunctionType, (obj.__code__, {}, obj.__name__, - obj.__defaults__, obj.__closure__), - obj.__dict__) - self.save_reduce(*rv, obj=obj) - - dispatch[types.FunctionType] = save_function diff --git a/srsly/cloudpickle/compat.py b/srsly/cloudpickle/compat.py deleted file mode 100644 index 5e9b527..0000000 --- a/srsly/cloudpickle/compat.py +++ /dev/null @@ -1,18 +0,0 @@ -import sys - - -if sys.version_info < (3, 8): - try: - import pickle5 as pickle # noqa: F401 - from pickle5 import Pickler # noqa: F401 - except ImportError: - import pickle # noqa: F401 - - # Use the Python pickler for old CPython versions - from pickle import _Pickler as Pickler # noqa: F401 -else: - import pickle # noqa: F401 - - # Pickler will the C implementation in CPython and the Python - # implementation in PyPy - from pickle import Pickler # noqa: F401 diff --git a/srsly/msgpack/__init__.py b/srsly/msgpack/__init__.py deleted file mode 100644 index 9bf2473..0000000 --- a/srsly/msgpack/__init__.py +++ /dev/null @@ -1,93 +0,0 @@ -# coding: utf-8 - -import functools -import catalogue - -# These need to be imported before packer and unpacker -from ._epoch import utc, epoch # noqa - -from ._version import version -from .exceptions import * - -# In msgpack-python these are put under a _cmsgpack module that textually includes -# them. I dislike this so I refactored it. -from ._packer import Packer as _Packer -from ._unpacker import unpackb as _unpackb -from ._unpacker import Unpacker as _Unpacker -from .ext import ExtType -from ._msgpack_numpy import encode_numpy as _encode_numpy -from ._msgpack_numpy import decode_numpy as _decode_numpy - - -msgpack_encoders = catalogue.create("srsly", "msgpack_encoders", entry_points=True) -msgpack_decoders = catalogue.create("srsly", "msgpack_decoders", entry_points=True) - -msgpack_encoders.register("numpy", func=_encode_numpy) -msgpack_decoders.register("numpy", func=_decode_numpy) - - -# msgpack_numpy extensions -class Packer(_Packer): - def __init__(self, *args, **kwargs): - default = kwargs.get("default") - for encoder in msgpack_encoders.get_all().values(): - default = functools.partial(encoder, chain=default) - kwargs["default"] = default - super(Packer, self).__init__(*args, **kwargs) - - -class Unpacker(_Unpacker): - def __init__(self, *args, **kwargs): - object_hook = kwargs.get("object_hook") - for decoder in msgpack_decoders.get_all().values(): - object_hook = functools.partial(decoder, chain=object_hook) - kwargs["object_hook"] = object_hook - super(Unpacker, self).__init__(*args, **kwargs) - - -def pack(o, stream, **kwargs): - """ - Pack an object and write it to a stream. - """ - packer = Packer(**kwargs) - stream.write(packer.pack(o)) - - -def packb(o, **kwargs): - """ - Pack an object and return the packed bytes. - """ - return Packer(**kwargs).pack(o) - - -def unpack(stream, **kwargs): - """ - Unpack a packed object from a stream. - """ - if "object_pairs_hook" not in kwargs: - object_hook = kwargs.get("object_hook") - for decoder in msgpack_decoders.get_all().values(): - object_hook = functools.partial(decoder, chain=object_hook) - kwargs["object_hook"] = object_hook - data = stream.read() - return _unpackb(data, **kwargs) - - -def unpackb(packed, **kwargs): - """ - Unpack a packed object. - """ - if "object_pairs_hook" not in kwargs: - object_hook = kwargs.get("object_hook") - for decoder in msgpack_decoders.get_all().values(): - object_hook = functools.partial(decoder, chain=object_hook) - kwargs["object_hook"] = object_hook - return _unpackb(packed, **kwargs) - - -# alias for compatibility to simplejson/marshal/pickle. -load = unpack -loads = unpackb - -dump = pack -dumps = packb diff --git a/srsly/msgpack/_epoch.pyx b/srsly/msgpack/_epoch.pyx deleted file mode 100644 index 27b4bea..0000000 --- a/srsly/msgpack/_epoch.pyx +++ /dev/null @@ -1,7 +0,0 @@ -from cpython.datetime cimport import_datetime, datetime_new - -import_datetime() -import datetime - -utc = datetime.timezone.utc -epoch = datetime_new(1970, 1, 1, 0, 0, 0, 0, tz=utc) diff --git a/srsly/msgpack/_msgpack_numpy.py b/srsly/msgpack/_msgpack_numpy.py deleted file mode 100644 index 4748bed..0000000 --- a/srsly/msgpack/_msgpack_numpy.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python - -""" -Support for serialization of numpy data types with msgpack. -""" - -# Copyright (c) 2013-2018, Lev E. Givon -# All rights reserved. -# Distributed under the terms of the BSD license: -# http://www.opensource.org/licenses/bsd-license -try: - import numpy as np - - has_numpy = True -except ImportError: - has_numpy = False - -try: - import cupy - - has_cupy = True -except ImportError: - has_cupy = False - - -def encode_numpy(obj, chain=None): - """ - Data encoder for serializing numpy data types. - """ - if not has_numpy: - return obj if chain is None else chain(obj) - if has_cupy and isinstance(obj, cupy.ndarray): - obj = obj.get() - if isinstance(obj, np.ndarray): - # If the dtype is structured, store the interface description; - # otherwise, store the corresponding array protocol type string: - if obj.dtype.kind == "V": - kind = b"V" - descr = obj.dtype.descr - else: - kind = b"" - descr = obj.dtype.str - return { - b"nd": True, - b"type": descr, - b"kind": kind, - b"shape": obj.shape, - b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(), - } - elif isinstance(obj, (np.bool_, np.number)): - return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data} - elif isinstance(obj, complex): - return {b"complex": True, b"data": obj.__repr__()} - else: - return obj if chain is None else chain(obj) - - -def tostr(x): - if isinstance(x, bytes): - return x.decode() - else: - return str(x) - - -def decode_numpy(obj, chain=None): - """ - Decoder for deserializing numpy data types. - """ - - try: - if b"nd" in obj: - if obj[b"nd"] is True: - - # Check if b'kind' is in obj to enable decoding of data - # serialized with older versions (#20): - if b"kind" in obj and obj[b"kind"] == b"V": - descr = [ - tuple(tostr(t) if type(t) is bytes else t for t in d) - for d in obj[b"type"] - ] - else: - descr = obj[b"type"] - return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape( - obj[b"shape"] - ) - else: - descr = obj[b"type"] - return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0] - elif b"complex" in obj: - return complex(tostr(obj[b"data"])) - else: - return obj if chain is None else chain(obj) - except KeyError: - return obj if chain is None else chain(obj) diff --git a/srsly/msgpack/_packer.pyx b/srsly/msgpack/_packer.pyx deleted file mode 100644 index 25f10a8..0000000 --- a/srsly/msgpack/_packer.pyx +++ /dev/null @@ -1,376 +0,0 @@ -# coding: utf-8 - -from cpython cimport * -from cpython.bytearray cimport PyByteArray_Check, PyByteArray_CheckExact -from cpython.datetime cimport ( - PyDateTime_CheckExact, PyDelta_CheckExact, - datetime_tzinfo, timedelta_days, timedelta_seconds, timedelta_microseconds, -) -from ._epoch import utc, epoch - -cdef ExtType -cdef Timestamp - -from .ext import ExtType, Timestamp -from .util import ensure_bytes - - -cdef extern from "Python.h": - - int PyMemoryView_Check(object obj) - -cdef extern from "pack.h": - struct msgpack_packer: - char* buf - size_t length - size_t buf_size - bint use_bin_type - - int msgpack_pack_nil(msgpack_packer* pk) except -1 - int msgpack_pack_true(msgpack_packer* pk) except -1 - int msgpack_pack_false(msgpack_packer* pk) except -1 - int msgpack_pack_long_long(msgpack_packer* pk, long long d) except -1 - int msgpack_pack_unsigned_long_long(msgpack_packer* pk, unsigned long long d) except -1 - int msgpack_pack_float(msgpack_packer* pk, float d) except -1 - int msgpack_pack_double(msgpack_packer* pk, double d) except -1 - int msgpack_pack_array(msgpack_packer* pk, size_t l) except -1 - int msgpack_pack_map(msgpack_packer* pk, size_t l) except -1 - int msgpack_pack_raw(msgpack_packer* pk, size_t l) except -1 - int msgpack_pack_bin(msgpack_packer* pk, size_t l) except -1 - int msgpack_pack_raw_body(msgpack_packer* pk, char* body, size_t l) except -1 - int msgpack_pack_ext(msgpack_packer* pk, char typecode, size_t l) except -1 - int msgpack_pack_timestamp(msgpack_packer* x, long long seconds, unsigned long nanoseconds) except -1 - - -cdef int DEFAULT_RECURSE_LIMIT=511 -cdef long long ITEM_LIMIT = (2**32)-1 - - -cdef inline int PyBytesLike_Check(object o): - return PyBytes_Check(o) or PyByteArray_Check(o) - - -cdef inline int PyBytesLike_CheckExact(object o): - return PyBytes_CheckExact(o) or PyByteArray_CheckExact(o) - - -cdef class Packer: - """ - MessagePack Packer - - Usage:: - - packer = Packer() - astream.write(packer.pack(a)) - astream.write(packer.pack(b)) - - Packer's constructor has some keyword arguments: - - :param default: - When specified, it should be callable. - Convert user type to builtin type that Packer supports. - See also simplejson's document. - - :param bool use_single_float: - Use single precision float type for float. (default: False) - - :param bool autoreset: - Reset buffer after each pack and return its content as `bytes`. (default: True). - If set this to false, use `bytes()` to get content and `.reset()` to clear buffer. - - :param bool use_bin_type: - Use bin type introduced in msgpack spec 2.0 for bytes. - It also enables str8 type for unicode. (default: True) - - :param bool strict_types: - If set to true, types will be checked to be exact. Derived classes - from serializeable types will not be serialized and will be - treated as unsupported type and forwarded to default. - Additionally tuples will not be serialized as lists. - This is useful when trying to implement accurate serialization - for python types. - - :param bool datetime: - If set to true, datetime with tzinfo is packed into Timestamp type. - Note that the tzinfo is stripped in the timestamp. - You can get UTC datetime with `timestamp=3` option of the Unpacker. - - :param str unicode_errors: - The error handler for encoding unicode. (default: 'strict') - DO NOT USE THIS!! This option is kept for very specific usage. - - :param int buf_size: - The size of the internal buffer. (default: 256*1024) - Useful if serialisation size can be correctly estimated, - avoid unnecessary reallocations. - """ - cdef msgpack_packer pk - cdef object _default - cdef size_t exports # number of exported buffers - cdef bint strict_types - cdef bint use_float - cdef bint autoreset - cdef bint datetime - cdef object _bencoding - cdef object _berrors - cdef const char *encoding - cdef const char *unicode_errors - - - def __cinit__(self, buf_size=256*1024, **_kwargs): - self.pk.buf = PyMem_Malloc(buf_size) - if self.pk.buf == NULL: - raise MemoryError("Unable to allocate internal buffer.") - self.pk.buf_size = buf_size - self.pk.length = 0 - self.exports = 0 - - def __dealloc__(self): - PyMem_Free(self.pk.buf) - self.pk.buf = NULL - assert self.exports == 0 - - cdef _check_exports(self): - if self.exports > 0: - raise BufferError("Existing exports of data: Packer cannot be changed") - - def __init__(self, *, default=None, encoding=None, - bint use_single_float=False, bint autoreset=True, bint use_bin_type=False, - bint strict_types=False, bint datetime=False, unicode_errors=None, - buf_size=256*1024): - self.use_float = use_single_float - self.strict_types = strict_types - self.autoreset = autoreset - self.datetime = datetime - self.pk.use_bin_type = use_bin_type - if default is not None: - if not PyCallable_Check(default): - raise TypeError("default must be a callable.") - self._default = default - - if encoding is None: - if PY_MAJOR_VERSION < 3: - encoding = 'utf-8' - if encoding is None: - self._bencoding = None - self.encoding = NULL - else: - self._bencoding = ensure_bytes(encoding) - self.encoding = self._bencoding - else: - self._bencoding = ensure_bytes(encoding) - self.encoding = self._bencoding - unicode_errors = ensure_bytes(unicode_errors) - self._berrors = unicode_errors - if unicode_errors is None: - self.unicode_errors = NULL - else: - self.unicode_errors = self._berrors - - # returns -2 when default should(o) be called - cdef int _pack_inner(self, object o, bint will_default, int nest_limit) except -1: - cdef long long llval - cdef unsigned long long ullval - cdef unsigned long ulval - cdef const char* rawval - cdef Py_ssize_t L - cdef Py_buffer view - cdef bint strict = self.strict_types - - if o is None: - msgpack_pack_nil(&self.pk) - elif o is True: - msgpack_pack_true(&self.pk) - elif o is False: - msgpack_pack_false(&self.pk) - elif PyLong_CheckExact(o) if strict else PyLong_Check(o): - try: - if o > 0: - ullval = o - msgpack_pack_unsigned_long_long(&self.pk, ullval) - else: - llval = o - msgpack_pack_long_long(&self.pk, llval) - except OverflowError as oe: - if will_default: - return -2 - else: - raise OverflowError("Integer value out of range") - elif PyFloat_CheckExact(o) if strict else PyFloat_Check(o): - if self.use_float: - msgpack_pack_float(&self.pk, o) - else: - msgpack_pack_double(&self.pk, o) - elif PyBytesLike_CheckExact(o) if strict else PyBytesLike_Check(o): - L = Py_SIZE(o) - if L > ITEM_LIMIT: - PyErr_Format(ValueError, b"%.200s object is too large", Py_TYPE(o).tp_name) - rawval = o - msgpack_pack_bin(&self.pk, L) - msgpack_pack_raw_body(&self.pk, rawval, L) - elif PyUnicode_CheckExact(o) if strict else PyUnicode_Check(o): - if self.unicode_errors == NULL and self.encoding == NULL: - rawval = PyUnicode_AsUTF8AndSize(o, &L) - if L >ITEM_LIMIT: - raise ValueError("unicode string is too large") - else: - o = PyUnicode_AsEncodedString(o, self.encoding, self.unicode_errors) - L = Py_SIZE(o) - if L > ITEM_LIMIT: - raise ValueError("unicode string is too large") - rawval = o - msgpack_pack_raw(&self.pk, L) - msgpack_pack_raw_body(&self.pk, rawval, L) - elif PyDict_CheckExact(o) if strict else PyDict_Check(o): - L = len(o) - if L > ITEM_LIMIT: - raise ValueError("dict is too large") - msgpack_pack_map(&self.pk, L) - for k, v in o.items(): - self._pack(k, nest_limit) - self._pack(v, nest_limit) - elif type(o) is ExtType if strict else isinstance(o, ExtType): - # This should be before Tuple because ExtType is namedtuple. - rawval = o.data - L = len(o.data) - if L > ITEM_LIMIT: - raise ValueError("EXT data is too large") - msgpack_pack_ext(&self.pk, o.code, L) - msgpack_pack_raw_body(&self.pk, rawval, L) - elif type(o) is Timestamp: - llval = o.seconds - ulval = o.nanoseconds - msgpack_pack_timestamp(&self.pk, llval, ulval) - elif PyList_CheckExact(o) if strict else (PyTuple_Check(o) or PyList_Check(o)): - L = Py_SIZE(o) - if L > ITEM_LIMIT: - raise ValueError("list is too large") - msgpack_pack_array(&self.pk, L) - for v in o: - self._pack(v, nest_limit) - elif PyMemoryView_Check(o): - PyObject_GetBuffer(o, &view, PyBUF_SIMPLE) - L = view.len - if L > ITEM_LIMIT: - PyBuffer_Release(&view); - raise ValueError("memoryview is too large") - try: - msgpack_pack_bin(&self.pk, L) - msgpack_pack_raw_body(&self.pk, view.buf, L) - finally: - PyBuffer_Release(&view); - elif self.datetime and PyDateTime_CheckExact(o) and datetime_tzinfo(o) is not None: - delta = o - epoch - if not PyDelta_CheckExact(delta): - raise ValueError("failed to calculate delta") - llval = timedelta_days(delta) * (24*60*60) + timedelta_seconds(delta) - ulval = timedelta_microseconds(delta) * 1000 - msgpack_pack_timestamp(&self.pk, llval, ulval) - elif will_default: - return -2 - elif self.datetime and PyDateTime_CheckExact(o): - # this should be later than will_default - PyErr_Format(ValueError, b"can not serialize '%.200s' object where tzinfo=None", Py_TYPE(o).tp_name) - else: - PyErr_Format(TypeError, b"can not serialize '%.200s' object", Py_TYPE(o).tp_name) - - cdef int _pack(self, object o, int nest_limit=DEFAULT_RECURSE_LIMIT) except -1: - cdef int ret - if nest_limit < 0: - raise ValueError("recursion limit exceeded.") - nest_limit -= 1 - if self._default is not None: - ret = self._pack_inner(o, 1, nest_limit) - if ret == -2: - o = self._default(o) - else: - return ret - return self._pack_inner(o, 0, nest_limit) - - def pack(self, object obj): - cdef int ret - self._check_exports() - try: - ret = self._pack(obj, DEFAULT_RECURSE_LIMIT) - except: - self.pk.length = 0 - raise - if ret: # should not happen. - raise RuntimeError("internal error") - if self.autoreset: - buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) - self.pk.length = 0 - return buf - - def pack_ext_type(self, typecode, data): - self._check_exports() - if len(data) > ITEM_LIMIT: - raise ValueError("ext data too large") - msgpack_pack_ext(&self.pk, typecode, len(data)) - msgpack_pack_raw_body(&self.pk, data, len(data)) - - def pack_array_header(self, long long size): - self._check_exports() - if size > ITEM_LIMIT: - raise ValueError("array too large") - msgpack_pack_array(&self.pk, size) - if self.autoreset: - buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) - self.pk.length = 0 - return buf - - def pack_map_header(self, long long size): - self._check_exports() - if size > ITEM_LIMIT: - raise ValueError("map too learge") - msgpack_pack_map(&self.pk, size) - if self.autoreset: - buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) - self.pk.length = 0 - return buf - - def pack_map_pairs(self, object pairs): - """ - Pack *pairs* as msgpack map type. - - *pairs* should be a sequence of pairs. - (`len(pairs)` and `for k, v in pairs:` should be supported.) - """ - self._check_exports() - size = len(pairs) - if size > ITEM_LIMIT: - raise ValueError("map too large") - msgpack_pack_map(&self.pk, size) - for k, v in pairs: - self._pack(k) - self._pack(v) - if self.autoreset: - buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) - self.pk.length = 0 - return buf - - def reset(self): - """Reset internal buffer. - - This method is useful only when autoreset=False. - """ - self._check_exports() - self.pk.length = 0 - - def bytes(self): - """Return internal buffer contents as bytes object""" - return PyBytes_FromStringAndSize(self.pk.buf, self.pk.length) - - def getbuffer(self): - """Return memoryview of internal buffer. - - Note: Packer now supports buffer protocol. You can use memoryview(packer). - """ - return memoryview(self) - - def __getbuffer__(self, Py_buffer *buffer, int flags): - PyBuffer_FillInfo(buffer, self, self.pk.buf, self.pk.length, 1, flags) - self.exports += 1 - - def __releasebuffer__(self, Py_buffer *buffer): - self.exports -= 1 diff --git a/srsly/msgpack/_unpacker.pyx b/srsly/msgpack/_unpacker.pyx deleted file mode 100644 index 631b9b5..0000000 --- a/srsly/msgpack/_unpacker.pyx +++ /dev/null @@ -1,560 +0,0 @@ -# coding: utf-8 - -from cpython cimport * -cdef extern from "Python.h": - ctypedef struct PyObject - object PyMemoryView_GetContiguous(object obj, int buffertype, char order) - -from libc.stdlib cimport * -from libc.string cimport * -from libc.limits cimport * -from libc.stdint cimport uint64_t - -from .exceptions import ( - BufferFull, - OutOfData, - ExtraData, - FormatError, - StackError, -) -from .ext import ExtType, Timestamp -from .util import ensure_bytes -from ._epoch import utc, epoch - -cdef object giga = 1_000_000_000 - - -cdef extern from "unpack.h": - ctypedef struct msgpack_user: - bint use_list - bint raw - bint has_pairs_hook # call object_hook with k-v pairs - bint strict_map_key - int timestamp - PyObject* object_hook - PyObject* list_hook - PyObject* ext_hook - PyObject* timestamp_t - PyObject *giga; - PyObject *utc; - const char *unicode_errors - const char *encoding - Py_ssize_t max_str_len - Py_ssize_t max_bin_len - Py_ssize_t max_array_len - Py_ssize_t max_map_len - Py_ssize_t max_ext_len - - ctypedef struct unpack_context: - msgpack_user user - PyObject* obj - Py_ssize_t count - - ctypedef int (*execute_fn)(unpack_context* ctx, const char* data, - Py_ssize_t len, Py_ssize_t* off) except? -1 - execute_fn unpack_construct - execute_fn unpack_skip - execute_fn read_array_header - execute_fn read_map_header - void unpack_init(unpack_context* ctx) - object unpack_data(unpack_context* ctx) - void unpack_clear(unpack_context* ctx) - -cdef inline init_ctx(unpack_context *ctx, - object object_hook, object object_pairs_hook, - object list_hook, object ext_hook, - bint use_list, bint raw, int timestamp, - bint strict_map_key, - const char* encoding, - const char* unicode_errors, - Py_ssize_t max_str_len, Py_ssize_t max_bin_len, - Py_ssize_t max_array_len, Py_ssize_t max_map_len, - Py_ssize_t max_ext_len): - unpack_init(ctx) - ctx.user.use_list = use_list - ctx.user.raw = raw - ctx.user.strict_map_key = strict_map_key - ctx.user.object_hook = ctx.user.list_hook = NULL - ctx.user.max_str_len = max_str_len - ctx.user.max_bin_len = max_bin_len - ctx.user.max_array_len = max_array_len - ctx.user.max_map_len = max_map_len - ctx.user.max_ext_len = max_ext_len - - if object_hook is not None and object_pairs_hook is not None: - raise TypeError("object_pairs_hook and object_hook are mutually exclusive.") - - if object_hook is not None: - if not PyCallable_Check(object_hook): - raise TypeError("object_hook must be a callable.") - ctx.user.object_hook = object_hook - - if object_pairs_hook is None: - ctx.user.has_pairs_hook = False - else: - if not PyCallable_Check(object_pairs_hook): - raise TypeError("object_pairs_hook must be a callable.") - ctx.user.object_hook = object_pairs_hook - ctx.user.has_pairs_hook = True - - if list_hook is not None: - if not PyCallable_Check(list_hook): - raise TypeError("list_hook must be a callable.") - ctx.user.list_hook = list_hook - - if ext_hook is not None: - if not PyCallable_Check(ext_hook): - raise TypeError("ext_hook must be a callable.") - ctx.user.ext_hook = ext_hook - - if timestamp < 0 or 3 < timestamp: - raise ValueError("timestamp must be 0..3") - - # Add Timestamp type to the user object so it may be used in unpack.h - ctx.user.timestamp = timestamp - ctx.user.timestamp_t = Timestamp - ctx.user.giga = giga - ctx.user.utc = utc - ctx.user.unicode_errors = unicode_errors - ctx.user.encoding = encoding - -def default_read_extended_type(typecode, data): - raise NotImplementedError("Cannot decode extended type with typecode=%d" % typecode) - -cdef inline int get_data_from_buffer(object obj, - Py_buffer *view, - char **buf, - Py_ssize_t *buffer_len) except 0: - cdef object contiguous - cdef Py_buffer tmp - if PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) == -1: - raise - if view.itemsize != 1: - PyBuffer_Release(view) - raise BufferError("cannot unpack from multi-byte object") - if PyBuffer_IsContiguous(view, b'A') == 0: - PyBuffer_Release(view) - # create a contiguous copy and get buffer - contiguous = PyMemoryView_GetContiguous(obj, PyBUF_READ, b'C') - PyObject_GetBuffer(contiguous, view, PyBUF_SIMPLE) - # view must hold the only reference to contiguous, - # so memory is freed when view is released - Py_DECREF(contiguous) - buffer_len[0] = view.len - buf[0] = view.buf - return 1 - - -def unpackb(object packed, *, object object_hook=None, object list_hook=None, - bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False, - encoding=None, - unicode_errors=None, - object_pairs_hook=None, ext_hook=ExtType, - Py_ssize_t max_str_len=-1, - Py_ssize_t max_bin_len=-1, - Py_ssize_t max_array_len=-1, - Py_ssize_t max_map_len=-1, - Py_ssize_t max_ext_len=-1): - """ - Unpack packed_bytes to object. Returns an unpacked object. - - Raises ``ExtraData`` when *packed* contains extra bytes. - Raises ``ValueError`` when *packed* is incomplete. - Raises ``FormatError`` when *packed* is not valid msgpack. - Raises ``StackError`` when *packed* contains too nested. - Other exceptions can be raised during unpacking. - - See :class:`Unpacker` for options. - - *max_xxx_len* options are configured automatically from ``len(packed)``. - """ - cdef unpack_context ctx - cdef Py_ssize_t off = 0 - cdef int ret - - cdef Py_buffer view - cdef char* buf = NULL - cdef Py_ssize_t buf_len - cdef const char* cenc = NULL - cdef const char* cerr = NULL - - if encoding is not None: - encoding = ensure_bytes(encoding) - cenc = encoding - - if unicode_errors is not None: - unicode_errors = ensure_bytes(unicode_errors) - cerr = unicode_errors - - get_data_from_buffer(packed, &view, &buf, &buf_len) - - if max_str_len == -1: - max_str_len = buf_len - if max_bin_len == -1: - max_bin_len = buf_len - if max_array_len == -1: - max_array_len = buf_len - if max_map_len == -1: - max_map_len = buf_len//2 - if max_ext_len == -1: - max_ext_len = buf_len - - try: - init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook, - use_list, raw, timestamp, strict_map_key, cenc, cerr, - max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len) - ret = unpack_construct(&ctx, buf, buf_len, &off) - finally: - PyBuffer_Release(&view); - - if ret == 1: - obj = unpack_data(&ctx) - if off < buf_len: - raise ExtraData(obj, PyBytes_FromStringAndSize(buf+off, buf_len-off)) - return obj - unpack_clear(&ctx) - if ret == 0: - raise ValueError("Unpack failed: incomplete input") - elif ret == -2: - raise FormatError - elif ret == -3: - raise StackError - raise ValueError("Unpack failed: error = %d" % (ret,)) - - -cdef class Unpacker: - """Streaming unpacker. - - Arguments: - - :param file_like: - File-like object having `.read(n)` method. - If specified, unpacker reads serialized data from it and `.feed()` is not usable. - - :param int read_size: - Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`) - - :param bool use_list: - If true, unpack msgpack array to Python list. - Otherwise, unpack to Python tuple. (default: True) - - :param bool raw: - If true, unpack msgpack raw to Python bytes. - Otherwise, unpack to Python str by decoding with UTF-8 encoding (default). - - :param int timestamp: - Control how timestamp type is unpacked: - - 0 - Timestamp - 1 - float (Seconds from the EPOCH) - 2 - int (Nanoseconds from the EPOCH) - 3 - datetime.datetime (UTC). - - :param bool strict_map_key: - If true (default), only str or bytes are accepted for map (dict) keys. - - :param object_hook: - When specified, it should be callable. - Unpacker calls it with a dict argument after unpacking msgpack map. - (See also simplejson) - - :param object_pairs_hook: - When specified, it should be callable. - Unpacker calls it with a list of key-value pairs after unpacking msgpack map. - (See also simplejson) - - :param str unicode_errors: - The error handler for decoding unicode. (default: 'strict') - This option should be used only when you have msgpack data which - contains invalid UTF-8 string. - - :param int max_buffer_size: - Limits size of data waiting unpacked. 0 means 2**32-1. - The default value is 100*1024*1024 (100MiB). - Raises `BufferFull` exception when it is insufficient. - You should set this parameter when unpacking data from untrusted source. - - :param int max_str_len: - Deprecated, use *max_buffer_size* instead. - Limits max length of str. (default: max_buffer_size) - - :param int max_bin_len: - Deprecated, use *max_buffer_size* instead. - Limits max length of bin. (default: max_buffer_size) - - :param int max_array_len: - Limits max length of array. - (default: max_buffer_size) - - :param int max_map_len: - Limits max length of map. - (default: max_buffer_size//2) - - :param int max_ext_len: - Deprecated, use *max_buffer_size* instead. - Limits max size of ext type. (default: max_buffer_size) - - Example of streaming deserialize from file-like object:: - - unpacker = Unpacker(file_like) - for o in unpacker: - process(o) - - Example of streaming deserialize from socket:: - - unpacker = Unpacker() - while True: - buf = sock.recv(1024**2) - if not buf: - break - unpacker.feed(buf) - for o in unpacker: - process(o) - - Raises ``ExtraData`` when *packed* contains extra bytes. - Raises ``OutOfData`` when *packed* is incomplete. - Raises ``FormatError`` when *packed* is not valid msgpack. - Raises ``StackError`` when *packed* contains too nested. - Other exceptions can be raised during unpacking. - """ - cdef unpack_context ctx - cdef char* buf - cdef Py_ssize_t buf_size, buf_head, buf_tail - cdef object file_like - cdef object file_like_read - cdef Py_ssize_t read_size - # To maintain refcnt. - cdef object object_hook, object_pairs_hook, list_hook, ext_hook - cdef object unicode_errors - cdef Py_ssize_t max_buffer_size - cdef uint64_t stream_offset - - def __cinit__(self): - self.buf = NULL - - def __dealloc__(self): - PyMem_Free(self.buf) - self.buf = NULL - - def __init__(self, file_like=None, *, Py_ssize_t read_size=0, - bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False, - object object_hook=None, object object_pairs_hook=None, object list_hook=None, - unicode_errors=None, Py_ssize_t max_buffer_size=100*1024*1024, - object ext_hook=ExtType, - Py_ssize_t max_str_len=-1, - Py_ssize_t max_bin_len=-1, - Py_ssize_t max_array_len=-1, - Py_ssize_t max_map_len=-1, - Py_ssize_t max_ext_len=-1): - cdef const char* cerr = NULL - cdef const char* cenc = NULL - - self.object_hook = object_hook - self.object_pairs_hook = object_pairs_hook - self.list_hook = list_hook - self.ext_hook = ext_hook - - self.file_like = file_like - if file_like: - self.file_like_read = file_like.read - if not PyCallable_Check(self.file_like_read): - raise TypeError("`file_like.read` must be a callable.") - - if not max_buffer_size: - max_buffer_size = INT_MAX - if max_str_len == -1: - max_str_len = max_buffer_size - if max_bin_len == -1: - max_bin_len = max_buffer_size - if max_array_len == -1: - max_array_len = max_buffer_size - if max_map_len == -1: - max_map_len = max_buffer_size//2 - if max_ext_len == -1: - max_ext_len = max_buffer_size - - if read_size > max_buffer_size: - raise ValueError("read_size should be less or equal to max_buffer_size") - if not read_size: - read_size = min(max_buffer_size, 1024**2) - - self.max_buffer_size = max_buffer_size - self.read_size = read_size - self.buf = PyMem_Malloc(read_size) - if self.buf == NULL: - raise MemoryError("Unable to allocate internal buffer.") - self.buf_size = read_size - self.buf_head = 0 - self.buf_tail = 0 - self.stream_offset = 0 - - if unicode_errors is not None: - self.unicode_errors = unicode_errors - cerr = unicode_errors - - init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook, - ext_hook, use_list, raw, timestamp, strict_map_key, cenc, cerr, - max_str_len, max_bin_len, max_array_len, - max_map_len, max_ext_len) - - def feed(self, object next_bytes): - """Append `next_bytes` to internal buffer.""" - cdef Py_buffer pybuff - cdef char* buf - cdef Py_ssize_t buf_len - - if self.file_like is not None: - raise AssertionError( - "unpacker.feed() is not be able to use with `file_like`.") - - get_data_from_buffer(next_bytes, &pybuff, &buf, &buf_len) - try: - self.append_buffer(buf, buf_len) - finally: - PyBuffer_Release(&pybuff) - - cdef append_buffer(self, void* _buf, Py_ssize_t _buf_len): - cdef: - char* buf = self.buf - char* new_buf - Py_ssize_t head = self.buf_head - Py_ssize_t tail = self.buf_tail - Py_ssize_t buf_size = self.buf_size - Py_ssize_t new_size - - if tail + _buf_len > buf_size: - if ((tail - head) + _buf_len) <= buf_size: - # move to front. - memmove(buf, buf + head, tail - head) - tail -= head - head = 0 - else: - # expand buffer. - new_size = (tail-head) + _buf_len - if new_size > self.max_buffer_size: - raise BufferFull - new_size = min(new_size*2, self.max_buffer_size) - new_buf = PyMem_Malloc(new_size) - if new_buf == NULL: - # self.buf still holds old buffer and will be freed during - # obj destruction - raise MemoryError("Unable to enlarge internal buffer.") - memcpy(new_buf, buf + head, tail - head) - PyMem_Free(buf) - - buf = new_buf - buf_size = new_size - tail -= head - head = 0 - - memcpy(buf + tail, (_buf), _buf_len) - self.buf = buf - self.buf_head = head - self.buf_size = buf_size - self.buf_tail = tail + _buf_len - - cdef int read_from_file(self) except -1: - cdef Py_ssize_t remains = self.max_buffer_size - (self.buf_tail - self.buf_head) - if remains <= 0: - raise BufferFull - - next_bytes = self.file_like_read(min(self.read_size, remains)) - if next_bytes: - self.append_buffer(PyBytes_AsString(next_bytes), PyBytes_Size(next_bytes)) - else: - self.file_like = None - return 0 - - cdef object _unpack(self, execute_fn execute, bint iter=0): - cdef int ret - cdef object obj - cdef Py_ssize_t prev_head - - while 1: - prev_head = self.buf_head - if prev_head < self.buf_tail: - ret = execute(&self.ctx, self.buf, self.buf_tail, &self.buf_head) - self.stream_offset += self.buf_head - prev_head - else: - ret = 0 - - if ret == 1: - obj = unpack_data(&self.ctx) - unpack_init(&self.ctx) - return obj - elif ret == 0: - if self.file_like is not None: - self.read_from_file() - continue - if iter: - raise StopIteration("No more data to unpack.") - else: - raise OutOfData("No more data to unpack.") - elif ret == -2: - raise FormatError - elif ret == -3: - raise StackError - else: - raise ValueError("Unpack failed: error = %d" % (ret,)) - - def read_bytes(self, Py_ssize_t nbytes): - """Read a specified number of raw bytes from the stream""" - cdef Py_ssize_t nread - nread = min(self.buf_tail - self.buf_head, nbytes) - ret = PyBytes_FromStringAndSize(self.buf + self.buf_head, nread) - self.buf_head += nread - if nread < nbytes and self.file_like is not None: - ret += self.file_like.read(nbytes - nread) - nread = len(ret) - self.stream_offset += nread - return ret - - def unpack(self): - """Unpack one object - - Raises `OutOfData` when there are no more bytes to unpack. - """ - return self._unpack(unpack_construct) - - def skip(self): - """Read and ignore one object, returning None - - Raises `OutOfData` when there are no more bytes to unpack. - """ - return self._unpack(unpack_skip) - - def read_array_header(self): - """assuming the next object is an array, return its size n, such that - the next n unpack() calls will iterate over its contents. - - Raises `OutOfData` when there are no more bytes to unpack. - """ - return self._unpack(read_array_header) - - def read_map_header(self): - """assuming the next object is a map, return its size n, such that the - next n * 2 unpack() calls will iterate over its key-value pairs. - - Raises `OutOfData` when there are no more bytes to unpack. - """ - return self._unpack(read_map_header) - - def tell(self): - """Returns the current position of the Unpacker in bytes, i.e., the - number of bytes that were read from the input, also the starting - position of the next object. - """ - return self.stream_offset - - def __iter__(self): - return self - - def __next__(self): - return self._unpack(unpack_construct, 1) - - # for debug. - #def _buf(self): - # return PyString_FromStringAndSize(self.buf, self.buf_tail) - - #def _off(self): - # return self.buf_head diff --git a/srsly/msgpack/_version.py b/srsly/msgpack/_version.py deleted file mode 100644 index 1b75e0a..0000000 --- a/srsly/msgpack/_version.py +++ /dev/null @@ -1 +0,0 @@ -version = (1, 1, 0) diff --git a/srsly/msgpack/exceptions.py b/srsly/msgpack/exceptions.py deleted file mode 100644 index d6d2615..0000000 --- a/srsly/msgpack/exceptions.py +++ /dev/null @@ -1,48 +0,0 @@ -class UnpackException(Exception): - """Base class for some exceptions raised while unpacking. - - NOTE: unpack may raise exception other than subclass of - UnpackException. If you want to catch all error, catch - Exception instead. - """ - - -class BufferFull(UnpackException): - pass - - -class OutOfData(UnpackException): - pass - - -class FormatError(ValueError, UnpackException): - """Invalid msgpack format""" - - -class StackError(ValueError, UnpackException): - """Too nested""" - - -# Deprecated. Use ValueError instead -UnpackValueError = ValueError - - -class ExtraData(UnpackValueError): - """ExtraData is raised when there is trailing data. - - This exception is raised while only one-shot (not streaming) - unpack. - """ - - def __init__(self, unpacked, extra): - self.unpacked = unpacked - self.extra = extra - - def __str__(self): - return "unpack(b) received extra data." - - -# Deprecated. Use Exception instead to catch all exception during packing. -PackException = Exception -PackValueError = ValueError -PackOverflowError = OverflowError diff --git a/srsly/msgpack/ext.py b/srsly/msgpack/ext.py deleted file mode 100644 index 9694819..0000000 --- a/srsly/msgpack/ext.py +++ /dev/null @@ -1,170 +0,0 @@ -import datetime -import struct -from collections import namedtuple - - -class ExtType(namedtuple("ExtType", "code data")): - """ExtType represents ext type in msgpack.""" - - def __new__(cls, code, data): - if not isinstance(code, int): - raise TypeError("code must be int") - if not isinstance(data, bytes): - raise TypeError("data must be bytes") - if not 0 <= code <= 127: - raise ValueError("code must be 0~127") - return super().__new__(cls, code, data) - - -class Timestamp: - """Timestamp represents the Timestamp extension type in msgpack. - - When built with Cython, msgpack uses C methods to pack and unpack `Timestamp`. - When using pure-Python msgpack, :func:`to_bytes` and :func:`from_bytes` are used to pack and - unpack `Timestamp`. - - This class is immutable: Do not override seconds and nanoseconds. - """ - - __slots__ = ["seconds", "nanoseconds"] - - def __init__(self, seconds, nanoseconds=0): - """Initialize a Timestamp object. - - :param int seconds: - Number of seconds since the UNIX epoch (00:00:00 UTC Jan 1 1970, minus leap seconds). - May be negative. - - :param int nanoseconds: - Number of nanoseconds to add to `seconds` to get fractional time. - Maximum is 999_999_999. Default is 0. - - Note: Negative times (before the UNIX epoch) are represented as neg. seconds + pos. ns. - """ - if not isinstance(seconds, int): - raise TypeError("seconds must be an integer") - if not isinstance(nanoseconds, int): - raise TypeError("nanoseconds must be an integer") - if not (0 <= nanoseconds < 10**9): - raise ValueError("nanoseconds must be a non-negative integer less than 999999999.") - self.seconds = seconds - self.nanoseconds = nanoseconds - - def __repr__(self): - """String representation of Timestamp.""" - return f"Timestamp(seconds={self.seconds}, nanoseconds={self.nanoseconds})" - - def __eq__(self, other): - """Check for equality with another Timestamp object""" - if type(other) is self.__class__: - return self.seconds == other.seconds and self.nanoseconds == other.nanoseconds - return False - - def __ne__(self, other): - """not-equals method (see :func:`__eq__()`)""" - return not self.__eq__(other) - - def __hash__(self): - return hash((self.seconds, self.nanoseconds)) - - @staticmethod - def from_bytes(b): - """Unpack bytes into a `Timestamp` object. - - Used for pure-Python msgpack unpacking. - - :param b: Payload from msgpack ext message with code -1 - :type b: bytes - - :returns: Timestamp object unpacked from msgpack ext payload - :rtype: Timestamp - """ - if len(b) == 4: - seconds = struct.unpack("!L", b)[0] - nanoseconds = 0 - elif len(b) == 8: - data64 = struct.unpack("!Q", b)[0] - seconds = data64 & 0x00000003FFFFFFFF - nanoseconds = data64 >> 34 - elif len(b) == 12: - nanoseconds, seconds = struct.unpack("!Iq", b) - else: - raise ValueError( - "Timestamp type can only be created from 32, 64, or 96-bit byte objects" - ) - return Timestamp(seconds, nanoseconds) - - def to_bytes(self): - """Pack this Timestamp object into bytes. - - Used for pure-Python msgpack packing. - - :returns data: Payload for EXT message with code -1 (timestamp type) - :rtype: bytes - """ - if (self.seconds >> 34) == 0: # seconds is non-negative and fits in 34 bits - data64 = self.nanoseconds << 34 | self.seconds - if data64 & 0xFFFFFFFF00000000 == 0: - # nanoseconds is zero and seconds < 2**32, so timestamp 32 - data = struct.pack("!L", data64) - else: - # timestamp 64 - data = struct.pack("!Q", data64) - else: - # timestamp 96 - data = struct.pack("!Iq", self.nanoseconds, self.seconds) - return data - - @staticmethod - def from_unix(unix_sec): - """Create a Timestamp from posix timestamp in seconds. - - :param unix_float: Posix timestamp in seconds. - :type unix_float: int or float - """ - seconds = int(unix_sec // 1) - nanoseconds = int((unix_sec % 1) * 10**9) - return Timestamp(seconds, nanoseconds) - - def to_unix(self): - """Get the timestamp as a floating-point value. - - :returns: posix timestamp - :rtype: float - """ - return self.seconds + self.nanoseconds / 1e9 - - @staticmethod - def from_unix_nano(unix_ns): - """Create a Timestamp from posix timestamp in nanoseconds. - - :param int unix_ns: Posix timestamp in nanoseconds. - :rtype: Timestamp - """ - return Timestamp(*divmod(unix_ns, 10**9)) - - def to_unix_nano(self): - """Get the timestamp as a unixtime in nanoseconds. - - :returns: posix timestamp in nanoseconds - :rtype: int - """ - return self.seconds * 10**9 + self.nanoseconds - - def to_datetime(self): - """Get the timestamp as a UTC datetime. - - :rtype: `datetime.datetime` - """ - utc = datetime.timezone.utc - return datetime.datetime.fromtimestamp(0, utc) + datetime.timedelta( - seconds=self.seconds, microseconds=self.nanoseconds // 1000 - ) - - @staticmethod - def from_datetime(dt): - """Create a Timestamp from datetime with tzinfo. - - :rtype: Timestamp - """ - return Timestamp(seconds=int(dt.timestamp()), nanoseconds=dt.microsecond * 1000) diff --git a/srsly/msgpack/fallback.py b/srsly/msgpack/fallback.py deleted file mode 100644 index b02e47c..0000000 --- a/srsly/msgpack/fallback.py +++ /dev/null @@ -1,929 +0,0 @@ -"""Fallback pure Python implementation of msgpack""" - -import struct -import sys -from datetime import datetime as _DateTime - -if hasattr(sys, "pypy_version_info"): - from __pypy__ import newlist_hint - from __pypy__.builders import BytesBuilder - - _USING_STRINGBUILDER = True - - class BytesIO: - def __init__(self, s=b""): - if s: - self.builder = BytesBuilder(len(s)) - self.builder.append(s) - else: - self.builder = BytesBuilder() - - def write(self, s): - if isinstance(s, memoryview): - s = s.tobytes() - elif isinstance(s, bytearray): - s = bytes(s) - self.builder.append(s) - - def getvalue(self): - return self.builder.build() - -else: - from io import BytesIO - - _USING_STRINGBUILDER = False - - def newlist_hint(size): - return [] - - -from .exceptions import BufferFull, ExtraData, FormatError, OutOfData, StackError -from .ext import ExtType, Timestamp - -EX_SKIP = 0 -EX_CONSTRUCT = 1 -EX_READ_ARRAY_HEADER = 2 -EX_READ_MAP_HEADER = 3 - -TYPE_IMMEDIATE = 0 -TYPE_ARRAY = 1 -TYPE_MAP = 2 -TYPE_RAW = 3 -TYPE_BIN = 4 -TYPE_EXT = 5 - -DEFAULT_RECURSE_LIMIT = 511 - - -def _check_type_strict(obj, t, type=type, tuple=tuple): - if type(t) is tuple: - return type(obj) in t - else: - return type(obj) is t - - -def _get_data_from_buffer(obj): - view = memoryview(obj) - if view.itemsize != 1: - raise ValueError("cannot unpack from multi-byte object") - return view - - -def unpackb(packed, **kwargs): - """ - Unpack an object from `packed`. - - Raises ``ExtraData`` when *packed* contains extra bytes. - Raises ``ValueError`` when *packed* is incomplete. - Raises ``FormatError`` when *packed* is not valid msgpack. - Raises ``StackError`` when *packed* contains too nested. - Other exceptions can be raised during unpacking. - - See :class:`Unpacker` for options. - """ - unpacker = Unpacker(None, max_buffer_size=len(packed), **kwargs) - unpacker.feed(packed) - try: - ret = unpacker._unpack() - except OutOfData: - raise ValueError("Unpack failed: incomplete input") - except RecursionError: - raise StackError - if unpacker._got_extradata(): - raise ExtraData(ret, unpacker._get_extradata()) - return ret - - -_NO_FORMAT_USED = "" -_MSGPACK_HEADERS = { - 0xC4: (1, _NO_FORMAT_USED, TYPE_BIN), - 0xC5: (2, ">H", TYPE_BIN), - 0xC6: (4, ">I", TYPE_BIN), - 0xC7: (2, "Bb", TYPE_EXT), - 0xC8: (3, ">Hb", TYPE_EXT), - 0xC9: (5, ">Ib", TYPE_EXT), - 0xCA: (4, ">f"), - 0xCB: (8, ">d"), - 0xCC: (1, _NO_FORMAT_USED), - 0xCD: (2, ">H"), - 0xCE: (4, ">I"), - 0xCF: (8, ">Q"), - 0xD0: (1, "b"), - 0xD1: (2, ">h"), - 0xD2: (4, ">i"), - 0xD3: (8, ">q"), - 0xD4: (1, "b1s", TYPE_EXT), - 0xD5: (2, "b2s", TYPE_EXT), - 0xD6: (4, "b4s", TYPE_EXT), - 0xD7: (8, "b8s", TYPE_EXT), - 0xD8: (16, "b16s", TYPE_EXT), - 0xD9: (1, _NO_FORMAT_USED, TYPE_RAW), - 0xDA: (2, ">H", TYPE_RAW), - 0xDB: (4, ">I", TYPE_RAW), - 0xDC: (2, ">H", TYPE_ARRAY), - 0xDD: (4, ">I", TYPE_ARRAY), - 0xDE: (2, ">H", TYPE_MAP), - 0xDF: (4, ">I", TYPE_MAP), -} - - -class Unpacker: - """Streaming unpacker. - - Arguments: - - :param file_like: - File-like object having `.read(n)` method. - If specified, unpacker reads serialized data from it and `.feed()` is not usable. - - :param int read_size: - Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`) - - :param bool use_list: - If true, unpack msgpack array to Python list. - Otherwise, unpack to Python tuple. (default: True) - - :param bool raw: - If true, unpack msgpack raw to Python bytes. - Otherwise, unpack to Python str by decoding with UTF-8 encoding (default). - - :param int timestamp: - Control how timestamp type is unpacked: - - 0 - Timestamp - 1 - float (Seconds from the EPOCH) - 2 - int (Nanoseconds from the EPOCH) - 3 - datetime.datetime (UTC). - - :param bool strict_map_key: - If true (default), only str or bytes are accepted for map (dict) keys. - - :param object_hook: - When specified, it should be callable. - Unpacker calls it with a dict argument after unpacking msgpack map. - (See also simplejson) - - :param object_pairs_hook: - When specified, it should be callable. - Unpacker calls it with a list of key-value pairs after unpacking msgpack map. - (See also simplejson) - - :param str unicode_errors: - The error handler for decoding unicode. (default: 'strict') - This option should be used only when you have msgpack data which - contains invalid UTF-8 string. - - :param int max_buffer_size: - Limits size of data waiting unpacked. 0 means 2**32-1. - The default value is 100*1024*1024 (100MiB). - Raises `BufferFull` exception when it is insufficient. - You should set this parameter when unpacking data from untrusted source. - - :param int max_str_len: - Deprecated, use *max_buffer_size* instead. - Limits max length of str. (default: max_buffer_size) - - :param int max_bin_len: - Deprecated, use *max_buffer_size* instead. - Limits max length of bin. (default: max_buffer_size) - - :param int max_array_len: - Limits max length of array. - (default: max_buffer_size) - - :param int max_map_len: - Limits max length of map. - (default: max_buffer_size//2) - - :param int max_ext_len: - Deprecated, use *max_buffer_size* instead. - Limits max size of ext type. (default: max_buffer_size) - - Example of streaming deserialize from file-like object:: - - unpacker = Unpacker(file_like) - for o in unpacker: - process(o) - - Example of streaming deserialize from socket:: - - unpacker = Unpacker() - while True: - buf = sock.recv(1024**2) - if not buf: - break - unpacker.feed(buf) - for o in unpacker: - process(o) - - Raises ``ExtraData`` when *packed* contains extra bytes. - Raises ``OutOfData`` when *packed* is incomplete. - Raises ``FormatError`` when *packed* is not valid msgpack. - Raises ``StackError`` when *packed* contains too nested. - Other exceptions can be raised during unpacking. - """ - - def __init__( - self, - file_like=None, - *, - read_size=0, - use_list=True, - raw=False, - timestamp=0, - strict_map_key=True, - object_hook=None, - object_pairs_hook=None, - list_hook=None, - unicode_errors=None, - max_buffer_size=100 * 1024 * 1024, - ext_hook=ExtType, - max_str_len=-1, - max_bin_len=-1, - max_array_len=-1, - max_map_len=-1, - max_ext_len=-1, - ): - if unicode_errors is None: - unicode_errors = "strict" - - if file_like is None: - self._feeding = True - else: - if not callable(file_like.read): - raise TypeError("`file_like.read` must be callable") - self.file_like = file_like - self._feeding = False - - #: array of bytes fed. - self._buffer = bytearray() - #: Which position we currently reads - self._buff_i = 0 - - # When Unpacker is used as an iterable, between the calls to next(), - # the buffer is not "consumed" completely, for efficiency sake. - # Instead, it is done sloppily. To make sure we raise BufferFull at - # the correct moments, we have to keep track of how sloppy we were. - # Furthermore, when the buffer is incomplete (that is: in the case - # we raise an OutOfData) we need to rollback the buffer to the correct - # state, which _buf_checkpoint records. - self._buf_checkpoint = 0 - - if not max_buffer_size: - max_buffer_size = 2**31 - 1 - if max_str_len == -1: - max_str_len = max_buffer_size - if max_bin_len == -1: - max_bin_len = max_buffer_size - if max_array_len == -1: - max_array_len = max_buffer_size - if max_map_len == -1: - max_map_len = max_buffer_size // 2 - if max_ext_len == -1: - max_ext_len = max_buffer_size - - self._max_buffer_size = max_buffer_size - if read_size > self._max_buffer_size: - raise ValueError("read_size must be smaller than max_buffer_size") - self._read_size = read_size or min(self._max_buffer_size, 16 * 1024) - self._raw = bool(raw) - self._strict_map_key = bool(strict_map_key) - self._unicode_errors = unicode_errors - self._use_list = use_list - if not (0 <= timestamp <= 3): - raise ValueError("timestamp must be 0..3") - self._timestamp = timestamp - self._list_hook = list_hook - self._object_hook = object_hook - self._object_pairs_hook = object_pairs_hook - self._ext_hook = ext_hook - self._max_str_len = max_str_len - self._max_bin_len = max_bin_len - self._max_array_len = max_array_len - self._max_map_len = max_map_len - self._max_ext_len = max_ext_len - self._stream_offset = 0 - - if list_hook is not None and not callable(list_hook): - raise TypeError("`list_hook` is not callable") - if object_hook is not None and not callable(object_hook): - raise TypeError("`object_hook` is not callable") - if object_pairs_hook is not None and not callable(object_pairs_hook): - raise TypeError("`object_pairs_hook` is not callable") - if object_hook is not None and object_pairs_hook is not None: - raise TypeError("object_pairs_hook and object_hook are mutually exclusive") - if not callable(ext_hook): - raise TypeError("`ext_hook` is not callable") - - def feed(self, next_bytes): - assert self._feeding - view = _get_data_from_buffer(next_bytes) - if len(self._buffer) - self._buff_i + len(view) > self._max_buffer_size: - raise BufferFull - - # Strip buffer before checkpoint before reading file. - if self._buf_checkpoint > 0: - del self._buffer[: self._buf_checkpoint] - self._buff_i -= self._buf_checkpoint - self._buf_checkpoint = 0 - - # Use extend here: INPLACE_ADD += doesn't reliably typecast memoryview in jython - self._buffer.extend(view) - view.release() - - def _consume(self): - """Gets rid of the used parts of the buffer.""" - self._stream_offset += self._buff_i - self._buf_checkpoint - self._buf_checkpoint = self._buff_i - - def _got_extradata(self): - return self._buff_i < len(self._buffer) - - def _get_extradata(self): - return self._buffer[self._buff_i :] - - def read_bytes(self, n): - ret = self._read(n, raise_outofdata=False) - self._consume() - return ret - - def _read(self, n, raise_outofdata=True): - # (int) -> bytearray - self._reserve(n, raise_outofdata=raise_outofdata) - i = self._buff_i - ret = self._buffer[i : i + n] - self._buff_i = i + len(ret) - return ret - - def _reserve(self, n, raise_outofdata=True): - remain_bytes = len(self._buffer) - self._buff_i - n - - # Fast path: buffer has n bytes already - if remain_bytes >= 0: - return - - if self._feeding: - self._buff_i = self._buf_checkpoint - raise OutOfData - - # Strip buffer before checkpoint before reading file. - if self._buf_checkpoint > 0: - del self._buffer[: self._buf_checkpoint] - self._buff_i -= self._buf_checkpoint - self._buf_checkpoint = 0 - - # Read from file - remain_bytes = -remain_bytes - if remain_bytes + len(self._buffer) > self._max_buffer_size: - raise BufferFull - while remain_bytes > 0: - to_read_bytes = max(self._read_size, remain_bytes) - read_data = self.file_like.read(to_read_bytes) - if not read_data: - break - assert isinstance(read_data, bytes) - self._buffer += read_data - remain_bytes -= len(read_data) - - if len(self._buffer) < n + self._buff_i and raise_outofdata: - self._buff_i = 0 # rollback - raise OutOfData - - def _read_header(self): - typ = TYPE_IMMEDIATE - n = 0 - obj = None - self._reserve(1) - b = self._buffer[self._buff_i] - self._buff_i += 1 - if b & 0b10000000 == 0: - obj = b - elif b & 0b11100000 == 0b11100000: - obj = -1 - (b ^ 0xFF) - elif b & 0b11100000 == 0b10100000: - n = b & 0b00011111 - typ = TYPE_RAW - if n > self._max_str_len: - raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})") - obj = self._read(n) - elif b & 0b11110000 == 0b10010000: - n = b & 0b00001111 - typ = TYPE_ARRAY - if n > self._max_array_len: - raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})") - elif b & 0b11110000 == 0b10000000: - n = b & 0b00001111 - typ = TYPE_MAP - if n > self._max_map_len: - raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})") - elif b == 0xC0: - obj = None - elif b == 0xC2: - obj = False - elif b == 0xC3: - obj = True - elif 0xC4 <= b <= 0xC6: - size, fmt, typ = _MSGPACK_HEADERS[b] - self._reserve(size) - if len(fmt) > 0: - n = struct.unpack_from(fmt, self._buffer, self._buff_i)[0] - else: - n = self._buffer[self._buff_i] - self._buff_i += size - if n > self._max_bin_len: - raise ValueError(f"{n} exceeds max_bin_len({self._max_bin_len})") - obj = self._read(n) - elif 0xC7 <= b <= 0xC9: - size, fmt, typ = _MSGPACK_HEADERS[b] - self._reserve(size) - L, n = struct.unpack_from(fmt, self._buffer, self._buff_i) - self._buff_i += size - if L > self._max_ext_len: - raise ValueError(f"{L} exceeds max_ext_len({self._max_ext_len})") - obj = self._read(L) - elif 0xCA <= b <= 0xD3: - size, fmt = _MSGPACK_HEADERS[b] - self._reserve(size) - if len(fmt) > 0: - obj = struct.unpack_from(fmt, self._buffer, self._buff_i)[0] - else: - obj = self._buffer[self._buff_i] - self._buff_i += size - elif 0xD4 <= b <= 0xD8: - size, fmt, typ = _MSGPACK_HEADERS[b] - if self._max_ext_len < size: - raise ValueError(f"{size} exceeds max_ext_len({self._max_ext_len})") - self._reserve(size + 1) - n, obj = struct.unpack_from(fmt, self._buffer, self._buff_i) - self._buff_i += size + 1 - elif 0xD9 <= b <= 0xDB: - size, fmt, typ = _MSGPACK_HEADERS[b] - self._reserve(size) - if len(fmt) > 0: - (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i) - else: - n = self._buffer[self._buff_i] - self._buff_i += size - if n > self._max_str_len: - raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})") - obj = self._read(n) - elif 0xDC <= b <= 0xDD: - size, fmt, typ = _MSGPACK_HEADERS[b] - self._reserve(size) - (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i) - self._buff_i += size - if n > self._max_array_len: - raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})") - elif 0xDE <= b <= 0xDF: - size, fmt, typ = _MSGPACK_HEADERS[b] - self._reserve(size) - (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i) - self._buff_i += size - if n > self._max_map_len: - raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})") - else: - raise FormatError("Unknown header: 0x%x" % b) - return typ, n, obj - - def _unpack(self, execute=EX_CONSTRUCT): - typ, n, obj = self._read_header() - - if execute == EX_READ_ARRAY_HEADER: - if typ != TYPE_ARRAY: - raise ValueError("Expected array") - return n - if execute == EX_READ_MAP_HEADER: - if typ != TYPE_MAP: - raise ValueError("Expected map") - return n - # TODO should we eliminate the recursion? - if typ == TYPE_ARRAY: - if execute == EX_SKIP: - for i in range(n): - # TODO check whether we need to call `list_hook` - self._unpack(EX_SKIP) - return - ret = newlist_hint(n) - for i in range(n): - ret.append(self._unpack(EX_CONSTRUCT)) - if self._list_hook is not None: - ret = self._list_hook(ret) - # TODO is the interaction between `list_hook` and `use_list` ok? - return ret if self._use_list else tuple(ret) - if typ == TYPE_MAP: - if execute == EX_SKIP: - for i in range(n): - # TODO check whether we need to call hooks - self._unpack(EX_SKIP) - self._unpack(EX_SKIP) - return - if self._object_pairs_hook is not None: - ret = self._object_pairs_hook( - (self._unpack(EX_CONSTRUCT), self._unpack(EX_CONSTRUCT)) for _ in range(n) - ) - else: - ret = {} - for _ in range(n): - key = self._unpack(EX_CONSTRUCT) - if self._strict_map_key and type(key) not in (str, bytes): - raise ValueError("%s is not allowed for map key" % str(type(key))) - if isinstance(key, str): - key = sys.intern(key) - ret[key] = self._unpack(EX_CONSTRUCT) - if self._object_hook is not None: - ret = self._object_hook(ret) - return ret - if execute == EX_SKIP: - return - if typ == TYPE_RAW: - if self._raw: - obj = bytes(obj) - else: - obj = obj.decode("utf_8", self._unicode_errors) - return obj - if typ == TYPE_BIN: - return bytes(obj) - if typ == TYPE_EXT: - if n == -1: # timestamp - ts = Timestamp.from_bytes(bytes(obj)) - if self._timestamp == 1: - return ts.to_unix() - elif self._timestamp == 2: - return ts.to_unix_nano() - elif self._timestamp == 3: - return ts.to_datetime() - else: - return ts - else: - return self._ext_hook(n, bytes(obj)) - assert typ == TYPE_IMMEDIATE - return obj - - def __iter__(self): - return self - - def __next__(self): - try: - ret = self._unpack(EX_CONSTRUCT) - self._consume() - return ret - except OutOfData: - self._consume() - raise StopIteration - except RecursionError: - raise StackError - - next = __next__ - - def skip(self): - self._unpack(EX_SKIP) - self._consume() - - def unpack(self): - try: - ret = self._unpack(EX_CONSTRUCT) - except RecursionError: - raise StackError - self._consume() - return ret - - def read_array_header(self): - ret = self._unpack(EX_READ_ARRAY_HEADER) - self._consume() - return ret - - def read_map_header(self): - ret = self._unpack(EX_READ_MAP_HEADER) - self._consume() - return ret - - def tell(self): - return self._stream_offset - - -class Packer: - """ - MessagePack Packer - - Usage:: - - packer = Packer() - astream.write(packer.pack(a)) - astream.write(packer.pack(b)) - - Packer's constructor has some keyword arguments: - - :param default: - When specified, it should be callable. - Convert user type to builtin type that Packer supports. - See also simplejson's document. - - :param bool use_single_float: - Use single precision float type for float. (default: False) - - :param bool autoreset: - Reset buffer after each pack and return its content as `bytes`. (default: True). - If set this to false, use `bytes()` to get content and `.reset()` to clear buffer. - - :param bool use_bin_type: - Use bin type introduced in msgpack spec 2.0 for bytes. - It also enables str8 type for unicode. (default: True) - - :param bool strict_types: - If set to true, types will be checked to be exact. Derived classes - from serializable types will not be serialized and will be - treated as unsupported type and forwarded to default. - Additionally tuples will not be serialized as lists. - This is useful when trying to implement accurate serialization - for python types. - - :param bool datetime: - If set to true, datetime with tzinfo is packed into Timestamp type. - Note that the tzinfo is stripped in the timestamp. - You can get UTC datetime with `timestamp=3` option of the Unpacker. - - :param str unicode_errors: - The error handler for encoding unicode. (default: 'strict') - DO NOT USE THIS!! This option is kept for very specific usage. - - :param int buf_size: - Internal buffer size. This option is used only for C implementation. - """ - - def __init__( - self, - *, - default=None, - use_single_float=False, - autoreset=True, - use_bin_type=True, - strict_types=False, - datetime=False, - unicode_errors=None, - buf_size=None, - ): - self._strict_types = strict_types - self._use_float = use_single_float - self._autoreset = autoreset - self._use_bin_type = use_bin_type - self._buffer = BytesIO() - self._datetime = bool(datetime) - self._unicode_errors = unicode_errors or "strict" - if default is not None and not callable(default): - raise TypeError("default must be callable") - self._default = default - - def _pack( - self, - obj, - nest_limit=DEFAULT_RECURSE_LIMIT, - check=isinstance, - check_type_strict=_check_type_strict, - ): - default_used = False - if self._strict_types: - check = check_type_strict - list_types = list - else: - list_types = (list, tuple) - while True: - if nest_limit < 0: - raise ValueError("recursion limit exceeded") - if obj is None: - return self._buffer.write(b"\xc0") - if check(obj, bool): - if obj: - return self._buffer.write(b"\xc3") - return self._buffer.write(b"\xc2") - if check(obj, int): - if 0 <= obj < 0x80: - return self._buffer.write(struct.pack("B", obj)) - if -0x20 <= obj < 0: - return self._buffer.write(struct.pack("b", obj)) - if 0x80 <= obj <= 0xFF: - return self._buffer.write(struct.pack("BB", 0xCC, obj)) - if -0x80 <= obj < 0: - return self._buffer.write(struct.pack(">Bb", 0xD0, obj)) - if 0xFF < obj <= 0xFFFF: - return self._buffer.write(struct.pack(">BH", 0xCD, obj)) - if -0x8000 <= obj < -0x80: - return self._buffer.write(struct.pack(">Bh", 0xD1, obj)) - if 0xFFFF < obj <= 0xFFFFFFFF: - return self._buffer.write(struct.pack(">BI", 0xCE, obj)) - if -0x80000000 <= obj < -0x8000: - return self._buffer.write(struct.pack(">Bi", 0xD2, obj)) - if 0xFFFFFFFF < obj <= 0xFFFFFFFFFFFFFFFF: - return self._buffer.write(struct.pack(">BQ", 0xCF, obj)) - if -0x8000000000000000 <= obj < -0x80000000: - return self._buffer.write(struct.pack(">Bq", 0xD3, obj)) - if not default_used and self._default is not None: - obj = self._default(obj) - default_used = True - continue - raise OverflowError("Integer value out of range") - if check(obj, (bytes, bytearray)): - n = len(obj) - if n >= 2**32: - raise ValueError("%s is too large" % type(obj).__name__) - self._pack_bin_header(n) - return self._buffer.write(obj) - if check(obj, str): - obj = obj.encode("utf-8", self._unicode_errors) - n = len(obj) - if n >= 2**32: - raise ValueError("String is too large") - self._pack_raw_header(n) - return self._buffer.write(obj) - if check(obj, memoryview): - n = obj.nbytes - if n >= 2**32: - raise ValueError("Memoryview is too large") - self._pack_bin_header(n) - return self._buffer.write(obj) - if check(obj, float): - if self._use_float: - return self._buffer.write(struct.pack(">Bf", 0xCA, obj)) - return self._buffer.write(struct.pack(">Bd", 0xCB, obj)) - if check(obj, (ExtType, Timestamp)): - if check(obj, Timestamp): - code = -1 - data = obj.to_bytes() - else: - code = obj.code - data = obj.data - assert isinstance(code, int) - assert isinstance(data, bytes) - L = len(data) - if L == 1: - self._buffer.write(b"\xd4") - elif L == 2: - self._buffer.write(b"\xd5") - elif L == 4: - self._buffer.write(b"\xd6") - elif L == 8: - self._buffer.write(b"\xd7") - elif L == 16: - self._buffer.write(b"\xd8") - elif L <= 0xFF: - self._buffer.write(struct.pack(">BB", 0xC7, L)) - elif L <= 0xFFFF: - self._buffer.write(struct.pack(">BH", 0xC8, L)) - else: - self._buffer.write(struct.pack(">BI", 0xC9, L)) - self._buffer.write(struct.pack("b", code)) - self._buffer.write(data) - return - if check(obj, list_types): - n = len(obj) - self._pack_array_header(n) - for i in range(n): - self._pack(obj[i], nest_limit - 1) - return - if check(obj, dict): - return self._pack_map_pairs(len(obj), obj.items(), nest_limit - 1) - - if self._datetime and check(obj, _DateTime) and obj.tzinfo is not None: - obj = Timestamp.from_datetime(obj) - default_used = 1 - continue - - if not default_used and self._default is not None: - obj = self._default(obj) - default_used = 1 - continue - - if self._datetime and check(obj, _DateTime): - raise ValueError(f"Cannot serialize {obj!r} where tzinfo=None") - - raise TypeError(f"Cannot serialize {obj!r}") - - def pack(self, obj): - try: - self._pack(obj) - except: - self._buffer = BytesIO() # force reset - raise - if self._autoreset: - ret = self._buffer.getvalue() - self._buffer = BytesIO() - return ret - - def pack_map_pairs(self, pairs): - self._pack_map_pairs(len(pairs), pairs) - if self._autoreset: - ret = self._buffer.getvalue() - self._buffer = BytesIO() - return ret - - def pack_array_header(self, n): - if n >= 2**32: - raise ValueError - self._pack_array_header(n) - if self._autoreset: - ret = self._buffer.getvalue() - self._buffer = BytesIO() - return ret - - def pack_map_header(self, n): - if n >= 2**32: - raise ValueError - self._pack_map_header(n) - if self._autoreset: - ret = self._buffer.getvalue() - self._buffer = BytesIO() - return ret - - def pack_ext_type(self, typecode, data): - if not isinstance(typecode, int): - raise TypeError("typecode must have int type.") - if not 0 <= typecode <= 127: - raise ValueError("typecode should be 0-127") - if not isinstance(data, bytes): - raise TypeError("data must have bytes type") - L = len(data) - if L > 0xFFFFFFFF: - raise ValueError("Too large data") - if L == 1: - self._buffer.write(b"\xd4") - elif L == 2: - self._buffer.write(b"\xd5") - elif L == 4: - self._buffer.write(b"\xd6") - elif L == 8: - self._buffer.write(b"\xd7") - elif L == 16: - self._buffer.write(b"\xd8") - elif L <= 0xFF: - self._buffer.write(b"\xc7" + struct.pack("B", L)) - elif L <= 0xFFFF: - self._buffer.write(b"\xc8" + struct.pack(">H", L)) - else: - self._buffer.write(b"\xc9" + struct.pack(">I", L)) - self._buffer.write(struct.pack("B", typecode)) - self._buffer.write(data) - - def _pack_array_header(self, n): - if n <= 0x0F: - return self._buffer.write(struct.pack("B", 0x90 + n)) - if n <= 0xFFFF: - return self._buffer.write(struct.pack(">BH", 0xDC, n)) - if n <= 0xFFFFFFFF: - return self._buffer.write(struct.pack(">BI", 0xDD, n)) - raise ValueError("Array is too large") - - def _pack_map_header(self, n): - if n <= 0x0F: - return self._buffer.write(struct.pack("B", 0x80 + n)) - if n <= 0xFFFF: - return self._buffer.write(struct.pack(">BH", 0xDE, n)) - if n <= 0xFFFFFFFF: - return self._buffer.write(struct.pack(">BI", 0xDF, n)) - raise ValueError("Dict is too large") - - def _pack_map_pairs(self, n, pairs, nest_limit=DEFAULT_RECURSE_LIMIT): - self._pack_map_header(n) - for k, v in pairs: - self._pack(k, nest_limit - 1) - self._pack(v, nest_limit - 1) - - def _pack_raw_header(self, n): - if n <= 0x1F: - self._buffer.write(struct.pack("B", 0xA0 + n)) - elif self._use_bin_type and n <= 0xFF: - self._buffer.write(struct.pack(">BB", 0xD9, n)) - elif n <= 0xFFFF: - self._buffer.write(struct.pack(">BH", 0xDA, n)) - elif n <= 0xFFFFFFFF: - self._buffer.write(struct.pack(">BI", 0xDB, n)) - else: - raise ValueError("Raw is too large") - - def _pack_bin_header(self, n): - if not self._use_bin_type: - return self._pack_raw_header(n) - elif n <= 0xFF: - return self._buffer.write(struct.pack(">BB", 0xC4, n)) - elif n <= 0xFFFF: - return self._buffer.write(struct.pack(">BH", 0xC5, n)) - elif n <= 0xFFFFFFFF: - return self._buffer.write(struct.pack(">BI", 0xC6, n)) - else: - raise ValueError("Bin is too large") - - def bytes(self): - """Return internal buffer contents as bytes object""" - return self._buffer.getvalue() - - def reset(self): - """Reset internal buffer. - - This method is useful only when autoreset=False. - """ - self._buffer = BytesIO() - - def getbuffer(self): - """Return view of internal buffer.""" - if _USING_STRINGBUILDER: - return memoryview(self.bytes()) - else: - return self._buffer.getbuffer() diff --git a/srsly/msgpack/pack.h b/srsly/msgpack/pack.h deleted file mode 100644 index edf3a3f..0000000 --- a/srsly/msgpack/pack.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * MessagePack for Python packing routine - * - * Copyright (C) 2009 Naoki INADA - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "sysdep.h" -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct msgpack_packer { - char *buf; - size_t length; - size_t buf_size; - bool use_bin_type; -} msgpack_packer; - -typedef struct Packer Packer; - -static inline int msgpack_pack_write(msgpack_packer* pk, const char *data, size_t l) -{ - char* buf = pk->buf; - size_t bs = pk->buf_size; - size_t len = pk->length; - - if (len + l > bs) { - bs = (len + l) * 2; - buf = (char*)PyMem_Realloc(buf, bs); - if (!buf) { - PyErr_NoMemory(); - return -1; - } - } - memcpy(buf + len, data, l); - len += l; - - pk->buf = buf; - pk->buf_size = bs; - pk->length = len; - return 0; -} - -#define msgpack_pack_append_buffer(user, buf, len) \ - return msgpack_pack_write(user, (const char*)buf, len) - -#include "pack_template.h" - -#ifdef __cplusplus -} -#endif diff --git a/srsly/msgpack/pack_template.h b/srsly/msgpack/pack_template.h deleted file mode 100644 index b8959f0..0000000 --- a/srsly/msgpack/pack_template.h +++ /dev/null @@ -1,596 +0,0 @@ -/* - * MessagePack packing routine template - * - * Copyright (C) 2008-2010 FURUHASHI Sadayuki - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#if defined(__LITTLE_ENDIAN__) -#define TAKE8_8(d) ((uint8_t*)&d)[0] -#define TAKE8_16(d) ((uint8_t*)&d)[0] -#define TAKE8_32(d) ((uint8_t*)&d)[0] -#define TAKE8_64(d) ((uint8_t*)&d)[0] -#elif defined(__BIG_ENDIAN__) -#define TAKE8_8(d) ((uint8_t*)&d)[0] -#define TAKE8_16(d) ((uint8_t*)&d)[1] -#define TAKE8_32(d) ((uint8_t*)&d)[3] -#define TAKE8_64(d) ((uint8_t*)&d)[7] -#endif - -#ifndef msgpack_pack_append_buffer -#error msgpack_pack_append_buffer callback is not defined -#endif - - -/* - * Integer - */ - -#define msgpack_pack_real_uint16(x, d) \ -do { \ - if(d < (1<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \ - } else if(d < (1<<8)) { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } else { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } \ -} while(0) - -#define msgpack_pack_real_uint32(x, d) \ -do { \ - if(d < (1<<8)) { \ - if(d < (1<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \ - } else { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } \ - } else { \ - if(d < (1<<16)) { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else { \ - /* unsigned 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } \ - } \ -} while(0) - -#define msgpack_pack_real_uint64(x, d) \ -do { \ - if(d < (1ULL<<8)) { \ - if(d < (1ULL<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \ - } else { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } \ - } else { \ - if(d < (1ULL<<16)) { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else if(d < (1ULL<<32)) { \ - /* unsigned 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } else { \ - /* unsigned 64 */ \ - unsigned char buf[9]; \ - buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \ - msgpack_pack_append_buffer(x, buf, 9); \ - } \ - } \ -} while(0) - -#define msgpack_pack_real_int16(x, d) \ -do { \ - if(d < -(1<<5)) { \ - if(d < -(1<<7)) { \ - /* signed 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else { \ - /* signed 8 */ \ - unsigned char buf[2] = {0xd0, TAKE8_16(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } \ - } else if(d < (1<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \ - } else { \ - if(d < (1<<8)) { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } else { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } \ - } \ -} while(0) - -#define msgpack_pack_real_int32(x, d) \ -do { \ - if(d < -(1<<5)) { \ - if(d < -(1<<15)) { \ - /* signed 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } else if(d < -(1<<7)) { \ - /* signed 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else { \ - /* signed 8 */ \ - unsigned char buf[2] = {0xd0, TAKE8_32(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } \ - } else if(d < (1<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \ - } else { \ - if(d < (1<<8)) { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } else if(d < (1<<16)) { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else { \ - /* unsigned 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } \ - } \ -} while(0) - -#define msgpack_pack_real_int64(x, d) \ -do { \ - if(d < -(1LL<<5)) { \ - if(d < -(1LL<<15)) { \ - if(d < -(1LL<<31)) { \ - /* signed 64 */ \ - unsigned char buf[9]; \ - buf[0] = 0xd3; _msgpack_store64(&buf[1], d); \ - msgpack_pack_append_buffer(x, buf, 9); \ - } else { \ - /* signed 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } \ - } else { \ - if(d < -(1<<7)) { \ - /* signed 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } else { \ - /* signed 8 */ \ - unsigned char buf[2] = {0xd0, TAKE8_64(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } \ - } \ - } else if(d < (1<<7)) { \ - /* fixnum */ \ - msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \ - } else { \ - if(d < (1LL<<16)) { \ - if(d < (1<<8)) { \ - /* unsigned 8 */ \ - unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \ - msgpack_pack_append_buffer(x, buf, 2); \ - } else { \ - /* unsigned 16 */ \ - unsigned char buf[3]; \ - buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \ - msgpack_pack_append_buffer(x, buf, 3); \ - } \ - } else { \ - if(d < (1LL<<32)) { \ - /* unsigned 32 */ \ - unsigned char buf[5]; \ - buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \ - msgpack_pack_append_buffer(x, buf, 5); \ - } else { \ - /* unsigned 64 */ \ - unsigned char buf[9]; \ - buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \ - msgpack_pack_append_buffer(x, buf, 9); \ - } \ - } \ - } \ -} while(0) - - -static inline int msgpack_pack_short(msgpack_packer* x, short d) -{ -#if defined(SIZEOF_SHORT) -#if SIZEOF_SHORT == 2 - msgpack_pack_real_int16(x, d); -#elif SIZEOF_SHORT == 4 - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#elif defined(SHRT_MAX) -#if SHRT_MAX == 0x7fff - msgpack_pack_real_int16(x, d); -#elif SHRT_MAX == 0x7fffffff - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#else -if(sizeof(short) == 2) { - msgpack_pack_real_int16(x, d); -} else if(sizeof(short) == 4) { - msgpack_pack_real_int32(x, d); -} else { - msgpack_pack_real_int64(x, d); -} -#endif -} - -static inline int msgpack_pack_int(msgpack_packer* x, int d) -{ -#if defined(SIZEOF_INT) -#if SIZEOF_INT == 2 - msgpack_pack_real_int16(x, d); -#elif SIZEOF_INT == 4 - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#elif defined(INT_MAX) -#if INT_MAX == 0x7fff - msgpack_pack_real_int16(x, d); -#elif INT_MAX == 0x7fffffff - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#else -if(sizeof(int) == 2) { - msgpack_pack_real_int16(x, d); -} else if(sizeof(int) == 4) { - msgpack_pack_real_int32(x, d); -} else { - msgpack_pack_real_int64(x, d); -} -#endif -} - -static inline int msgpack_pack_long(msgpack_packer* x, long d) -{ -#if defined(SIZEOF_LONG) -#if SIZEOF_LONG == 4 - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#elif defined(LONG_MAX) -#if LONG_MAX == 0x7fffffffL - msgpack_pack_real_int32(x, d); -#else - msgpack_pack_real_int64(x, d); -#endif - -#else - if (sizeof(long) == 4) { - msgpack_pack_real_int32(x, d); - } else { - msgpack_pack_real_int64(x, d); - } -#endif -} - -static inline int msgpack_pack_long_long(msgpack_packer* x, long long d) -{ - msgpack_pack_real_int64(x, d); -} - -static inline int msgpack_pack_unsigned_long_long(msgpack_packer* x, unsigned long long d) -{ - msgpack_pack_real_uint64(x, d); -} - - -/* - * Float - */ - -static inline int msgpack_pack_float(msgpack_packer* x, float d) -{ - unsigned char buf[5]; - buf[0] = 0xca; - -#if PY_VERSION_HEX >= 0x030B00A7 - PyFloat_Pack4(d, (char *)&buf[1], 0); -#else - _PyFloat_Pack4(d, &buf[1], 0); -#endif - msgpack_pack_append_buffer(x, buf, 5); -} - -static inline int msgpack_pack_double(msgpack_packer* x, double d) -{ - unsigned char buf[9]; - buf[0] = 0xcb; -#if PY_VERSION_HEX >= 0x030B00A7 - PyFloat_Pack8(d, (char *)&buf[1], 0); -#else - _PyFloat_Pack8(d, &buf[1], 0); -#endif - msgpack_pack_append_buffer(x, buf, 9); -} - - -/* - * Nil - */ - -static inline int msgpack_pack_nil(msgpack_packer* x) -{ - static const unsigned char d = 0xc0; - msgpack_pack_append_buffer(x, &d, 1); -} - - -/* - * Boolean - */ - -static inline int msgpack_pack_true(msgpack_packer* x) -{ - static const unsigned char d = 0xc3; - msgpack_pack_append_buffer(x, &d, 1); -} - -static inline int msgpack_pack_false(msgpack_packer* x) -{ - static const unsigned char d = 0xc2; - msgpack_pack_append_buffer(x, &d, 1); -} - - -/* - * Array - */ - -static inline int msgpack_pack_array(msgpack_packer* x, unsigned int n) -{ - if(n < 16) { - unsigned char d = 0x90 | n; - msgpack_pack_append_buffer(x, &d, 1); - } else if(n < 65536) { - unsigned char buf[3]; - buf[0] = 0xdc; _msgpack_store16(&buf[1], (uint16_t)n); - msgpack_pack_append_buffer(x, buf, 3); - } else { - unsigned char buf[5]; - buf[0] = 0xdd; _msgpack_store32(&buf[1], (uint32_t)n); - msgpack_pack_append_buffer(x, buf, 5); - } -} - - -/* - * Map - */ - -static inline int msgpack_pack_map(msgpack_packer* x, unsigned int n) -{ - if(n < 16) { - unsigned char d = 0x80 | n; - msgpack_pack_append_buffer(x, &TAKE8_8(d), 1); - } else if(n < 65536) { - unsigned char buf[3]; - buf[0] = 0xde; _msgpack_store16(&buf[1], (uint16_t)n); - msgpack_pack_append_buffer(x, buf, 3); - } else { - unsigned char buf[5]; - buf[0] = 0xdf; _msgpack_store32(&buf[1], (uint32_t)n); - msgpack_pack_append_buffer(x, buf, 5); - } -} - - -/* - * Raw - */ - -static inline int msgpack_pack_raw(msgpack_packer* x, size_t l) -{ - if (l < 32) { - unsigned char d = 0xa0 | (uint8_t)l; - msgpack_pack_append_buffer(x, &TAKE8_8(d), 1); - } else if (x->use_bin_type && l < 256) { // str8 is new format introduced with bin. - unsigned char buf[2] = {0xd9, (uint8_t)l}; - msgpack_pack_append_buffer(x, buf, 2); - } else if (l < 65536) { - unsigned char buf[3]; - buf[0] = 0xda; _msgpack_store16(&buf[1], (uint16_t)l); - msgpack_pack_append_buffer(x, buf, 3); - } else { - unsigned char buf[5]; - buf[0] = 0xdb; _msgpack_store32(&buf[1], (uint32_t)l); - msgpack_pack_append_buffer(x, buf, 5); - } -} - -/* - * bin - */ -static inline int msgpack_pack_bin(msgpack_packer *x, size_t l) -{ - if (!x->use_bin_type) { - return msgpack_pack_raw(x, l); - } - if (l < 256) { - unsigned char buf[2] = {0xc4, (unsigned char)l}; - msgpack_pack_append_buffer(x, buf, 2); - } else if (l < 65536) { - unsigned char buf[3] = {0xc5}; - _msgpack_store16(&buf[1], (uint16_t)l); - msgpack_pack_append_buffer(x, buf, 3); - } else { - unsigned char buf[5] = {0xc6}; - _msgpack_store32(&buf[1], (uint32_t)l); - msgpack_pack_append_buffer(x, buf, 5); - } -} - -static inline int msgpack_pack_raw_body(msgpack_packer* x, const void* b, size_t l) -{ - if (l > 0) msgpack_pack_append_buffer(x, (const unsigned char*)b, l); - return 0; -} - -/* - * Ext - */ -static inline int msgpack_pack_ext(msgpack_packer* x, char typecode, size_t l) -{ - if (l == 1) { - unsigned char buf[2]; - buf[0] = 0xd4; - buf[1] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 2); - } - else if(l == 2) { - unsigned char buf[2]; - buf[0] = 0xd5; - buf[1] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 2); - } - else if(l == 4) { - unsigned char buf[2]; - buf[0] = 0xd6; - buf[1] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 2); - } - else if(l == 8) { - unsigned char buf[2]; - buf[0] = 0xd7; - buf[1] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 2); - } - else if(l == 16) { - unsigned char buf[2]; - buf[0] = 0xd8; - buf[1] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 2); - } - else if(l < 256) { - unsigned char buf[3]; - buf[0] = 0xc7; - buf[1] = l; - buf[2] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 3); - } else if(l < 65536) { - unsigned char buf[4]; - buf[0] = 0xc8; - _msgpack_store16(&buf[1], (uint16_t)l); - buf[3] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 4); - } else { - unsigned char buf[6]; - buf[0] = 0xc9; - _msgpack_store32(&buf[1], (uint32_t)l); - buf[5] = (unsigned char)typecode; - msgpack_pack_append_buffer(x, buf, 6); - } - -} - -/* - * Pack Timestamp extension type. Follows msgpack-c pack_template.h. - */ -static inline int msgpack_pack_timestamp(msgpack_packer* x, int64_t seconds, uint32_t nanoseconds) -{ - if ((seconds >> 34) == 0) { - /* seconds is unsigned and fits in 34 bits */ - uint64_t data64 = ((uint64_t)nanoseconds << 34) | (uint64_t)seconds; - if ((data64 & 0xffffffff00000000L) == 0) { - /* no nanoseconds and seconds is 32bits or smaller. timestamp32. */ - unsigned char buf[4]; - uint32_t data32 = (uint32_t)data64; - msgpack_pack_ext(x, -1, 4); - _msgpack_store32(buf, data32); - msgpack_pack_raw_body(x, buf, 4); - } else { - /* timestamp64 */ - unsigned char buf[8]; - msgpack_pack_ext(x, -1, 8); - _msgpack_store64(buf, data64); - msgpack_pack_raw_body(x, buf, 8); - - } - } else { - /* seconds is signed or >34bits */ - unsigned char buf[12]; - _msgpack_store32(&buf[0], nanoseconds); - _msgpack_store64(&buf[4], seconds); - msgpack_pack_ext(x, -1, 12); - msgpack_pack_raw_body(x, buf, 12); - } - return 0; -} - - -#undef msgpack_pack_append_buffer - -#undef TAKE8_8 -#undef TAKE8_16 -#undef TAKE8_32 -#undef TAKE8_64 - -#undef msgpack_pack_real_uint16 -#undef msgpack_pack_real_uint32 -#undef msgpack_pack_real_uint64 -#undef msgpack_pack_real_int16 -#undef msgpack_pack_real_int32 -#undef msgpack_pack_real_int64 diff --git a/srsly/msgpack/sysdep.h b/srsly/msgpack/sysdep.h deleted file mode 100644 index 7067300..0000000 --- a/srsly/msgpack/sysdep.h +++ /dev/null @@ -1,194 +0,0 @@ -/* - * MessagePack system dependencies - * - * Copyright (C) 2008-2010 FURUHASHI Sadayuki - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MSGPACK_SYSDEP_H__ -#define MSGPACK_SYSDEP_H__ - -#include -#include -#if defined(_MSC_VER) && _MSC_VER < 1600 -typedef __int8 int8_t; -typedef unsigned __int8 uint8_t; -typedef __int16 int16_t; -typedef unsigned __int16 uint16_t; -typedef __int32 int32_t; -typedef unsigned __int32 uint32_t; -typedef __int64 int64_t; -typedef unsigned __int64 uint64_t; -#elif defined(_MSC_VER) // && _MSC_VER >= 1600 -#include -#else -#include -#include -#endif - -#ifdef _WIN32 -#define _msgpack_atomic_counter_header -typedef long _msgpack_atomic_counter_t; -#define _msgpack_sync_decr_and_fetch(ptr) InterlockedDecrement(ptr) -#define _msgpack_sync_incr_and_fetch(ptr) InterlockedIncrement(ptr) -#elif defined(__GNUC__) && ((__GNUC__*10 + __GNUC_MINOR__) < 41) -#define _msgpack_atomic_counter_header "gcc_atomic.h" -#else -typedef unsigned int _msgpack_atomic_counter_t; -#define _msgpack_sync_decr_and_fetch(ptr) __sync_sub_and_fetch(ptr, 1) -#define _msgpack_sync_incr_and_fetch(ptr) __sync_add_and_fetch(ptr, 1) -#endif - -#ifdef _WIN32 - -#ifdef __cplusplus -/* numeric_limits::min,max */ -#ifdef max -#undef max -#endif -#ifdef min -#undef min -#endif -#endif - -#else /* _WIN32 */ -#include /* ntohs, ntohl */ -#endif - -#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__) -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ -#define __LITTLE_ENDIAN__ -#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define __BIG_ENDIAN__ -#elif _WIN32 -#define __LITTLE_ENDIAN__ -#endif -#endif - - -#ifdef __LITTLE_ENDIAN__ - -#ifdef _WIN32 -# if defined(ntohs) -# define _msgpack_be16(x) ntohs(x) -# elif defined(_byteswap_ushort) || (defined(_MSC_VER) && _MSC_VER >= 1400) -# define _msgpack_be16(x) ((uint16_t)_byteswap_ushort((unsigned short)x)) -# else -# define _msgpack_be16(x) ( \ - ((((uint16_t)x) << 8) ) | \ - ((((uint16_t)x) >> 8) ) ) -# endif -#else -# define _msgpack_be16(x) ntohs(x) -#endif - -#ifdef _WIN32 -# if defined(ntohl) -# define _msgpack_be32(x) ntohl(x) -# elif defined(_byteswap_ulong) || defined(_MSC_VER) -# define _msgpack_be32(x) ((uint32_t)_byteswap_ulong((unsigned long)x)) -# else -# define _msgpack_be32(x) \ - ( ((((uint32_t)x) << 24) ) | \ - ((((uint32_t)x) << 8) & 0x00ff0000U ) | \ - ((((uint32_t)x) >> 8) & 0x0000ff00U ) | \ - ((((uint32_t)x) >> 24) ) ) -# endif -#else -# define _msgpack_be32(x) ntohl(x) -#endif - -#if defined(_byteswap_uint64) || defined(_MSC_VER) -# define _msgpack_be64(x) (_byteswap_uint64(x)) -#elif defined(bswap_64) -# define _msgpack_be64(x) bswap_64(x) -#elif defined(__DARWIN_OSSwapInt64) -# define _msgpack_be64(x) __DARWIN_OSSwapInt64(x) -#else -#define _msgpack_be64(x) \ - ( ((((uint64_t)x) << 56) ) | \ - ((((uint64_t)x) << 40) & 0x00ff000000000000ULL ) | \ - ((((uint64_t)x) << 24) & 0x0000ff0000000000ULL ) | \ - ((((uint64_t)x) << 8) & 0x000000ff00000000ULL ) | \ - ((((uint64_t)x) >> 8) & 0x00000000ff000000ULL ) | \ - ((((uint64_t)x) >> 24) & 0x0000000000ff0000ULL ) | \ - ((((uint64_t)x) >> 40) & 0x000000000000ff00ULL ) | \ - ((((uint64_t)x) >> 56) ) ) -#endif - -#define _msgpack_load16(cast, from) ((cast)( \ - (((uint16_t)((uint8_t*)(from))[0]) << 8) | \ - (((uint16_t)((uint8_t*)(from))[1]) ) )) - -#define _msgpack_load32(cast, from) ((cast)( \ - (((uint32_t)((uint8_t*)(from))[0]) << 24) | \ - (((uint32_t)((uint8_t*)(from))[1]) << 16) | \ - (((uint32_t)((uint8_t*)(from))[2]) << 8) | \ - (((uint32_t)((uint8_t*)(from))[3]) ) )) - -#define _msgpack_load64(cast, from) ((cast)( \ - (((uint64_t)((uint8_t*)(from))[0]) << 56) | \ - (((uint64_t)((uint8_t*)(from))[1]) << 48) | \ - (((uint64_t)((uint8_t*)(from))[2]) << 40) | \ - (((uint64_t)((uint8_t*)(from))[3]) << 32) | \ - (((uint64_t)((uint8_t*)(from))[4]) << 24) | \ - (((uint64_t)((uint8_t*)(from))[5]) << 16) | \ - (((uint64_t)((uint8_t*)(from))[6]) << 8) | \ - (((uint64_t)((uint8_t*)(from))[7]) ) )) - -#else - -#define _msgpack_be16(x) (x) -#define _msgpack_be32(x) (x) -#define _msgpack_be64(x) (x) - -#define _msgpack_load16(cast, from) ((cast)( \ - (((uint16_t)((uint8_t*)from)[0]) << 8) | \ - (((uint16_t)((uint8_t*)from)[1]) ) )) - -#define _msgpack_load32(cast, from) ((cast)( \ - (((uint32_t)((uint8_t*)from)[0]) << 24) | \ - (((uint32_t)((uint8_t*)from)[1]) << 16) | \ - (((uint32_t)((uint8_t*)from)[2]) << 8) | \ - (((uint32_t)((uint8_t*)from)[3]) ) )) - -#define _msgpack_load64(cast, from) ((cast)( \ - (((uint64_t)((uint8_t*)from)[0]) << 56) | \ - (((uint64_t)((uint8_t*)from)[1]) << 48) | \ - (((uint64_t)((uint8_t*)from)[2]) << 40) | \ - (((uint64_t)((uint8_t*)from)[3]) << 32) | \ - (((uint64_t)((uint8_t*)from)[4]) << 24) | \ - (((uint64_t)((uint8_t*)from)[5]) << 16) | \ - (((uint64_t)((uint8_t*)from)[6]) << 8) | \ - (((uint64_t)((uint8_t*)from)[7]) ) )) -#endif - - -#define _msgpack_store16(to, num) \ - do { uint16_t val = _msgpack_be16(num); memcpy(to, &val, 2); } while(0) -#define _msgpack_store32(to, num) \ - do { uint32_t val = _msgpack_be32(num); memcpy(to, &val, 4); } while(0) -#define _msgpack_store64(to, num) \ - do { uint64_t val = _msgpack_be64(num); memcpy(to, &val, 8); } while(0) - -/* -#define _msgpack_load16(cast, from) \ - ({ cast val; memcpy(&val, (char*)from, 2); _msgpack_be16(val); }) -#define _msgpack_load32(cast, from) \ - ({ cast val; memcpy(&val, (char*)from, 4); _msgpack_be32(val); }) -#define _msgpack_load64(cast, from) \ - ({ cast val; memcpy(&val, (char*)from, 8); _msgpack_be64(val); }) -*/ - - -#endif /* msgpack/sysdep.h */ diff --git a/srsly/msgpack/unpack.h b/srsly/msgpack/unpack.h deleted file mode 100644 index dabb5c1..0000000 --- a/srsly/msgpack/unpack.h +++ /dev/null @@ -1,393 +0,0 @@ -/* - * MessagePack for Python unpacking routine - * - * Copyright (C) 2009 Naoki INADA - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define MSGPACK_EMBED_STACK_SIZE (1024) -#include "unpack_define.h" - -typedef struct unpack_user { - bool use_list; - bool raw; - bool has_pairs_hook; - bool strict_map_key; - int timestamp; - PyObject *object_hook; - PyObject *list_hook; - PyObject *ext_hook; - PyObject *timestamp_t; - PyObject *giga; - PyObject *utc; - const char *encoding; - const char *unicode_errors; - Py_ssize_t max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len; -} unpack_user; - -typedef PyObject* msgpack_unpack_object; -struct unpack_context; -typedef struct unpack_context unpack_context; -typedef int (*execute_fn)(unpack_context *ctx, const char* data, Py_ssize_t len, Py_ssize_t* off); - -static inline msgpack_unpack_object unpack_callback_root(unpack_user* u) -{ - return NULL; -} - -static inline int unpack_callback_uint16(unpack_user* u, uint16_t d, msgpack_unpack_object* o) -{ - PyObject *p = PyLong_FromLong((long)d); - if (!p) - return -1; - *o = p; - return 0; -} -static inline int unpack_callback_uint8(unpack_user* u, uint8_t d, msgpack_unpack_object* o) -{ - return unpack_callback_uint16(u, d, o); -} - - -static inline int unpack_callback_uint32(unpack_user* u, uint32_t d, msgpack_unpack_object* o) -{ - PyObject *p = PyLong_FromSize_t((size_t)d); - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_uint64(unpack_user* u, uint64_t d, msgpack_unpack_object* o) -{ - PyObject *p; - if (d > LONG_MAX) { - p = PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG)d); - } else { - p = PyLong_FromLong((long)d); - } - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_int32(unpack_user* u, int32_t d, msgpack_unpack_object* o) -{ - PyObject *p = PyLong_FromLong(d); - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_int16(unpack_user* u, int16_t d, msgpack_unpack_object* o) -{ - return unpack_callback_int32(u, d, o); -} - -static inline int unpack_callback_int8(unpack_user* u, int8_t d, msgpack_unpack_object* o) -{ - return unpack_callback_int32(u, d, o); -} - -static inline int unpack_callback_int64(unpack_user* u, int64_t d, msgpack_unpack_object* o) -{ - PyObject *p; - if (d > LONG_MAX || d < LONG_MIN) { - p = PyLong_FromLongLong((PY_LONG_LONG)d); - } else { - p = PyLong_FromLong((long)d); - } - *o = p; - return 0; -} - -static inline int unpack_callback_double(unpack_user* u, double d, msgpack_unpack_object* o) -{ - PyObject *p = PyFloat_FromDouble(d); - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_float(unpack_user* u, float d, msgpack_unpack_object* o) -{ - return unpack_callback_double(u, d, o); -} - -static inline int unpack_callback_nil(unpack_user* u, msgpack_unpack_object* o) -{ Py_INCREF(Py_None); *o = Py_None; return 0; } - -static inline int unpack_callback_true(unpack_user* u, msgpack_unpack_object* o) -{ Py_INCREF(Py_True); *o = Py_True; return 0; } - -static inline int unpack_callback_false(unpack_user* u, msgpack_unpack_object* o) -{ Py_INCREF(Py_False); *o = Py_False; return 0; } - -static inline int unpack_callback_array(unpack_user* u, unsigned int n, msgpack_unpack_object* o) -{ - if (n > u->max_array_len) { - PyErr_Format(PyExc_ValueError, "%u exceeds max_array_len(%zd)", n, u->max_array_len); - return -1; - } - PyObject *p = u->use_list ? PyList_New(n) : PyTuple_New(n); - - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_array_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object o) -{ - if (u->use_list) - PyList_SET_ITEM(*c, current, o); - else - PyTuple_SET_ITEM(*c, current, o); - return 0; -} - -static inline int unpack_callback_array_end(unpack_user* u, msgpack_unpack_object* c) -{ - if (u->list_hook) { - PyObject *new_c = PyObject_CallFunctionObjArgs(u->list_hook, *c, NULL); - if (!new_c) - return -1; - Py_DECREF(*c); - *c = new_c; - } - return 0; -} - -static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_unpack_object* o) -{ - if (n > u->max_map_len) { - PyErr_Format(PyExc_ValueError, "%u exceeds max_map_len(%zd)", n, u->max_map_len); - return -1; - } - PyObject *p; - if (u->has_pairs_hook) { - p = PyList_New(n); // Or use tuple? - } - else { - p = PyDict_New(); - } - if (!p) - return -1; - *o = p; - return 0; -} - -static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v) -{ - if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) { - PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key when strict_map_key=True", Py_TYPE(k)->tp_name); - return -1; - } - if (PyUnicode_CheckExact(k)) { - PyUnicode_InternInPlace(&k); - } - if (u->has_pairs_hook) { - msgpack_unpack_object item = PyTuple_Pack(2, k, v); - if (!item) - return -1; - Py_DECREF(k); - Py_DECREF(v); - PyList_SET_ITEM(*c, current, item); - return 0; - } - else if (PyDict_SetItem(*c, k, v) == 0) { - Py_DECREF(k); - Py_DECREF(v); - return 0; - } - return -1; -} - -static inline int unpack_callback_map_end(unpack_user* u, msgpack_unpack_object* c) -{ - if (u->object_hook) { - PyObject *new_c = PyObject_CallFunctionObjArgs(u->object_hook, *c, NULL); - if (!new_c) - return -1; - - Py_DECREF(*c); - *c = new_c; - } - return 0; -} - -static inline int unpack_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o) -{ - if (l > u->max_str_len) { - PyErr_Format(PyExc_ValueError, "%u exceeds max_str_len(%zd)", l, u->max_str_len); - return -1; - } - - PyObject *py; - if (u->encoding) { - py = PyUnicode_Decode(p, l, u->encoding, u->unicode_errors); - } else if (u->raw) { - py = PyBytes_FromStringAndSize(p, l); - } else { - py = PyUnicode_DecodeUTF8(p, l, u->unicode_errors); - } - if (!py) - return -1; - *o = py; - return 0; -} - -static inline int unpack_callback_bin(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o) -{ - if (l > u->max_bin_len) { - PyErr_Format(PyExc_ValueError, "%u exceeds max_bin_len(%zd)", l, u->max_bin_len); - return -1; - } - - PyObject *py = PyBytes_FromStringAndSize(p, l); - if (!py) - return -1; - *o = py; - return 0; -} - -typedef struct msgpack_timestamp { - int64_t tv_sec; - uint32_t tv_nsec; -} msgpack_timestamp; - -/* - * Unpack ext buffer to a timestamp. Pulled from msgpack-c timestamp.h. - */ -static int unpack_timestamp(const char* buf, unsigned int buflen, msgpack_timestamp* ts) { - switch (buflen) { - case 4: - ts->tv_nsec = 0; - { - uint32_t v = _msgpack_load32(uint32_t, buf); - ts->tv_sec = (int64_t)v; - } - return 0; - case 8: { - uint64_t value =_msgpack_load64(uint64_t, buf); - ts->tv_nsec = (uint32_t)(value >> 34); - ts->tv_sec = value & 0x00000003ffffffffLL; - return 0; - } - case 12: - ts->tv_nsec = _msgpack_load32(uint32_t, buf); - ts->tv_sec = _msgpack_load64(int64_t, buf + 4); - return 0; - default: - return -1; - } -} - -#include "datetime.h" - -static int unpack_callback_ext(unpack_user* u, const char* base, const char* pos, - unsigned int length, msgpack_unpack_object* o) -{ - int8_t typecode = (int8_t)*pos++; - if (!u->ext_hook) { - PyErr_SetString(PyExc_AssertionError, "u->ext_hook cannot be NULL"); - return -1; - } - if (length-1 > u->max_ext_len) { - PyErr_Format(PyExc_ValueError, "%u exceeds max_ext_len(%zd)", length, u->max_ext_len); - return -1; - } - - PyObject *py = NULL; - // length also includes the typecode, so the actual data is length-1 - if (typecode == -1) { - msgpack_timestamp ts; - if (unpack_timestamp(pos, length-1, &ts) < 0) { - return -1; - } - - if (u->timestamp == 2) { // int - PyObject *a = PyLong_FromLongLong(ts.tv_sec); - if (a == NULL) return -1; - - PyObject *c = PyNumber_Multiply(a, u->giga); - Py_DECREF(a); - if (c == NULL) { - return -1; - } - - PyObject *b = PyLong_FromUnsignedLong(ts.tv_nsec); - if (b == NULL) { - Py_DECREF(c); - return -1; - } - - py = PyNumber_Add(c, b); - Py_DECREF(c); - Py_DECREF(b); - } - else if (u->timestamp == 0) { // Timestamp - py = PyObject_CallFunction(u->timestamp_t, "(Lk)", ts.tv_sec, ts.tv_nsec); - } - else if (u->timestamp == 3) { // datetime - // Calculate datetime using epoch + delta - // due to limitations PyDateTime_FromTimestamp on Windows with negative timestamps - PyObject *epoch = PyDateTimeAPI->DateTime_FromDateAndTime(1970, 1, 1, 0, 0, 0, 0, u->utc, PyDateTimeAPI->DateTimeType); - if (epoch == NULL) { - return -1; - } - - PyObject* d = PyDelta_FromDSU(ts.tv_sec/(24*3600), ts.tv_sec%(24*3600), ts.tv_nsec / 1000); - if (d == NULL) { - Py_DECREF(epoch); - return -1; - } - - py = PyNumber_Add(epoch, d); - - Py_DECREF(epoch); - Py_DECREF(d); - } - else { // float - PyObject *a = PyFloat_FromDouble((double)ts.tv_nsec); - if (a == NULL) return -1; - - PyObject *b = PyNumber_TrueDivide(a, u->giga); - Py_DECREF(a); - if (b == NULL) return -1; - - PyObject *c = PyLong_FromLongLong(ts.tv_sec); - if (c == NULL) { - Py_DECREF(b); - return -1; - } - - a = PyNumber_Add(b, c); - Py_DECREF(b); - Py_DECREF(c); - py = a; - } - } else { - py = PyObject_CallFunction(u->ext_hook, "(iy#)", (int)typecode, pos, (Py_ssize_t)length-1); - } - if (!py) - return -1; - *o = py; - return 0; -} - -#include "unpack_template.h" diff --git a/srsly/msgpack/unpack_container_header.h b/srsly/msgpack/unpack_container_header.h deleted file mode 100644 index c14a3c2..0000000 --- a/srsly/msgpack/unpack_container_header.h +++ /dev/null @@ -1,51 +0,0 @@ -static inline int unpack_container_header(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off) -{ - assert(len >= *off); - uint32_t size; - const unsigned char *const p = (unsigned char*)data + *off; - -#define inc_offset(inc) \ - if (len - *off < inc) \ - return 0; \ - *off += inc; - - switch (*p) { - case var_offset: - inc_offset(3); - size = _msgpack_load16(uint16_t, p + 1); - break; - case var_offset + 1: - inc_offset(5); - size = _msgpack_load32(uint32_t, p + 1); - break; -#ifdef USE_CASE_RANGE - case fixed_offset + 0x0 ... fixed_offset + 0xf: -#else - case fixed_offset + 0x0: - case fixed_offset + 0x1: - case fixed_offset + 0x2: - case fixed_offset + 0x3: - case fixed_offset + 0x4: - case fixed_offset + 0x5: - case fixed_offset + 0x6: - case fixed_offset + 0x7: - case fixed_offset + 0x8: - case fixed_offset + 0x9: - case fixed_offset + 0xa: - case fixed_offset + 0xb: - case fixed_offset + 0xc: - case fixed_offset + 0xd: - case fixed_offset + 0xe: - case fixed_offset + 0xf: -#endif - ++*off; - size = ((unsigned int)*p) & 0x0f; - break; - default: - PyErr_SetString(PyExc_ValueError, "Unexpected type header on stream"); - return -1; - } - unpack_callback_uint32(&ctx->user, size, &ctx->stack[0].obj); - return 1; -} - diff --git a/srsly/msgpack/unpack_define.h b/srsly/msgpack/unpack_define.h deleted file mode 100644 index 0dd708d..0000000 --- a/srsly/msgpack/unpack_define.h +++ /dev/null @@ -1,95 +0,0 @@ -/* - * MessagePack unpacking routine template - * - * Copyright (C) 2008-2010 FURUHASHI Sadayuki - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MSGPACK_UNPACK_DEFINE_H__ -#define MSGPACK_UNPACK_DEFINE_H__ - -#include "msgpack/sysdep.h" -#include -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - - -#ifndef MSGPACK_EMBED_STACK_SIZE -#define MSGPACK_EMBED_STACK_SIZE 32 -#endif - - -// CS is first byte & 0x1f -typedef enum { - CS_HEADER = 0x00, // nil - - //CS_ = 0x01, - //CS_ = 0x02, // false - //CS_ = 0x03, // true - - CS_BIN_8 = 0x04, - CS_BIN_16 = 0x05, - CS_BIN_32 = 0x06, - - CS_EXT_8 = 0x07, - CS_EXT_16 = 0x08, - CS_EXT_32 = 0x09, - - CS_FLOAT = 0x0a, - CS_DOUBLE = 0x0b, - CS_UINT_8 = 0x0c, - CS_UINT_16 = 0x0d, - CS_UINT_32 = 0x0e, - CS_UINT_64 = 0x0f, - CS_INT_8 = 0x10, - CS_INT_16 = 0x11, - CS_INT_32 = 0x12, - CS_INT_64 = 0x13, - - //CS_FIXEXT1 = 0x14, - //CS_FIXEXT2 = 0x15, - //CS_FIXEXT4 = 0x16, - //CS_FIXEXT8 = 0x17, - //CS_FIXEXT16 = 0x18, - - CS_RAW_8 = 0x19, - CS_RAW_16 = 0x1a, - CS_RAW_32 = 0x1b, - CS_ARRAY_16 = 0x1c, - CS_ARRAY_32 = 0x1d, - CS_MAP_16 = 0x1e, - CS_MAP_32 = 0x1f, - - ACS_RAW_VALUE, - ACS_BIN_VALUE, - ACS_EXT_VALUE, -} msgpack_unpack_state; - - -typedef enum { - CT_ARRAY_ITEM, - CT_MAP_KEY, - CT_MAP_VALUE, -} msgpack_container_type; - - -#ifdef __cplusplus -} -#endif - -#endif /* msgpack/unpack_define.h */ diff --git a/srsly/msgpack/unpack_template.h b/srsly/msgpack/unpack_template.h deleted file mode 100644 index cce29e7..0000000 --- a/srsly/msgpack/unpack_template.h +++ /dev/null @@ -1,423 +0,0 @@ -/* - * MessagePack unpacking routine template - * - * Copyright (C) 2008-2010 FURUHASHI Sadayuki - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef USE_CASE_RANGE -#if !defined(_MSC_VER) -#define USE_CASE_RANGE -#endif -#endif - -typedef struct unpack_stack { - PyObject* obj; - Py_ssize_t size; - Py_ssize_t count; - unsigned int ct; - PyObject* map_key; -} unpack_stack; - -struct unpack_context { - unpack_user user; - unsigned int cs; - unsigned int trail; - unsigned int top; - /* - unpack_stack* stack; - unsigned int stack_size; - unpack_stack embed_stack[MSGPACK_EMBED_STACK_SIZE]; - */ - unpack_stack stack[MSGPACK_EMBED_STACK_SIZE]; -}; - - -static inline void unpack_init(unpack_context* ctx) -{ - ctx->cs = CS_HEADER; - ctx->trail = 0; - ctx->top = 0; - /* - ctx->stack = ctx->embed_stack; - ctx->stack_size = MSGPACK_EMBED_STACK_SIZE; - */ - ctx->stack[0].obj = unpack_callback_root(&ctx->user); -} - -/* -static inline void unpack_destroy(unpack_context* ctx) -{ - if(ctx->stack_size != MSGPACK_EMBED_STACK_SIZE) { - free(ctx->stack); - } -} -*/ - -static inline PyObject* unpack_data(unpack_context* ctx) -{ - return (ctx)->stack[0].obj; -} - -static inline void unpack_clear(unpack_context *ctx) -{ - Py_CLEAR(ctx->stack[0].obj); -} - -static inline int unpack_execute(bool construct, unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off) -{ - assert(len >= *off); - - const unsigned char* p = (unsigned char*)data + *off; - const unsigned char* const pe = (unsigned char*)data + len; - const void* n = p; - - unsigned int trail = ctx->trail; - unsigned int cs = ctx->cs; - unsigned int top = ctx->top; - unpack_stack* stack = ctx->stack; - /* - unsigned int stack_size = ctx->stack_size; - */ - unpack_user* user = &ctx->user; - - PyObject* obj = NULL; - unpack_stack* c = NULL; - - int ret; - -#define construct_cb(name) \ - construct && unpack_callback ## name - -#define push_simple_value(func) \ - if(construct_cb(func)(user, &obj) < 0) { goto _failed; } \ - goto _push -#define push_fixed_value(func, arg) \ - if(construct_cb(func)(user, arg, &obj) < 0) { goto _failed; } \ - goto _push -#define push_variable_value(func, base, pos, len) \ - if(construct_cb(func)(user, \ - (const char*)base, (const char*)pos, len, &obj) < 0) { goto _failed; } \ - goto _push - -#define again_fixed_trail(_cs, trail_len) \ - trail = trail_len; \ - cs = _cs; \ - goto _fixed_trail_again -#define again_fixed_trail_if_zero(_cs, trail_len, ifzero) \ - trail = trail_len; \ - if(trail == 0) { goto ifzero; } \ - cs = _cs; \ - goto _fixed_trail_again - -#define start_container(func, count_, ct_) \ - if(top >= MSGPACK_EMBED_STACK_SIZE) { ret = -3; goto _end; } \ - if(construct_cb(func)(user, count_, &stack[top].obj) < 0) { goto _failed; } \ - if((count_) == 0) { obj = stack[top].obj; \ - if (construct_cb(func##_end)(user, &obj) < 0) { goto _failed; } \ - goto _push; } \ - stack[top].ct = ct_; \ - stack[top].size = count_; \ - stack[top].count = 0; \ - ++top; \ - goto _header_again - -#define NEXT_CS(p) ((unsigned int)*p & 0x1f) - -#ifdef USE_CASE_RANGE -#define SWITCH_RANGE_BEGIN switch(*p) { -#define SWITCH_RANGE(FROM, TO) case FROM ... TO: -#define SWITCH_RANGE_DEFAULT default: -#define SWITCH_RANGE_END } -#else -#define SWITCH_RANGE_BEGIN { if(0) { -#define SWITCH_RANGE(FROM, TO) } else if(FROM <= *p && *p <= TO) { -#define SWITCH_RANGE_DEFAULT } else { -#define SWITCH_RANGE_END } } -#endif - - if(p == pe) { goto _out; } - do { - switch(cs) { - case CS_HEADER: - SWITCH_RANGE_BEGIN - SWITCH_RANGE(0x00, 0x7f) // Positive Fixnum - push_fixed_value(_uint8, *(uint8_t*)p); - SWITCH_RANGE(0xe0, 0xff) // Negative Fixnum - push_fixed_value(_int8, *(int8_t*)p); - SWITCH_RANGE(0xc0, 0xdf) // Variable - switch(*p) { - case 0xc0: // nil - push_simple_value(_nil); - //case 0xc1: // never used - case 0xc2: // false - push_simple_value(_false); - case 0xc3: // true - push_simple_value(_true); - case 0xc4: // bin 8 - again_fixed_trail(NEXT_CS(p), 1); - case 0xc5: // bin 16 - again_fixed_trail(NEXT_CS(p), 2); - case 0xc6: // bin 32 - again_fixed_trail(NEXT_CS(p), 4); - case 0xc7: // ext 8 - again_fixed_trail(NEXT_CS(p), 1); - case 0xc8: // ext 16 - again_fixed_trail(NEXT_CS(p), 2); - case 0xc9: // ext 32 - again_fixed_trail(NEXT_CS(p), 4); - case 0xca: // float - case 0xcb: // double - case 0xcc: // unsigned int 8 - case 0xcd: // unsigned int 16 - case 0xce: // unsigned int 32 - case 0xcf: // unsigned int 64 - case 0xd0: // signed int 8 - case 0xd1: // signed int 16 - case 0xd2: // signed int 32 - case 0xd3: // signed int 64 - again_fixed_trail(NEXT_CS(p), 1 << (((unsigned int)*p) & 0x03)); - case 0xd4: // fixext 1 - case 0xd5: // fixext 2 - case 0xd6: // fixext 4 - case 0xd7: // fixext 8 - again_fixed_trail_if_zero(ACS_EXT_VALUE, - (1 << (((unsigned int)*p) & 0x03))+1, - _ext_zero); - case 0xd8: // fixext 16 - again_fixed_trail_if_zero(ACS_EXT_VALUE, 16+1, _ext_zero); - case 0xd9: // str 8 - again_fixed_trail(NEXT_CS(p), 1); - case 0xda: // raw 16 - case 0xdb: // raw 32 - case 0xdc: // array 16 - case 0xdd: // array 32 - case 0xde: // map 16 - case 0xdf: // map 32 - again_fixed_trail(NEXT_CS(p), 2 << (((unsigned int)*p) & 0x01)); - default: - ret = -2; - goto _end; - } - SWITCH_RANGE(0xa0, 0xbf) // FixRaw - again_fixed_trail_if_zero(ACS_RAW_VALUE, ((unsigned int)*p & 0x1f), _raw_zero); - SWITCH_RANGE(0x90, 0x9f) // FixArray - start_container(_array, ((unsigned int)*p) & 0x0f, CT_ARRAY_ITEM); - SWITCH_RANGE(0x80, 0x8f) // FixMap - start_container(_map, ((unsigned int)*p) & 0x0f, CT_MAP_KEY); - - SWITCH_RANGE_DEFAULT - ret = -2; - goto _end; - SWITCH_RANGE_END - // end CS_HEADER - - - _fixed_trail_again: - ++p; - - default: - if((size_t)(pe - p) < trail) { goto _out; } - n = p; p += trail - 1; - switch(cs) { - case CS_EXT_8: - again_fixed_trail_if_zero(ACS_EXT_VALUE, *(uint8_t*)n+1, _ext_zero); - case CS_EXT_16: - again_fixed_trail_if_zero(ACS_EXT_VALUE, - _msgpack_load16(uint16_t,n)+1, - _ext_zero); - case CS_EXT_32: - again_fixed_trail_if_zero(ACS_EXT_VALUE, - _msgpack_load32(uint32_t,n)+1, - _ext_zero); - case CS_FLOAT: { - double f; -#if PY_VERSION_HEX >= 0x030B00A7 - f = PyFloat_Unpack4((const char*)n, 0); -#else - f = _PyFloat_Unpack4((unsigned char*)n, 0); -#endif - push_fixed_value(_float, f); } - case CS_DOUBLE: { - double f; -#if PY_VERSION_HEX >= 0x030B00A7 - f = PyFloat_Unpack8((const char*)n, 0); -#else - f = _PyFloat_Unpack8((unsigned char*)n, 0); -#endif - push_fixed_value(_double, f); } - case CS_UINT_8: - push_fixed_value(_uint8, *(uint8_t*)n); - case CS_UINT_16: - push_fixed_value(_uint16, _msgpack_load16(uint16_t,n)); - case CS_UINT_32: - push_fixed_value(_uint32, _msgpack_load32(uint32_t,n)); - case CS_UINT_64: - push_fixed_value(_uint64, _msgpack_load64(uint64_t,n)); - - case CS_INT_8: - push_fixed_value(_int8, *(int8_t*)n); - case CS_INT_16: - push_fixed_value(_int16, _msgpack_load16(int16_t,n)); - case CS_INT_32: - push_fixed_value(_int32, _msgpack_load32(int32_t,n)); - case CS_INT_64: - push_fixed_value(_int64, _msgpack_load64(int64_t,n)); - - case CS_BIN_8: - again_fixed_trail_if_zero(ACS_BIN_VALUE, *(uint8_t*)n, _bin_zero); - case CS_BIN_16: - again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load16(uint16_t,n), _bin_zero); - case CS_BIN_32: - again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load32(uint32_t,n), _bin_zero); - case ACS_BIN_VALUE: - _bin_zero: - push_variable_value(_bin, data, n, trail); - - case CS_RAW_8: - again_fixed_trail_if_zero(ACS_RAW_VALUE, *(uint8_t*)n, _raw_zero); - case CS_RAW_16: - again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load16(uint16_t,n), _raw_zero); - case CS_RAW_32: - again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load32(uint32_t,n), _raw_zero); - case ACS_RAW_VALUE: - _raw_zero: - push_variable_value(_raw, data, n, trail); - - case ACS_EXT_VALUE: - _ext_zero: - push_variable_value(_ext, data, n, trail); - - case CS_ARRAY_16: - start_container(_array, _msgpack_load16(uint16_t,n), CT_ARRAY_ITEM); - case CS_ARRAY_32: - /* FIXME security guard */ - start_container(_array, _msgpack_load32(uint32_t,n), CT_ARRAY_ITEM); - - case CS_MAP_16: - start_container(_map, _msgpack_load16(uint16_t,n), CT_MAP_KEY); - case CS_MAP_32: - /* FIXME security guard */ - start_container(_map, _msgpack_load32(uint32_t,n), CT_MAP_KEY); - - default: - goto _failed; - } - } - -_push: - if(top == 0) { goto _finish; } - c = &stack[top-1]; - switch(c->ct) { - case CT_ARRAY_ITEM: - if(construct_cb(_array_item)(user, c->count, &c->obj, obj) < 0) { goto _failed; } - if(++c->count == c->size) { - obj = c->obj; - if (construct_cb(_array_end)(user, &obj) < 0) { goto _failed; } - --top; - /*printf("stack pop %d\n", top);*/ - goto _push; - } - goto _header_again; - case CT_MAP_KEY: - c->map_key = obj; - c->ct = CT_MAP_VALUE; - goto _header_again; - case CT_MAP_VALUE: - if(construct_cb(_map_item)(user, c->count, &c->obj, c->map_key, obj) < 0) { goto _failed; } - if(++c->count == c->size) { - obj = c->obj; - if (construct_cb(_map_end)(user, &obj) < 0) { goto _failed; } - --top; - /*printf("stack pop %d\n", top);*/ - goto _push; - } - c->ct = CT_MAP_KEY; - goto _header_again; - - default: - goto _failed; - } - -_header_again: - cs = CS_HEADER; - ++p; - } while(p != pe); - goto _out; - - -_finish: - if (!construct) - unpack_callback_nil(user, &obj); - stack[0].obj = obj; - ++p; - ret = 1; - /*printf("-- finish --\n"); */ - goto _end; - -_failed: - /*printf("** FAILED **\n"); */ - ret = -1; - goto _end; - -_out: - ret = 0; - goto _end; - -_end: - ctx->cs = cs; - ctx->trail = trail; - ctx->top = top; - *off = p - (const unsigned char*)data; - - return ret; -#undef construct_cb -} - -#undef NEXT_CS -#undef SWITCH_RANGE_BEGIN -#undef SWITCH_RANGE -#undef SWITCH_RANGE_DEFAULT -#undef SWITCH_RANGE_END -#undef push_simple_value -#undef push_fixed_value -#undef push_variable_value -#undef again_fixed_trail -#undef again_fixed_trail_if_zero -#undef start_container - -static int unpack_construct(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) { - return unpack_execute(1, ctx, data, len, off); -} -static int unpack_skip(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) { - return unpack_execute(0, ctx, data, len, off); -} - -#define unpack_container_header read_array_header -#define fixed_offset 0x90 -#define var_offset 0xdc -#include "unpack_container_header.h" -#undef unpack_container_header -#undef fixed_offset -#undef var_offset - -#define unpack_container_header read_map_header -#define fixed_offset 0x80 -#define var_offset 0xde -#include "unpack_container_header.h" -#undef unpack_container_header -#undef fixed_offset -#undef var_offset - -/* vim: set ts=4 sw=4 sts=4 expandtab */ diff --git a/srsly/msgpack/util.py b/srsly/msgpack/util.py deleted file mode 100644 index 967d66d..0000000 --- a/srsly/msgpack/util.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import unicode_literals - -try: - unicode -except NameError: - unicode = str - - -def ensure_bytes(string): - """Ensure a string is returned as a bytes object, encoded as utf8.""" - if isinstance(string, unicode): - return string.encode("utf8") - else: - return string diff --git a/srsly/ruamel_yaml/LICENSE b/srsly/ruamel_yaml/LICENSE deleted file mode 100755 index 5b863d3..0000000 --- a/srsly/ruamel_yaml/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ - The MIT License (MIT) - - Copyright (c) 2014-2020 Anthon van der Neut, Ruamel bvba - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. diff --git a/srsly/ruamel_yaml/__init__.py b/srsly/ruamel_yaml/__init__.py deleted file mode 100755 index b7678ec..0000000 --- a/srsly/ruamel_yaml/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -__with_libyaml__ = False - -from .main import * # NOQA - -version_info = (0, 16, 7) -__version__ = "0.16.7" diff --git a/srsly/ruamel_yaml/anchor.py b/srsly/ruamel_yaml/anchor.py deleted file mode 100755 index aa649f5..0000000 --- a/srsly/ruamel_yaml/anchor.py +++ /dev/null @@ -1,20 +0,0 @@ - -if False: # MYPY - from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA - -anchor_attrib = '_yaml_anchor' - - -class Anchor(object): - __slots__ = 'value', 'always_dump' - attrib = anchor_attrib - - def __init__(self): - # type: () -> None - self.value = None - self.always_dump = False - - def __repr__(self): - # type: () -> Any - ad = ', (always dump)' if self.always_dump else "" - return 'Anchor({!r}{})'.format(self.value, ad) diff --git a/srsly/ruamel_yaml/comments.py b/srsly/ruamel_yaml/comments.py deleted file mode 100755 index 9090298..0000000 --- a/srsly/ruamel_yaml/comments.py +++ /dev/null @@ -1,1153 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import, print_function - -""" -stuff to deal with comments and formatting on dict/list/ordereddict/set -these are not really related, formatting could be factored out as -a separate base -""" - -import sys -import copy - - -from .compat import ordereddict # type: ignore -from .compat import PY2, string_types, MutableSliceableSequence -from .scalarstring import ScalarString -from .anchor import Anchor - -if PY2: - from collections import MutableSet, Sized, Set, Mapping -else: - from collections.abc import MutableSet, Sized, Set, Mapping - -if False: # MYPY - from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA - -# fmt: off -__all__ = ['CommentedSeq', 'CommentedKeySeq', - 'CommentedMap', 'CommentedOrderedMap', - 'CommentedSet', 'comment_attrib', 'merge_attrib'] -# fmt: on - -comment_attrib = "_yaml_comment" -format_attrib = "_yaml_format" -line_col_attrib = "_yaml_line_col" -merge_attrib = "_yaml_merge" -tag_attrib = "_yaml_tag" - - -class Comment(object): - # sys.getsize tested the Comment objects, __slots__ makes them bigger - # and adding self.end did not matter - __slots__ = "comment", "_items", "_end", "_start" - attrib = comment_attrib - - def __init__(self): - # type: () -> None - self.comment = None # [post, [pre]] - # map key (mapping/omap/dict) or index (sequence/list) to a list of - # dict: post_key, pre_key, post_value, pre_value - # list: pre item, post item - self._items = {} # type: Dict[Any, Any] - # self._start = [] # should not put these on first item - self._end = [] # type: List[Any] # end of document comments - - def __str__(self): - # type: () -> str - if bool(self._end): - end = ",\n end=" + str(self._end) - else: - end = "" - return "Comment(comment={0},\n items={1}{2})".format( - self.comment, self._items, end - ) - - @property - def items(self): - # type: () -> Any - return self._items - - @property - def end(self): - # type: () -> Any - return self._end - - @end.setter - def end(self, value): - # type: (Any) -> None - self._end = value - - @property - def start(self): - # type: () -> Any - return self._start - - @start.setter - def start(self, value): - # type: (Any) -> None - self._start = value - - -# to distinguish key from None -def NoComment(): - # type: () -> None - pass - - -class Format(object): - __slots__ = ("_flow_style",) - attrib = format_attrib - - def __init__(self): - # type: () -> None - self._flow_style = None # type: Any - - def set_flow_style(self): - # type: () -> None - self._flow_style = True - - def set_block_style(self): - # type: () -> None - self._flow_style = False - - def flow_style(self, default=None): - # type: (Optional[Any]) -> Any - """if default (the flow_style) is None, the flow style tacked on to - the object explicitly will be taken. If that is None as well the - default flow style rules the format down the line, or the type - of the constituent values (simple -> flow, map/list -> block)""" - if self._flow_style is None: - return default - return self._flow_style - - -class LineCol(object): - attrib = line_col_attrib - - def __init__(self): - # type: () -> None - self.line = None - self.col = None - self.data = None # type: Optional[Dict[Any, Any]] - - def add_kv_line_col(self, key, data): - # type: (Any, Any) -> None - if self.data is None: - self.data = {} - self.data[key] = data - - def key(self, k): - # type: (Any) -> Any - return self._kv(k, 0, 1) - - def value(self, k): - # type: (Any) -> Any - return self._kv(k, 2, 3) - - def _kv(self, k, x0, x1): - # type: (Any, Any, Any) -> Any - if self.data is None: - return None - data = self.data[k] - return data[x0], data[x1] - - def item(self, idx): - # type: (Any) -> Any - if self.data is None: - return None - return self.data[idx][0], self.data[idx][1] - - def add_idx_line_col(self, key, data): - # type: (Any, Any) -> None - if self.data is None: - self.data = {} - self.data[key] = data - - -class Tag(object): - """store tag information for roundtripping""" - - __slots__ = ("value",) - attrib = tag_attrib - - def __init__(self): - # type: () -> None - self.value = None - - def __repr__(self): - # type: () -> Any - return "{0.__class__.__name__}({0.value!r})".format(self) - - -class CommentedBase(object): - @property - def ca(self): - # type: () -> Any - if not hasattr(self, Comment.attrib): - setattr(self, Comment.attrib, Comment()) - return getattr(self, Comment.attrib) - - def yaml_end_comment_extend(self, comment, clear=False): - # type: (Any, bool) -> None - if comment is None: - return - if clear or self.ca.end is None: - self.ca.end = [] - self.ca.end.extend(comment) - - def yaml_key_comment_extend(self, key, comment, clear=False): - # type: (Any, Any, bool) -> None - r = self.ca._items.setdefault(key, [None, None, None, None]) - if clear or r[1] is None: - if comment[1] is not None: - assert isinstance(comment[1], list) - r[1] = comment[1] - else: - r[1].extend(comment[0]) - r[0] = comment[0] - - def yaml_value_comment_extend(self, key, comment, clear=False): - # type: (Any, Any, bool) -> None - r = self.ca._items.setdefault(key, [None, None, None, None]) - if clear or r[3] is None: - if comment[1] is not None: - assert isinstance(comment[1], list) - r[3] = comment[1] - else: - r[3].extend(comment[0]) - r[2] = comment[0] - - def yaml_set_start_comment(self, comment, indent=0): - # type: (Any, Any) -> None - """overwrites any preceding comment lines on an object - expects comment to be without `#` and possible have multiple lines - """ - from .error import CommentMark - from .tokens import CommentToken - - pre_comments = self._yaml_get_pre_comment() - if comment[-1] == "\n": - comment = comment[:-1] # strip final newline if there - start_mark = CommentMark(indent) - for com in comment.split("\n"): - pre_comments.append(CommentToken("# " + com + "\n", start_mark, None)) - - def yaml_set_comment_before_after_key( - self, key, before=None, indent=0, after=None, after_indent=None - ): - # type: (Any, Any, Any, Any, Any) -> None - """ - expects comment (before/after) to be without `#` and possible have multiple lines - """ - from srsly.ruamel_yaml.error import CommentMark - from srsly.ruamel_yaml.tokens import CommentToken - - def comment_token(s, mark): - # type: (Any, Any) -> Any - # handle empty lines as having no comment - return CommentToken(("# " if s else "") + s + "\n", mark, None) - - if after_indent is None: - after_indent = indent + 2 - if before and (len(before) > 1) and before[-1] == "\n": - before = before[:-1] # strip final newline if there - if after and after[-1] == "\n": - after = after[:-1] # strip final newline if there - start_mark = CommentMark(indent) - c = self.ca.items.setdefault(key, [None, [], None, None]) - if before == "\n": - c[1].append(comment_token("", start_mark)) - elif before: - for com in before.split("\n"): - c[1].append(comment_token(com, start_mark)) - if after: - start_mark = CommentMark(after_indent) - if c[3] is None: - c[3] = [] - for com in after.split("\n"): - c[3].append(comment_token(com, start_mark)) # type: ignore - - @property - def fa(self): - # type: () -> Any - """format attribute - - set_flow_style()/set_block_style()""" - if not hasattr(self, Format.attrib): - setattr(self, Format.attrib, Format()) - return getattr(self, Format.attrib) - - def yaml_add_eol_comment(self, comment, key=NoComment, column=None): - # type: (Any, Optional[Any], Optional[Any]) -> None - """ - there is a problem as eol comments should start with ' #' - (but at the beginning of the line the space doesn't have to be before - the #. The column index is for the # mark - """ - from .tokens import CommentToken - from .error import CommentMark - - if column is None: - try: - column = self._yaml_get_column(key) - except AttributeError: - column = 0 - if comment[0] != "#": - comment = "# " + comment - if column is None: - if comment[0] == "#": - comment = " " + comment - column = 0 - start_mark = CommentMark(column) - ct = [CommentToken(comment, start_mark, None), None] - self._yaml_add_eol_comment(ct, key=key) - - @property - def lc(self): - # type: () -> Any - if not hasattr(self, LineCol.attrib): - setattr(self, LineCol.attrib, LineCol()) - return getattr(self, LineCol.attrib) - - def _yaml_set_line_col(self, line, col): - # type: (Any, Any) -> None - self.lc.line = line - self.lc.col = col - - def _yaml_set_kv_line_col(self, key, data): - # type: (Any, Any) -> None - self.lc.add_kv_line_col(key, data) - - def _yaml_set_idx_line_col(self, key, data): - # type: (Any, Any) -> None - self.lc.add_idx_line_col(key, data) - - @property - def anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - setattr(self, Anchor.attrib, Anchor()) - return getattr(self, Anchor.attrib) - - def yaml_anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - return None - return self.anchor - - def yaml_set_anchor(self, value, always_dump=False): - # type: (Any, bool) -> None - self.anchor.value = value - self.anchor.always_dump = always_dump - - @property - def tag(self): - # type: () -> Any - if not hasattr(self, Tag.attrib): - setattr(self, Tag.attrib, Tag()) - return getattr(self, Tag.attrib) - - def yaml_set_tag(self, value): - # type: (Any) -> None - self.tag.value = value - - def copy_attributes(self, t, memo=None): - # type: (Any, Any) -> None - # fmt: off - for a in [Comment.attrib, Format.attrib, LineCol.attrib, Anchor.attrib, - Tag.attrib, merge_attrib]: - if hasattr(self, a): - if memo is not None: - setattr(t, a, copy.deepcopy(getattr(self, a, memo))) - else: - setattr(t, a, getattr(self, a)) - # fmt: on - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - raise NotImplementedError - - def _yaml_get_pre_comment(self): - # type: () -> Any - raise NotImplementedError - - def _yaml_get_column(self, key): - # type: (Any) -> Any - raise NotImplementedError - - -class CommentedSeq(MutableSliceableSequence, list, CommentedBase): # type: ignore - __slots__ = (Comment.attrib, "_lst") - - def __init__(self, *args, **kw): - # type: (Any, Any) -> None - list.__init__(self, *args, **kw) - - def __getsingleitem__(self, idx): - # type: (Any) -> Any - return list.__getitem__(self, idx) - - def __setsingleitem__(self, idx, value): - # type: (Any, Any) -> None - # try to preserve the scalarstring type if setting an existing key to a new value - if idx < len(self): - if ( - isinstance(value, string_types) - and not isinstance(value, ScalarString) - and isinstance(self[idx], ScalarString) - ): - value = type(self[idx])(value) - list.__setitem__(self, idx, value) - - def __delsingleitem__(self, idx=None): - # type: (Any) -> Any - list.__delitem__(self, idx) - self.ca.items.pop(idx, None) # might not be there -> default value - for list_index in sorted(self.ca.items): - if list_index < idx: - continue - self.ca.items[list_index - 1] = self.ca.items.pop(list_index) - - def __len__(self): - # type: () -> int - return list.__len__(self) - - def insert(self, idx, val): - # type: (Any, Any) -> None - """the comments after the insertion have to move forward""" - list.insert(self, idx, val) - for list_index in sorted(self.ca.items, reverse=True): - if list_index < idx: - break - self.ca.items[list_index + 1] = self.ca.items.pop(list_index) - - def extend(self, val): - # type: (Any) -> None - list.extend(self, val) - - def __eq__(self, other): - # type: (Any) -> bool - return list.__eq__(self, other) - - def _yaml_add_comment(self, comment, key=NoComment): - # type: (Any, Optional[Any]) -> None - if key is not NoComment: - self.yaml_key_comment_extend(key, comment) - else: - self.ca.comment = comment - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - self._yaml_add_comment(comment, key=key) - - def _yaml_get_columnX(self, key): - # type: (Any) -> Any - return self.ca.items[key][0].start_mark.column - - def _yaml_get_column(self, key): - # type: (Any) -> Any - column = None - sel_idx = None - pre, post = key - 1, key + 1 - if pre in self.ca.items: - sel_idx = pre - elif post in self.ca.items: - sel_idx = post - else: - # self.ca.items is not ordered - for row_idx, _k1 in enumerate(self): - if row_idx >= key: - break - if row_idx not in self.ca.items: - continue - sel_idx = row_idx - if sel_idx is not None: - column = self._yaml_get_columnX(sel_idx) - return column - - def _yaml_get_pre_comment(self): - # type: () -> Any - pre_comments = [] # type: List[Any] - if self.ca.comment is None: - self.ca.comment = [None, pre_comments] - else: - self.ca.comment[1] = pre_comments - return pre_comments - - def __deepcopy__(self, memo): - # type: (Any) -> Any - res = self.__class__() - memo[id(self)] = res - for k in self: - res.append(copy.deepcopy(k, memo)) - self.copy_attributes(res, memo=memo) - return res - - def __add__(self, other): - # type: (Any) -> Any - return list.__add__(self, other) - - def sort(self, key=None, reverse=False): # type: ignore - # type: (Any, bool) -> None - if key is None: - tmp_lst = sorted(zip(self, range(len(self))), reverse=reverse) - list.__init__(self, [x[0] for x in tmp_lst]) - else: - tmp_lst = sorted( - zip(map(key, list.__iter__(self)), range(len(self))), reverse=reverse - ) - list.__init__(self, [list.__getitem__(self, x[1]) for x in tmp_lst]) - itm = self.ca.items - self.ca._items = {} - for idx, x in enumerate(tmp_lst): - old_index = x[1] - if old_index in itm: - self.ca.items[idx] = itm[old_index] - - def __repr__(self): - # type: () -> Any - return list.__repr__(self) - - -class CommentedKeySeq(tuple, CommentedBase): # type: ignore - """This primarily exists to be able to roundtrip keys that are sequences""" - - def _yaml_add_comment(self, comment, key=NoComment): - # type: (Any, Optional[Any]) -> None - if key is not NoComment: - self.yaml_key_comment_extend(key, comment) - else: - self.ca.comment = comment - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - self._yaml_add_comment(comment, key=key) - - def _yaml_get_columnX(self, key): - # type: (Any) -> Any - return self.ca.items[key][0].start_mark.column - - def _yaml_get_column(self, key): - # type: (Any) -> Any - column = None - sel_idx = None - pre, post = key - 1, key + 1 - if pre in self.ca.items: - sel_idx = pre - elif post in self.ca.items: - sel_idx = post - else: - # self.ca.items is not ordered - for row_idx, _k1 in enumerate(self): - if row_idx >= key: - break - if row_idx not in self.ca.items: - continue - sel_idx = row_idx - if sel_idx is not None: - column = self._yaml_get_columnX(sel_idx) - return column - - def _yaml_get_pre_comment(self): - # type: () -> Any - pre_comments = [] # type: List[Any] - if self.ca.comment is None: - self.ca.comment = [None, pre_comments] - else: - self.ca.comment[1] = pre_comments - return pre_comments - - -class CommentedMapView(Sized): - __slots__ = ("_mapping",) - - def __init__(self, mapping): - # type: (Any) -> None - self._mapping = mapping - - def __len__(self): - # type: () -> int - count = len(self._mapping) - return count - - -class CommentedMapKeysView(CommentedMapView, Set): # type: ignore - __slots__ = () - - @classmethod - def _from_iterable(self, it): - # type: (Any) -> Any - return set(it) - - def __contains__(self, key): - # type: (Any) -> Any - return key in self._mapping - - def __iter__(self): - # type: () -> Any # yield from self._mapping # not in py27, pypy - # for x in self._mapping._keys(): - for x in self._mapping: - yield x - - -class CommentedMapItemsView(CommentedMapView, Set): # type: ignore - __slots__ = () - - @classmethod - def _from_iterable(self, it): - # type: (Any) -> Any - return set(it) - - def __contains__(self, item): - # type: (Any) -> Any - key, value = item - try: - v = self._mapping[key] - except KeyError: - return False - else: - return v == value - - def __iter__(self): - # type: () -> Any - for key in self._mapping._keys(): - yield (key, self._mapping[key]) - - -class CommentedMapValuesView(CommentedMapView): - __slots__ = () - - def __contains__(self, value): - # type: (Any) -> Any - for key in self._mapping: - if value == self._mapping[key]: - return True - return False - - def __iter__(self): - # type: () -> Any - for key in self._mapping._keys(): - yield self._mapping[key] - - -class CommentedMap(ordereddict, CommentedBase): # type: ignore - __slots__ = (Comment.attrib, "_ok", "_ref") - - def __init__(self, *args, **kw): - # type: (Any, Any) -> None - self._ok = set() # type: MutableSet[Any] # own keys - self._ref = [] # type: List[CommentedMap] - ordereddict.__init__(self, *args, **kw) - - def _yaml_add_comment(self, comment, key=NoComment, value=NoComment): - # type: (Any, Optional[Any], Optional[Any]) -> None - """values is set to key to indicate a value attachment of comment""" - if key is not NoComment: - self.yaml_key_comment_extend(key, comment) - return - if value is not NoComment: - self.yaml_value_comment_extend(value, comment) - else: - self.ca.comment = comment - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - """add on the value line, with value specified by the key""" - self._yaml_add_comment(comment, value=key) - - def _yaml_get_columnX(self, key): - # type: (Any) -> Any - return self.ca.items[key][2].start_mark.column - - def _yaml_get_column(self, key): - # type: (Any) -> Any - column = None - sel_idx = None - pre, post, last = None, None, None - for x in self: - if pre is not None and x != key: - post = x - break - if x == key: - pre = last - last = x - if pre in self.ca.items: - sel_idx = pre - elif post in self.ca.items: - sel_idx = post - else: - # self.ca.items is not ordered - for k1 in self: - if k1 >= key: - break - if k1 not in self.ca.items: - continue - sel_idx = k1 - if sel_idx is not None: - column = self._yaml_get_columnX(sel_idx) - return column - - def _yaml_get_pre_comment(self): - # type: () -> Any - pre_comments = [] # type: List[Any] - if self.ca.comment is None: - self.ca.comment = [None, pre_comments] - else: - self.ca.comment[1] = pre_comments - return pre_comments - - def update(self, vals): - # type: (Any) -> None - try: - ordereddict.update(self, vals) - except TypeError: - # probably a dict that is used - for x in vals: - self[x] = vals[x] - try: - self._ok.update(vals.keys()) # type: ignore - except AttributeError: - # assume a list/tuple of two element lists/tuples - for x in vals: - self._ok.add(x[0]) - - def insert(self, pos, key, value, comment=None): - # type: (Any, Any, Any, Optional[Any]) -> None - """insert key value into given position - attach comment if provided - """ - ordereddict.insert(self, pos, key, value) - self._ok.add(key) - if comment is not None: - self.yaml_add_eol_comment(comment, key=key) - - def mlget(self, key, default=None, list_ok=False): - # type: (Any, Any, Any) -> Any - """multi-level get that expects dicts within dicts""" - if not isinstance(key, list): - return self.get(key, default) - # assume that the key is a list of recursively accessible dicts - - def get_one_level(key_list, level, d): - # type: (Any, Any, Any) -> Any - if not list_ok: - assert isinstance(d, dict) - if level >= len(key_list): - if level > len(key_list): - raise IndexError - return d[key_list[level - 1]] - return get_one_level(key_list, level + 1, d[key_list[level - 1]]) - - try: - return get_one_level(key, 1, self) - except KeyError: - return default - except (TypeError, IndexError): - if not list_ok: - raise - return default - - def __getitem__(self, key): - # type: (Any) -> Any - try: - return ordereddict.__getitem__(self, key) - except KeyError: - for merged in getattr(self, merge_attrib, []): - if key in merged[1]: - return merged[1][key] - raise - - def __setitem__(self, key, value): - # type: (Any, Any) -> None - # try to preserve the scalarstring type if setting an existing key to a new value - if key in self: - if ( - isinstance(value, string_types) - and not isinstance(value, ScalarString) - and isinstance(self[key], ScalarString) - ): - value = type(self[key])(value) - ordereddict.__setitem__(self, key, value) - self._ok.add(key) - - def _unmerged_contains(self, key): - # type: (Any) -> Any - if key in self._ok: - return True - return None - - def __contains__(self, key): - # type: (Any) -> bool - return bool(ordereddict.__contains__(self, key)) - - def get(self, key, default=None): - # type: (Any, Any) -> Any - try: - return self.__getitem__(key) - except: # NOQA - return default - - def __repr__(self): - # type: () -> Any - return ordereddict.__repr__(self).replace("CommentedMap", "ordereddict") - - def non_merged_items(self): - # type: () -> Any - for x in ordereddict.__iter__(self): - if x in self._ok: - yield x, ordereddict.__getitem__(self, x) - - def __delitem__(self, key): - # type: (Any) -> None - # for merged in getattr(self, merge_attrib, []): - # if key in merged[1]: - # value = merged[1][key] - # break - # else: - # # not found in merged in stuff - # ordereddict.__delitem__(self, key) - # for referer in self._ref: - # referer.update_key_value(key) - # return - # - # ordereddict.__setitem__(self, key, value) # merge might have different value - # self._ok.discard(key) - self._ok.discard(key) - ordereddict.__delitem__(self, key) - for referer in self._ref: - referer.update_key_value(key) - - def __iter__(self): - # type: () -> Any - for x in ordereddict.__iter__(self): - yield x - - def _keys(self): - # type: () -> Any - for x in ordereddict.__iter__(self): - yield x - - def __len__(self): - # type: () -> int - return int(ordereddict.__len__(self)) - - def __eq__(self, other): - # type: (Any) -> bool - return bool(dict(self) == other) - - if PY2: - - def keys(self): - # type: () -> Any - return list(self._keys()) - - def iterkeys(self): - # type: () -> Any - return self._keys() - - def viewkeys(self): - # type: () -> Any - return CommentedMapKeysView(self) - - else: - - def keys(self): - # type: () -> Any - return CommentedMapKeysView(self) - - if PY2: - - def _values(self): - # type: () -> Any - for x in ordereddict.__iter__(self): - yield ordereddict.__getitem__(self, x) - - def values(self): - # type: () -> Any - return list(self._values()) - - def itervalues(self): - # type: () -> Any - return self._values() - - def viewvalues(self): - # type: () -> Any - return CommentedMapValuesView(self) - - else: - - def values(self): - # type: () -> Any - return CommentedMapValuesView(self) - - def _items(self): - # type: () -> Any - for x in ordereddict.__iter__(self): - yield x, ordereddict.__getitem__(self, x) - - if PY2: - - def items(self): - # type: () -> Any - return list(self._items()) - - def iteritems(self): - # type: () -> Any - return self._items() - - def viewitems(self): - # type: () -> Any - return CommentedMapItemsView(self) - - else: - - def items(self): - # type: () -> Any - return CommentedMapItemsView(self) - - @property - def merge(self): - # type: () -> Any - if not hasattr(self, merge_attrib): - setattr(self, merge_attrib, []) - return getattr(self, merge_attrib) - - def copy(self): - # type: () -> Any - x = type(self)() # update doesn't work - for k, v in self._items(): - x[k] = v - self.copy_attributes(x) - return x - - def add_referent(self, cm): - # type: (Any) -> None - if cm not in self._ref: - self._ref.append(cm) - - def add_yaml_merge(self, value): - # type: (Any) -> None - for v in value: - v[1].add_referent(self) - for k, v in v[1].items(): - if ordereddict.__contains__(self, k): - continue - ordereddict.__setitem__(self, k, v) - self.merge.extend(value) - - def update_key_value(self, key): - # type: (Any) -> None - if key in self._ok: - return - for v in self.merge: - if key in v[1]: - ordereddict.__setitem__(self, key, v[1][key]) - return - ordereddict.__delitem__(self, key) - - def __deepcopy__(self, memo): - # type: (Any) -> Any - res = self.__class__() - memo[id(self)] = res - for k in self: - res[k] = copy.deepcopy(self[k], memo) - self.copy_attributes(res, memo=memo) - return res - - -# based on brownie mappings -@classmethod # type: ignore -def raise_immutable(cls, *args, **kwargs): - # type: (Any, *Any, **Any) -> None - raise TypeError("{} objects are immutable".format(cls.__name__)) - - -class CommentedKeyMap(CommentedBase, Mapping): # type: ignore - __slots__ = Comment.attrib, "_od" - """This primarily exists to be able to roundtrip keys that are mappings""" - - def __init__(self, *args, **kw): - # type: (Any, Any) -> None - if hasattr(self, "_od"): - raise_immutable(self) - try: - self._od = ordereddict(*args, **kw) - except TypeError: - if PY2: - self._od = ordereddict(args[0].items()) - else: - raise - - __delitem__ = ( - __setitem__ - ) = clear = pop = popitem = setdefault = update = raise_immutable - - # need to implement __getitem__, __iter__ and __len__ - def __getitem__(self, index): - # type: (Any) -> Any - return self._od[index] - - def __iter__(self): - # type: () -> Iterator[Any] - for x in self._od.__iter__(): - yield x - - def __len__(self): - # type: () -> int - return len(self._od) - - def __hash__(self): - # type: () -> Any - return hash(tuple(self.items())) - - def __repr__(self): - # type: () -> Any - if not hasattr(self, merge_attrib): - return self._od.__repr__() - return "ordereddict(" + repr(list(self._od.items())) + ")" - - @classmethod - def fromkeys(keys, v=None): - # type: (Any, Any) -> Any - return CommentedKeyMap(dict.fromkeys(keys, v)) - - def _yaml_add_comment(self, comment, key=NoComment): - # type: (Any, Optional[Any]) -> None - if key is not NoComment: - self.yaml_key_comment_extend(key, comment) - else: - self.ca.comment = comment - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - self._yaml_add_comment(comment, key=key) - - def _yaml_get_columnX(self, key): - # type: (Any) -> Any - return self.ca.items[key][0].start_mark.column - - def _yaml_get_column(self, key): - # type: (Any) -> Any - column = None - sel_idx = None - pre, post = key - 1, key + 1 - if pre in self.ca.items: - sel_idx = pre - elif post in self.ca.items: - sel_idx = post - else: - # self.ca.items is not ordered - for row_idx, _k1 in enumerate(self): - if row_idx >= key: - break - if row_idx not in self.ca.items: - continue - sel_idx = row_idx - if sel_idx is not None: - column = self._yaml_get_columnX(sel_idx) - return column - - def _yaml_get_pre_comment(self): - # type: () -> Any - pre_comments = [] # type: List[Any] - if self.ca.comment is None: - self.ca.comment = [None, pre_comments] - else: - self.ca.comment[1] = pre_comments - return pre_comments - - -class CommentedOrderedMap(CommentedMap): - __slots__ = (Comment.attrib,) - - -class CommentedSet(MutableSet, CommentedBase): # type: ignore # NOQA - __slots__ = Comment.attrib, "odict" - - def __init__(self, values=None): - # type: (Any) -> None - self.odict = ordereddict() - MutableSet.__init__(self) - if values is not None: - self |= values # type: ignore - - def _yaml_add_comment(self, comment, key=NoComment, value=NoComment): - # type: (Any, Optional[Any], Optional[Any]) -> None - """values is set to key to indicate a value attachment of comment""" - if key is not NoComment: - self.yaml_key_comment_extend(key, comment) - return - if value is not NoComment: - self.yaml_value_comment_extend(value, comment) - else: - self.ca.comment = comment - - def _yaml_add_eol_comment(self, comment, key): - # type: (Any, Any) -> None - """add on the value line, with value specified by the key""" - self._yaml_add_comment(comment, value=key) - - def add(self, value): - # type: (Any) -> None - """Add an element.""" - self.odict[value] = None - - def discard(self, value): - # type: (Any) -> None - """Remove an element. Do not raise an exception if absent.""" - del self.odict[value] - - def __contains__(self, x): - # type: (Any) -> Any - return x in self.odict - - def __iter__(self): - # type: () -> Any - for x in self.odict: - yield x - - def __len__(self): - # type: () -> int - return len(self.odict) - - def __repr__(self): - # type: () -> str - return "set({0!r})".format(self.odict.keys()) - - -class TaggedScalar(CommentedBase): - # the value and style attributes are set during roundtrip construction - def __init__(self, value=None, style=None, tag=None): - # type: (Any, Any, Any) -> None - self.value = value - self.style = style - if tag is not None: - self.yaml_set_tag(tag) - - def __str__(self): - # type: () -> Any - return self.value - - -def dump_comments(d, name="", sep=".", out=sys.stdout): - # type: (Any, str, str, Any) -> None - """ - recursively dump comments, all but the toplevel preceded by the path - in dotted form x.0.a - """ - if isinstance(d, dict) and hasattr(d, "ca"): - if name: - sys.stdout.write("{}\n".format(name)) - out.write("{}\n".format(d.ca)) # type: ignore - for k in d: - dump_comments(d[k], name=(name + sep + k) if name else k, sep=sep, out=out) - elif isinstance(d, list) and hasattr(d, "ca"): - if name: - sys.stdout.write("{}\n".format(name)) - out.write("{}\n".format(d.ca)) # type: ignore - for idx, k in enumerate(d): - dump_comments( - k, name=(name + sep + str(idx)) if name else str(idx), sep=sep, out=out - ) diff --git a/srsly/ruamel_yaml/compat.py b/srsly/ruamel_yaml/compat.py deleted file mode 100755 index 5544205..0000000 --- a/srsly/ruamel_yaml/compat.py +++ /dev/null @@ -1,328 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -# partially from package six by Benjamin Peterson - -import sys -import os -import types -import traceback -from abc import abstractmethod -from collections import OrderedDict # type: ignore - - -# fmt: off -if False: # MYPY - from typing import Any, Dict, Optional, List, Union, BinaryIO, IO, Text, Tuple # NOQA - from typing import Optional # NOQA -# fmt: on - -_DEFAULT_YAML_VERSION = (1, 2) - - -class ordereddict(OrderedDict): # type: ignore - if not hasattr(OrderedDict, "insert"): - - def insert(self, pos, key, value): - # type: (int, Any, Any) -> None - if pos >= len(self): - self[key] = value - return - od = ordereddict() - od.update(self) - for k in od: - del self[k] - for index, old_key in enumerate(od): - if pos == index: - self[key] = value - self[old_key] = od[old_key] - - -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - - -if PY3: - - def utf8(s): - # type: (str) -> str - return s - - def to_str(s): - # type: (str) -> str - return s - - def to_unicode(s): - # type: (str) -> str - return s - - -else: - if False: - unicode = str - - def utf8(s): - # type: (unicode) -> str - return s.encode("utf-8") - - def to_str(s): - # type: (str) -> str - return str(s) - - def to_unicode(s): - # type: (str) -> unicode - return unicode(s) # NOQA - - -if PY3: - string_types = str - integer_types = int - class_types = type - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize - unichr = chr - import io - - StringIO = io.StringIO - BytesIO = io.BytesIO - # have unlimited precision - no_limit_int = int - from collections.abc import ( - Hashable, - MutableSequence, - MutableMapping, - Mapping, - ) # NOQA - -else: - string_types = basestring # NOQA - integer_types = (int, long) # NOQA - class_types = (type, types.ClassType) - text_type = unicode # NOQA - binary_type = str - - # to allow importing - unichr = unichr - from StringIO import StringIO as _StringIO - - StringIO = _StringIO - import cStringIO - - BytesIO = cStringIO.StringIO - # have unlimited precision - no_limit_int = long # NOQA not available on Python 3 - from collections import Hashable, MutableSequence, MutableMapping, Mapping # NOQA - -if False: # MYPY - # StreamType = Union[BinaryIO, IO[str], IO[unicode], StringIO] - # StreamType = Union[BinaryIO, IO[str], StringIO] # type: ignore - StreamType = Any - - StreamTextType = StreamType # Union[Text, StreamType] - VersionType = Union[List[int], str, Tuple[int, int]] - -if PY3: - builtins_module = "builtins" -else: - builtins_module = "__builtin__" - -UNICODE_SIZE = 4 if sys.maxunicode > 65535 else 2 - - -def with_metaclass(meta, *bases): - # type: (Any, Any) -> Any - """Create a base class with a metaclass.""" - return meta("NewBase", bases, {}) - - -DBG_TOKEN = 1 -DBG_EVENT = 2 -DBG_NODE = 4 - - -_debug = None # type: Optional[int] -if "RUAMELDEBUG" in os.environ: - _debugx = os.environ.get("RUAMELDEBUG") - if _debugx is None: - _debug = 0 - else: - _debug = int(_debugx) - - -if bool(_debug): - - class ObjectCounter(object): - def __init__(self): - # type: () -> None - self.map = {} # type: Dict[Any, Any] - - def __call__(self, k): - # type: (Any) -> None - self.map[k] = self.map.get(k, 0) + 1 - - def dump(self): - # type: () -> None - for k in sorted(self.map): - sys.stdout.write("{} -> {}".format(k, self.map[k])) - - object_counter = ObjectCounter() - - -# used from yaml util when testing -def dbg(val=None): - # type: (Any) -> Any - global _debug - if _debug is None: - # set to true or false - _debugx = os.environ.get("YAMLDEBUG") - if _debugx is None: - _debug = 0 - else: - _debug = int(_debugx) - if val is None: - return _debug - return _debug & val - - -class Nprint(object): - def __init__(self, file_name=None): - # type: (Any) -> None - self._max_print = None # type: Any - self._count = None # type: Any - self._file_name = file_name - - def __call__(self, *args, **kw): - # type: (Any, Any) -> None - if not bool(_debug): - return - out = sys.stdout if self._file_name is None else open(self._file_name, "a") - dbgprint = print # to fool checking for print statements by dv utility - kw1 = kw.copy() - kw1["file"] = out - dbgprint(*args, **kw1) - out.flush() - if self._max_print is not None: - if self._count is None: - self._count = self._max_print - self._count -= 1 - if self._count == 0: - dbgprint("forced exit\n") - traceback.print_stack() - out.flush() - sys.exit(0) - if self._file_name: - out.close() - - def set_max_print(self, i): - # type: (int) -> None - self._max_print = i - self._count = None - - -nprint = Nprint() -nprintf = Nprint("/var/tmp/srsly.ruamel_yaml.log") - -# char checkers following production rules - - -def check_namespace_char(ch): - # type: (Any) -> bool - if u"\x21" <= ch <= u"\x7E": # ! to ~ - return True - if u"\xA0" <= ch <= u"\uD7FF": - return True - if (u"\uE000" <= ch <= u"\uFFFD") and ch != u"\uFEFF": # excl. byte order mark - return True - if u"\U00010000" <= ch <= u"\U0010FFFF": - return True - return False - - -def check_anchorname_char(ch): - # type: (Any) -> bool - if ch in u",[]{}": - return False - return check_namespace_char(ch) - - -def version_tnf(t1, t2=None): - # type: (Any, Any) -> Any - """ - return True if srsly.ruamel_yaml version_info < t1, None if t2 is specified and bigger else False - """ - from srsly.ruamel_yaml import version_info # NOQA - - if version_info < t1: - return True - if t2 is not None and version_info < t2: - return None - return False - - -class MutableSliceableSequence(MutableSequence): # type: ignore - __slots__ = () - - def __getitem__(self, index): - # type: (Any) -> Any - if not isinstance(index, slice): - return self.__getsingleitem__(index) - return type(self)( - [self[i] for i in range(*index.indices(len(self)))] - ) # type: ignore - - def __setitem__(self, index, value): - # type: (Any, Any) -> None - if not isinstance(index, slice): - return self.__setsingleitem__(index, value) - assert iter(value) - # nprint(index.start, index.stop, index.step, index.indices(len(self))) - if index.step is None: - del self[index.start : index.stop] - for elem in reversed(value): - self.insert(0 if index.start is None else index.start, elem) - else: - range_parms = index.indices(len(self)) - nr_assigned_items = (range_parms[1] - range_parms[0] - 1) // range_parms[ - 2 - ] + 1 - # need to test before changing, in case TypeError is caught - if nr_assigned_items < len(value): - raise TypeError( - "too many elements in value {} < {}".format( - nr_assigned_items, len(value) - ) - ) - elif nr_assigned_items > len(value): - raise TypeError( - "not enough elements in value {} > {}".format( - nr_assigned_items, len(value) - ) - ) - for idx, i in enumerate(range(*range_parms)): - self[i] = value[idx] - - def __delitem__(self, index): - # type: (Any) -> None - if not isinstance(index, slice): - return self.__delsingleitem__(index) - # nprint(index.start, index.stop, index.step, index.indices(len(self))) - for i in reversed(range(*index.indices(len(self)))): - del self[i] - - @abstractmethod - def __getsingleitem__(self, index): - # type: (Any) -> Any - raise IndexError - - @abstractmethod - def __setsingleitem__(self, index, value): - # type: (Any, Any) -> None - raise IndexError - - @abstractmethod - def __delsingleitem__(self, index): - # type: (Any) -> None - raise IndexError diff --git a/srsly/ruamel_yaml/composer.py b/srsly/ruamel_yaml/composer.py deleted file mode 100755 index 27d5a48..0000000 --- a/srsly/ruamel_yaml/composer.py +++ /dev/null @@ -1,243 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import, print_function - -import warnings - -from .error import MarkedYAMLError, ReusedAnchorWarning -from .compat import utf8, nprint, nprintf # NOQA - -from .events import ( - StreamStartEvent, - StreamEndEvent, - MappingStartEvent, - MappingEndEvent, - SequenceStartEvent, - SequenceEndEvent, - AliasEvent, - ScalarEvent, -) -from .nodes import MappingNode, ScalarNode, SequenceNode - -if False: # MYPY - from typing import Any, Dict, Optional, List # NOQA - -__all__ = ["Composer", "ComposerError"] - - -class ComposerError(MarkedYAMLError): - pass - - -class Composer(object): - def __init__(self, loader=None): - # type: (Any) -> None - self.loader = loader - if self.loader is not None and getattr(self.loader, "_composer", None) is None: - self.loader._composer = self - self.anchors = {} # type: Dict[Any, Any] - - @property - def parser(self): - # type: () -> Any - if hasattr(self.loader, "typ"): - self.loader.parser - return self.loader._parser - - @property - def resolver(self): - # type: () -> Any - # assert self.loader._resolver is not None - if hasattr(self.loader, "typ"): - self.loader.resolver - return self.loader._resolver - - def check_node(self): - # type: () -> Any - # Drop the STREAM-START event. - if self.parser.check_event(StreamStartEvent): - self.parser.get_event() - - # If there are more documents available? - return not self.parser.check_event(StreamEndEvent) - - def get_node(self): - # type: () -> Any - # Get the root node of the next document. - if not self.parser.check_event(StreamEndEvent): - return self.compose_document() - - def get_single_node(self): - # type: () -> Any - # Drop the STREAM-START event. - self.parser.get_event() - - # Compose a document if the stream is not empty. - document = None # type: Any - if not self.parser.check_event(StreamEndEvent): - document = self.compose_document() - - # Ensure that the stream contains no more documents. - if not self.parser.check_event(StreamEndEvent): - event = self.parser.get_event() - raise ComposerError( - "expected a single document in the stream", - document.start_mark, - "but found another document", - event.start_mark, - ) - - # Drop the STREAM-END event. - self.parser.get_event() - - return document - - def compose_document(self): - # type: (Any) -> Any - # Drop the DOCUMENT-START event. - self.parser.get_event() - - # Compose the root node. - node = self.compose_node(None, None) - - # Drop the DOCUMENT-END event. - self.parser.get_event() - - self.anchors = {} - return node - - def compose_node(self, parent, index): - # type: (Any, Any) -> Any - if self.parser.check_event(AliasEvent): - event = self.parser.get_event() - alias = event.anchor - if alias not in self.anchors: - raise ComposerError( - None, - None, - "found undefined alias %r" % utf8(alias), - event.start_mark, - ) - return self.anchors[alias] - event = self.parser.peek_event() - anchor = event.anchor - if anchor is not None: # have an anchor - if anchor in self.anchors: - # raise ComposerError( - # "found duplicate anchor %r; first occurrence" - # % utf8(anchor), self.anchors[anchor].start_mark, - # "second occurrence", event.start_mark) - ws = ( - "\nfound duplicate anchor {!r}\nfirst occurrence {}\nsecond occurrence " - "{}".format( - (anchor), self.anchors[anchor].start_mark, event.start_mark - ) - ) - warnings.warn(ws, ReusedAnchorWarning) - self.resolver.descend_resolver(parent, index) - if self.parser.check_event(ScalarEvent): - node = self.compose_scalar_node(anchor) - elif self.parser.check_event(SequenceStartEvent): - node = self.compose_sequence_node(anchor) - elif self.parser.check_event(MappingStartEvent): - node = self.compose_mapping_node(anchor) - self.resolver.ascend_resolver() - return node - - def compose_scalar_node(self, anchor): - # type: (Any) -> Any - event = self.parser.get_event() - tag = event.tag - if tag is None or tag == u"!": - tag = self.resolver.resolve(ScalarNode, event.value, event.implicit) - node = ScalarNode( - tag, - event.value, - event.start_mark, - event.end_mark, - style=event.style, - comment=event.comment, - anchor=anchor, - ) - if anchor is not None: - self.anchors[anchor] = node - return node - - def compose_sequence_node(self, anchor): - # type: (Any) -> Any - start_event = self.parser.get_event() - tag = start_event.tag - if tag is None or tag == u"!": - tag = self.resolver.resolve(SequenceNode, None, start_event.implicit) - node = SequenceNode( - tag, - [], - start_event.start_mark, - None, - flow_style=start_event.flow_style, - comment=start_event.comment, - anchor=anchor, - ) - if anchor is not None: - self.anchors[anchor] = node - index = 0 - while not self.parser.check_event(SequenceEndEvent): - node.value.append(self.compose_node(node, index)) - index += 1 - end_event = self.parser.get_event() - if node.flow_style is True and end_event.comment is not None: - if node.comment is not None: - nprint( - "Warning: unexpected end_event commment in sequence " - "node {}".format(node.flow_style) - ) - node.comment = end_event.comment - node.end_mark = end_event.end_mark - self.check_end_doc_comment(end_event, node) - return node - - def compose_mapping_node(self, anchor): - # type: (Any) -> Any - start_event = self.parser.get_event() - tag = start_event.tag - if tag is None or tag == u"!": - tag = self.resolver.resolve(MappingNode, None, start_event.implicit) - node = MappingNode( - tag, - [], - start_event.start_mark, - None, - flow_style=start_event.flow_style, - comment=start_event.comment, - anchor=anchor, - ) - if anchor is not None: - self.anchors[anchor] = node - while not self.parser.check_event(MappingEndEvent): - # key_event = self.parser.peek_event() - item_key = self.compose_node(node, None) - # if item_key in node.value: - # raise ComposerError("while composing a mapping", - # start_event.start_mark, - # "found duplicate key", key_event.start_mark) - item_value = self.compose_node(node, item_key) - # node.value[item_key] = item_value - node.value.append((item_key, item_value)) - end_event = self.parser.get_event() - if node.flow_style is True and end_event.comment is not None: - node.comment = end_event.comment - node.end_mark = end_event.end_mark - self.check_end_doc_comment(end_event, node) - return node - - def check_end_doc_comment(self, end_event, node): - # type: (Any, Any) -> None - if end_event.comment and end_event.comment[1]: - # pre comments on an end_event, no following to move to - if node.comment is None: - node.comment = [None, None] - assert not isinstance(node, ScalarEvent) - # this is a post comment on a mapping node, add as third element - # in the list - node.comment.append(end_event.comment[1]) - end_event.comment[1] = None diff --git a/srsly/ruamel_yaml/configobjwalker.py b/srsly/ruamel_yaml/configobjwalker.py deleted file mode 100755 index 511184e..0000000 --- a/srsly/ruamel_yaml/configobjwalker.py +++ /dev/null @@ -1,16 +0,0 @@ -# coding: utf-8 - -import warnings - -from .util import configobj_walker as new_configobj_walker - -if False: # MYPY - from typing import Any # NOQA - - -def configobj_walker(cfg): - # type: (Any) -> Any - warnings.warn( - "configobj_walker has moved to srsly.ruamel_yaml.util, please update your code" - ) - return new_configobj_walker(cfg) diff --git a/srsly/ruamel_yaml/constructor.py b/srsly/ruamel_yaml/constructor.py deleted file mode 100755 index b600339..0000000 --- a/srsly/ruamel_yaml/constructor.py +++ /dev/null @@ -1,1706 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division - -import datetime -import base64 -import binascii -import re -import sys -import types -import warnings - -# fmt: off -from .error import (MarkedYAMLError, MarkedYAMLFutureWarning, - MantissaNoDotYAML1_1Warning) -from .nodes import * # NOQA -from .nodes import (SequenceNode, MappingNode, ScalarNode) -from .compat import (utf8, builtins_module, to_str, PY2, PY3, # NOQA - text_type, nprint, nprintf, version_tnf) -from .compat import ordereddict, Hashable, MutableSequence # type: ignore -from .compat import MutableMapping # type: ignore - -from .comments import * # NOQA -from .comments import (CommentedMap, CommentedOrderedMap, CommentedSet, - CommentedKeySeq, CommentedSeq, TaggedScalar, - CommentedKeyMap) -from .scalarstring import (SingleQuotedScalarString, DoubleQuotedScalarString, - LiteralScalarString, FoldedScalarString, - PlainScalarString, ScalarString,) -from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt -from .scalarfloat import ScalarFloat -from .scalarbool import ScalarBoolean -from .timestamp import TimeStamp -from .util import RegExp - -if False: # MYPY - from typing import Any, Dict, List, Set, Generator, Union, Optional # NOQA - - -__all__ = ['BaseConstructor', 'SafeConstructor', 'Constructor', - 'ConstructorError', 'RoundTripConstructor'] -# fmt: on - - -class ConstructorError(MarkedYAMLError): - pass - - -class DuplicateKeyFutureWarning(MarkedYAMLFutureWarning): - pass - - -class DuplicateKeyError(MarkedYAMLFutureWarning): - pass - - -class BaseConstructor(object): - - yaml_constructors = {} # type: Dict[Any, Any] - yaml_multi_constructors = {} # type: Dict[Any, Any] - - def __init__(self, preserve_quotes=None, loader=None): - # type: (Optional[bool], Any) -> None - self.loader = loader - if ( - self.loader is not None - and getattr(self.loader, "_constructor", None) is None - ): - self.loader._constructor = self - self.loader = loader - self.yaml_base_dict_type = dict - self.yaml_base_list_type = list - self.constructed_objects = {} # type: Dict[Any, Any] - self.recursive_objects = {} # type: Dict[Any, Any] - self.state_generators = [] # type: List[Any] - self.deep_construct = False - self._preserve_quotes = preserve_quotes - self.allow_duplicate_keys = version_tnf((0, 15, 1), (0, 16)) - - @property - def composer(self): - # type: () -> Any - if hasattr(self.loader, "typ"): - return self.loader.composer - try: - return self.loader._composer - except AttributeError: - sys.stdout.write("slt {}\n".format(type(self))) - sys.stdout.write("slc {}\n".format(self.loader._composer)) - sys.stdout.write("{}\n".format(dir(self))) - raise - - @property - def resolver(self): - # type: () -> Any - if hasattr(self.loader, "typ"): - return self.loader.resolver - return self.loader._resolver - - def check_data(self): - # type: () -> Any - # If there are more documents available? - return self.composer.check_node() - - def get_data(self): - # type: () -> Any - # Construct and return the next document. - if self.composer.check_node(): - return self.construct_document(self.composer.get_node()) - - def get_single_data(self): - # type: () -> Any - # Ensure that the stream contains a single document and construct it. - node = self.composer.get_single_node() - if node is not None: - return self.construct_document(node) - return None - - def construct_document(self, node): - # type: (Any) -> Any - data = self.construct_object(node) - while bool(self.state_generators): - state_generators = self.state_generators - self.state_generators = [] - for generator in state_generators: - for _dummy in generator: - pass - self.constructed_objects = {} - self.recursive_objects = {} - self.deep_construct = False - return data - - def construct_object(self, node, deep=False): - # type: (Any, bool) -> Any - """deep is True when creating an object/mapping recursively, - in that case want the underlying elements available during construction - """ - if node in self.constructed_objects: - return self.constructed_objects[node] - if deep: - old_deep = self.deep_construct - self.deep_construct = True - if node in self.recursive_objects: - return self.recursive_objects[node] - # raise ConstructorError( - # None, None, 'found unconstructable recursive node', node.start_mark - # ) - self.recursive_objects[node] = None - data = self.construct_non_recursive_object(node) - - self.constructed_objects[node] = data - del self.recursive_objects[node] - if deep: - self.deep_construct = old_deep - return data - - def construct_non_recursive_object(self, node, tag=None): - # type: (Any, Optional[str]) -> Any - constructor = None # type: Any - tag_suffix = None - if tag is None: - tag = node.tag - if tag in self.yaml_constructors: - constructor = self.yaml_constructors[tag] - else: - for tag_prefix in self.yaml_multi_constructors: - if tag.startswith(tag_prefix): - tag_suffix = tag[len(tag_prefix) :] - constructor = self.yaml_multi_constructors[tag_prefix] - break - else: - if None in self.yaml_multi_constructors: - tag_suffix = tag - constructor = self.yaml_multi_constructors[None] - elif None in self.yaml_constructors: - constructor = self.yaml_constructors[None] - elif isinstance(node, ScalarNode): - constructor = self.__class__.construct_scalar - elif isinstance(node, SequenceNode): - constructor = self.__class__.construct_sequence - elif isinstance(node, MappingNode): - constructor = self.__class__.construct_mapping - if tag_suffix is None: - data = constructor(self, node) - else: - data = constructor(self, tag_suffix, node) - if isinstance(data, types.GeneratorType): - generator = data - data = next(generator) - if self.deep_construct: - for _dummy in generator: - pass - else: - self.state_generators.append(generator) - return data - - def construct_scalar(self, node): - # type: (Any) -> Any - if not isinstance(node, ScalarNode): - raise ConstructorError( - None, - None, - "expected a scalar node, but found %s" % node.id, - node.start_mark, - ) - return node.value - - def construct_sequence(self, node, deep=False): - # type: (Any, bool) -> Any - """deep is True when creating an object/mapping recursively, - in that case want the underlying elements available during construction - """ - if not isinstance(node, SequenceNode): - raise ConstructorError( - None, - None, - "expected a sequence node, but found %s" % node.id, - node.start_mark, - ) - return [self.construct_object(child, deep=deep) for child in node.value] - - def construct_mapping(self, node, deep=False): - # type: (Any, bool) -> Any - """deep is True when creating an object/mapping recursively, - in that case want the underlying elements available during construction - """ - if not isinstance(node, MappingNode): - raise ConstructorError( - None, - None, - "expected a mapping node, but found %s" % node.id, - node.start_mark, - ) - total_mapping = self.yaml_base_dict_type() - if getattr(node, "merge", None) is not None: - todo = [(node.merge, False), (node.value, False)] - else: - todo = [(node.value, True)] - for values, check in todo: - mapping = self.yaml_base_dict_type() # type: Dict[Any, Any] - for key_node, value_node in values: - # keys can be list -> deep - key = self.construct_object(key_node, deep=True) - # lists are not hashable, but tuples are - if not isinstance(key, Hashable): - if isinstance(key, list): - key = tuple(key) - if PY2: - try: - hash(key) - except TypeError as exc: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unacceptable key (%s)" % exc, - key_node.start_mark, - ) - else: - if not isinstance(key, Hashable): - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unhashable key", - key_node.start_mark, - ) - - value = self.construct_object(value_node, deep=deep) - if check: - if self.check_mapping_key(node, key_node, mapping, key, value): - mapping[key] = value - else: - mapping[key] = value - total_mapping.update(mapping) - return total_mapping - - def check_mapping_key(self, node, key_node, mapping, key, value): - # type: (Any, Any, Any, Any, Any) -> bool - """return True if key is unique""" - if key in mapping: - if not self.allow_duplicate_keys: - mk = mapping.get(key) - if PY2: - if isinstance(key, unicode): - key = key.encode("utf-8") - if isinstance(value, unicode): - value = value.encode("utf-8") - if isinstance(mk, unicode): - mk = mk.encode("utf-8") - args = [ - "while constructing a mapping", - node.start_mark, - 'found duplicate key "{}" with value "{}" ' - '(original value: "{}")'.format(key, value, mk), - key_node.start_mark, - """ - To suppress this check see: - http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys - """, - """\ - Duplicate keys will become an error in future releases, and are errors - by default when using the new API. - """, - ] - if self.allow_duplicate_keys is None: - warnings.warn(DuplicateKeyFutureWarning(*args)) - else: - raise DuplicateKeyError(*args) - return False - return True - - def check_set_key(self, node, key_node, setting, key): - # type: (Any, Any, Any, Any, Any) -> None - if key in setting: - if not self.allow_duplicate_keys: - if PY2: - if isinstance(key, unicode): - key = key.encode("utf-8") - args = [ - "while constructing a set", - node.start_mark, - 'found duplicate key "{}"'.format(key), - key_node.start_mark, - """ - To suppress this check see: - http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys - """, - """\ - Duplicate keys will become an error in future releases, and are errors - by default when using the new API. - """, - ] - if self.allow_duplicate_keys is None: - warnings.warn(DuplicateKeyFutureWarning(*args)) - else: - raise DuplicateKeyError(*args) - - def construct_pairs(self, node, deep=False): - # type: (Any, bool) -> Any - if not isinstance(node, MappingNode): - raise ConstructorError( - None, - None, - "expected a mapping node, but found %s" % node.id, - node.start_mark, - ) - pairs = [] - for key_node, value_node in node.value: - key = self.construct_object(key_node, deep=deep) - value = self.construct_object(value_node, deep=deep) - pairs.append((key, value)) - return pairs - - @classmethod - def add_constructor(cls, tag, constructor): - # type: (Any, Any) -> None - if "yaml_constructors" not in cls.__dict__: - cls.yaml_constructors = cls.yaml_constructors.copy() - cls.yaml_constructors[tag] = constructor - - @classmethod - def add_multi_constructor(cls, tag_prefix, multi_constructor): - # type: (Any, Any) -> None - if "yaml_multi_constructors" not in cls.__dict__: - cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() - cls.yaml_multi_constructors[tag_prefix] = multi_constructor - - -class SafeConstructor(BaseConstructor): - def construct_scalar(self, node): - # type: (Any) -> Any - if isinstance(node, MappingNode): - for key_node, value_node in node.value: - if key_node.tag == u"tag:yaml.org,2002:value": - return self.construct_scalar(value_node) - return BaseConstructor.construct_scalar(self, node) - - def flatten_mapping(self, node): - # type: (Any) -> Any - """ - This implements the merge key feature http://yaml.org/type/merge.html - by inserting keys from the merge dict/list of dicts if not yet - available in this node - """ - merge = [] # type: List[Any] - index = 0 - while index < len(node.value): - key_node, value_node = node.value[index] - if key_node.tag == u"tag:yaml.org,2002:merge": - if merge: # double << key - if self.allow_duplicate_keys: - del node.value[index] - index += 1 - continue - args = [ - "while constructing a mapping", - node.start_mark, - 'found duplicate key "{}"'.format(key_node.value), - key_node.start_mark, - """ - To suppress this check see: - http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys - """, - """\ - Duplicate keys will become an error in future releases, and are errors - by default when using the new API. - """, - ] - if self.allow_duplicate_keys is None: - warnings.warn(DuplicateKeyFutureWarning(*args)) - else: - raise DuplicateKeyError(*args) - del node.value[index] - if isinstance(value_node, MappingNode): - self.flatten_mapping(value_node) - merge.extend(value_node.value) - elif isinstance(value_node, SequenceNode): - submerge = [] - for subnode in value_node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "expected a mapping for merging, but found %s" - % subnode.id, - subnode.start_mark, - ) - self.flatten_mapping(subnode) - submerge.append(subnode.value) - submerge.reverse() - for value in submerge: - merge.extend(value) - else: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "expected a mapping or list of mappings for merging, " - "but found %s" % value_node.id, - value_node.start_mark, - ) - elif key_node.tag == u"tag:yaml.org,2002:value": - key_node.tag = u"tag:yaml.org,2002:str" - index += 1 - else: - index += 1 - if bool(merge): - node.merge = ( - merge - ) # separate merge keys to be able to update without duplicate - node.value = merge + node.value - - def construct_mapping(self, node, deep=False): - # type: (Any, bool) -> Any - """deep is True when creating an object/mapping recursively, - in that case want the underlying elements available during construction - """ - if isinstance(node, MappingNode): - self.flatten_mapping(node) - return BaseConstructor.construct_mapping(self, node, deep=deep) - - def construct_yaml_null(self, node): - # type: (Any) -> Any - self.construct_scalar(node) - return None - - # YAML 1.2 spec doesn't mention yes/no etc any more, 1.1 does - bool_values = { - u"yes": True, - u"no": False, - u"y": True, - u"n": False, - u"true": True, - u"false": False, - u"on": True, - u"off": False, - } - - def construct_yaml_bool(self, node): - # type: (Any) -> bool - value = self.construct_scalar(node) - return self.bool_values[value.lower()] - - def construct_yaml_int(self, node): - # type: (Any) -> int - value_s = to_str(self.construct_scalar(node)) - value_s = value_s.replace("_", "") - sign = +1 - if value_s[0] == "-": - sign = -1 - if value_s[0] in "+-": - value_s = value_s[1:] - if value_s == "0": - return 0 - elif value_s.startswith("0b"): - return sign * int(value_s[2:], 2) - elif value_s.startswith("0x"): - return sign * int(value_s[2:], 16) - elif value_s.startswith("0o"): - return sign * int(value_s[2:], 8) - elif self.resolver.processing_version == (1, 1) and value_s[0] == "0": - return sign * int(value_s, 8) - elif self.resolver.processing_version == (1, 1) and ":" in value_s: - digits = [int(part) for part in value_s.split(":")] - digits.reverse() - base = 1 - value = 0 - for digit in digits: - value += digit * base - base *= 60 - return sign * value - else: - return sign * int(value_s) - - inf_value = 1e300 - while inf_value != inf_value * inf_value: - inf_value *= inf_value - nan_value = -inf_value / inf_value # Trying to make a quiet NaN (like C99). - - def construct_yaml_float(self, node): - # type: (Any) -> float - value_so = to_str(self.construct_scalar(node)) - value_s = value_so.replace("_", "").lower() - sign = +1 - if value_s[0] == "-": - sign = -1 - if value_s[0] in "+-": - value_s = value_s[1:] - if value_s == ".inf": - return sign * self.inf_value - elif value_s == ".nan": - return self.nan_value - elif self.resolver.processing_version != (1, 2) and ":" in value_s: - digits = [float(part) for part in value_s.split(":")] - digits.reverse() - base = 1 - value = 0.0 - for digit in digits: - value += digit * base - base *= 60 - return sign * value - else: - if self.resolver.processing_version != (1, 2) and "e" in value_s: - # value_s is lower case independent of input - mantissa, exponent = value_s.split("e") - if "." not in mantissa: - warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so)) - return sign * float(value_s) - - if PY3: - - def construct_yaml_binary(self, node): - # type: (Any) -> Any - try: - value = self.construct_scalar(node).encode("ascii") - except UnicodeEncodeError as exc: - raise ConstructorError( - None, - None, - "failed to convert base64 data into ascii: %s" % exc, - node.start_mark, - ) - try: - if hasattr(base64, "decodebytes"): - return base64.decodebytes(value) - else: - return base64.decodestring(value) - except binascii.Error as exc: - raise ConstructorError( - None, - None, - "failed to decode base64 data: %s" % exc, - node.start_mark, - ) - - else: - - def construct_yaml_binary(self, node): - # type: (Any) -> Any - value = self.construct_scalar(node) - try: - return to_str(value).decode("base64") - except (binascii.Error, UnicodeEncodeError) as exc: - raise ConstructorError( - None, - None, - "failed to decode base64 data: %s" % exc, - node.start_mark, - ) - - timestamp_regexp = RegExp( - u"""^(?P[0-9][0-9][0-9][0-9]) - -(?P[0-9][0-9]?) - -(?P[0-9][0-9]?) - (?:((?P[Tt])|[ \\t]+) # explictly not retaining extra spaces - (?P[0-9][0-9]?) - :(?P[0-9][0-9]) - :(?P[0-9][0-9]) - (?:\\.(?P[0-9]*))? - (?:[ \\t]*(?PZ|(?P[-+])(?P[0-9][0-9]?) - (?::(?P[0-9][0-9]))?))?)?$""", - re.X, - ) - - def construct_yaml_timestamp(self, node, values=None): - # type: (Any, Any) -> Any - if values is None: - try: - match = self.timestamp_regexp.match(node.value) - except TypeError: - match = None - if match is None: - raise ConstructorError( - None, - None, - 'failed to construct timestamp from "{}"'.format(node.value), - node.start_mark, - ) - values = match.groupdict() - year = int(values["year"]) - month = int(values["month"]) - day = int(values["day"]) - if not values["hour"]: - return datetime.date(year, month, day) - hour = int(values["hour"]) - minute = int(values["minute"]) - second = int(values["second"]) - fraction = 0 - if values["fraction"]: - fraction_s = values["fraction"][:6] - while len(fraction_s) < 6: - fraction_s += "0" - fraction = int(fraction_s) - if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4: - fraction += 1 - delta = None - if values["tz_sign"]: - tz_hour = int(values["tz_hour"]) - minutes = values["tz_minute"] - tz_minute = int(minutes) if minutes else 0 - delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) - if values["tz_sign"] == "-": - delta = -delta - # should do something else instead (or hook this up to the preceding if statement - # in reverse - # if delta is None: - # return datetime.datetime(year, month, day, hour, minute, second, fraction) - # return datetime.datetime(year, month, day, hour, minute, second, fraction, - # datetime.timezone.utc) - # the above is not good enough though, should provide tzinfo. In Python3 that is easily - # doable drop that kind of support for Python2 as it has not native tzinfo - data = datetime.datetime(year, month, day, hour, minute, second, fraction) - if delta: - data -= delta - return data - - def construct_yaml_omap(self, node): - # type: (Any) -> Any - # Note: we do now check for duplicate keys - omap = ordereddict() - yield omap - if not isinstance(node, SequenceNode): - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a sequence, but found %s" % node.id, - node.start_mark, - ) - for subnode in node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a mapping of length 1, but found %s" % subnode.id, - subnode.start_mark, - ) - if len(subnode.value) != 1: - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a single mapping item, but found %d items" - % len(subnode.value), - subnode.start_mark, - ) - key_node, value_node = subnode.value[0] - key = self.construct_object(key_node) - assert key not in omap - value = self.construct_object(value_node) - omap[key] = value - - def construct_yaml_pairs(self, node): - # type: (Any) -> Any - # Note: the same code as `construct_yaml_omap`. - pairs = [] # type: List[Any] - yield pairs - if not isinstance(node, SequenceNode): - raise ConstructorError( - "while constructing pairs", - node.start_mark, - "expected a sequence, but found %s" % node.id, - node.start_mark, - ) - for subnode in node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError( - "while constructing pairs", - node.start_mark, - "expected a mapping of length 1, but found %s" % subnode.id, - subnode.start_mark, - ) - if len(subnode.value) != 1: - raise ConstructorError( - "while constructing pairs", - node.start_mark, - "expected a single mapping item, but found %d items" - % len(subnode.value), - subnode.start_mark, - ) - key_node, value_node = subnode.value[0] - key = self.construct_object(key_node) - value = self.construct_object(value_node) - pairs.append((key, value)) - - def construct_yaml_set(self, node): - # type: (Any) -> Any - data = set() # type: Set[Any] - yield data - value = self.construct_mapping(node) - data.update(value) - - def construct_yaml_str(self, node): - # type: (Any) -> Any - value = self.construct_scalar(node) - if PY3: - return value - try: - return value.encode("ascii") - except UnicodeEncodeError: - return value - - def construct_yaml_seq(self, node): - # type: (Any) -> Any - data = self.yaml_base_list_type() # type: List[Any] - yield data - data.extend(self.construct_sequence(node)) - - def construct_yaml_map(self, node): - # type: (Any) -> Any - data = self.yaml_base_dict_type() # type: Dict[Any, Any] - yield data - value = self.construct_mapping(node) - data.update(value) - - def construct_yaml_object(self, node, cls): - # type: (Any, Any) -> Any - data = cls.__new__(cls) - yield data - if hasattr(data, "__setstate__"): - state = self.construct_mapping(node, deep=True) - data.__setstate__(state) - else: - state = self.construct_mapping(node) - data.__dict__.update(state) - - def construct_undefined(self, node): - # type: (Any) -> None - raise ConstructorError( - None, - None, - "could not determine a constructor for the tag %r" % utf8(node.tag), - node.start_mark, - ) - - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:null", SafeConstructor.construct_yaml_null -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:bool", SafeConstructor.construct_yaml_bool -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:int", SafeConstructor.construct_yaml_int -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:float", SafeConstructor.construct_yaml_float -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:binary", SafeConstructor.construct_yaml_binary -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:timestamp", SafeConstructor.construct_yaml_timestamp -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:omap", SafeConstructor.construct_yaml_omap -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:pairs", SafeConstructor.construct_yaml_pairs -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:set", SafeConstructor.construct_yaml_set -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:str", SafeConstructor.construct_yaml_str -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:seq", SafeConstructor.construct_yaml_seq -) - -SafeConstructor.add_constructor( - u"tag:yaml.org,2002:map", SafeConstructor.construct_yaml_map -) - -SafeConstructor.add_constructor(None, SafeConstructor.construct_undefined) - -if PY2: - - class classobj: - pass - - -class Constructor(SafeConstructor): - def construct_python_str(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_unicode(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - if PY3: - - def construct_python_bytes(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_long(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_complex(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_tuple(self, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def find_python_module(self, name, mark): - raise ValueError("Unsafe constructor not implemented in this library") - - def find_python_name(self, name, mark): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_name(self, suffix, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_module(self, suffix, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False): - raise ValueError("Unsafe constructor not implemented in this library") - - def set_python_instance_state(self, instance, state): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_object(self, suffix, node): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_object_apply(self, suffix, node, newobj=False): - raise ValueError("Unsafe constructor not implemented in this library") - - def construct_python_object_new(self, suffix, node): - raise ValueError("Unsafe constructor not implemented in this library") - - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/none", Constructor.construct_yaml_null -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/bool", Constructor.construct_yaml_bool -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/str", Constructor.construct_python_str -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/unicode", Constructor.construct_python_unicode -) - -if PY3: - Constructor.add_constructor( - u"tag:yaml.org,2002:python/bytes", Constructor.construct_python_bytes - ) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/int", Constructor.construct_yaml_int -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/long", Constructor.construct_python_long -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/float", Constructor.construct_yaml_float -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/complex", Constructor.construct_python_complex -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/list", Constructor.construct_yaml_seq -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/tuple", Constructor.construct_python_tuple -) - -Constructor.add_constructor( - u"tag:yaml.org,2002:python/dict", Constructor.construct_yaml_map -) - -Constructor.add_multi_constructor( - u"tag:yaml.org,2002:python/name:", Constructor.construct_python_name -) - -Constructor.add_multi_constructor( - u"tag:yaml.org,2002:python/module:", Constructor.construct_python_module -) - -Constructor.add_multi_constructor( - u"tag:yaml.org,2002:python/object:", Constructor.construct_python_object -) - -Constructor.add_multi_constructor( - u"tag:yaml.org,2002:python/object/apply:", Constructor.construct_python_object_apply -) - -Constructor.add_multi_constructor( - u"tag:yaml.org,2002:python/object/new:", Constructor.construct_python_object_new -) - - -class RoundTripConstructor(SafeConstructor): - """need to store the comments on the node itself, - as well as on the items - """ - - def construct_scalar(self, node): - # type: (Any) -> Any - if not isinstance(node, ScalarNode): - raise ConstructorError( - None, - None, - "expected a scalar node, but found %s" % node.id, - node.start_mark, - ) - - if node.style == "|" and isinstance(node.value, text_type): - lss = LiteralScalarString(node.value, anchor=node.anchor) - if node.comment and node.comment[1]: - lss.comment = node.comment[1][0] # type: ignore - return lss - if node.style == ">" and isinstance(node.value, text_type): - fold_positions = [] # type: List[int] - idx = -1 - while True: - idx = node.value.find("\a", idx + 1) - if idx < 0: - break - fold_positions.append(idx - len(fold_positions)) - fss = FoldedScalarString(node.value.replace("\a", ""), anchor=node.anchor) - if node.comment and node.comment[1]: - fss.comment = node.comment[1][0] # type: ignore - if fold_positions: - fss.fold_pos = fold_positions # type: ignore - return fss - elif bool(self._preserve_quotes) and isinstance(node.value, text_type): - if node.style == "'": - return SingleQuotedScalarString(node.value, anchor=node.anchor) - if node.style == '"': - return DoubleQuotedScalarString(node.value, anchor=node.anchor) - if node.anchor: - return PlainScalarString(node.value, anchor=node.anchor) - return node.value - - def construct_yaml_int(self, node): - # type: (Any) -> Any - width = None # type: Any - value_su = to_str(self.construct_scalar(node)) - try: - sx = value_su.rstrip("_") - underscore = [len(sx) - sx.rindex("_") - 1, False, False] # type: Any - except ValueError: - underscore = None - except IndexError: - underscore = None - value_s = value_su.replace("_", "") - sign = +1 - if value_s[0] == "-": - sign = -1 - if value_s[0] in "+-": - value_s = value_s[1:] - if value_s == "0": - return 0 - elif value_s.startswith("0b"): - if self.resolver.processing_version > (1, 1) and value_s[2] == "0": - width = len(value_s[2:]) - if underscore is not None: - underscore[1] = value_su[2] == "_" - underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_" - return BinaryInt( - sign * int(value_s[2:], 2), - width=width, - underscore=underscore, - anchor=node.anchor, - ) - elif value_s.startswith("0x"): - # default to lower-case if no a-fA-F in string - if self.resolver.processing_version > (1, 1) and value_s[2] == "0": - width = len(value_s[2:]) - hex_fun = HexInt # type: Any - for ch in value_s[2:]: - if ch in "ABCDEF": # first non-digit is capital - hex_fun = HexCapsInt - break - if ch in "abcdef": - break - if underscore is not None: - underscore[1] = value_su[2] == "_" - underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_" - return hex_fun( - sign * int(value_s[2:], 16), - width=width, - underscore=underscore, - anchor=node.anchor, - ) - elif value_s.startswith("0o"): - if self.resolver.processing_version > (1, 1) and value_s[2] == "0": - width = len(value_s[2:]) - if underscore is not None: - underscore[1] = value_su[2] == "_" - underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_" - return OctalInt( - sign * int(value_s[2:], 8), - width=width, - underscore=underscore, - anchor=node.anchor, - ) - elif self.resolver.processing_version != (1, 2) and value_s[0] == "0": - return sign * int(value_s, 8) - elif self.resolver.processing_version != (1, 2) and ":" in value_s: - digits = [int(part) for part in value_s.split(":")] - digits.reverse() - base = 1 - value = 0 - for digit in digits: - value += digit * base - base *= 60 - return sign * value - elif self.resolver.processing_version > (1, 1) and value_s[0] == "0": - # not an octal, an integer with leading zero(s) - if underscore is not None: - # cannot have a leading underscore - underscore[2] = len(value_su) > 1 and value_su[-1] == "_" - return ScalarInt( - sign * int(value_s), width=len(value_s), underscore=underscore - ) - elif underscore: - # cannot have a leading underscore - underscore[2] = len(value_su) > 1 and value_su[-1] == "_" - return ScalarInt( - sign * int(value_s), - width=None, - underscore=underscore, - anchor=node.anchor, - ) - elif node.anchor: - return ScalarInt(sign * int(value_s), width=None, anchor=node.anchor) - else: - return sign * int(value_s) - - def construct_yaml_float(self, node): - # type: (Any) -> Any - def leading_zeros(v): - # type: (Any) -> int - lead0 = 0 - idx = 0 - while idx < len(v) and v[idx] in "0.": - if v[idx] == "0": - lead0 += 1 - idx += 1 - return lead0 - - # underscore = None - m_sign = False # type: Any - value_so = to_str(self.construct_scalar(node)) - value_s = value_so.replace("_", "").lower() - sign = +1 - if value_s[0] == "-": - sign = -1 - if value_s[0] in "+-": - m_sign = value_s[0] - value_s = value_s[1:] - if value_s == ".inf": - return sign * self.inf_value - if value_s == ".nan": - return self.nan_value - if self.resolver.processing_version != (1, 2) and ":" in value_s: - digits = [float(part) for part in value_s.split(":")] - digits.reverse() - base = 1 - value = 0.0 - for digit in digits: - value += digit * base - base *= 60 - return sign * value - if "e" in value_s: - try: - mantissa, exponent = value_so.split("e") - exp = "e" - except ValueError: - mantissa, exponent = value_so.split("E") - exp = "E" - if self.resolver.processing_version != (1, 2): - # value_s is lower case independent of input - if "." not in mantissa: - warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so)) - lead0 = leading_zeros(mantissa) - width = len(mantissa) - prec = mantissa.find(".") - if m_sign: - width -= 1 - e_width = len(exponent) - e_sign = exponent[0] in "+-" - # nprint('sf', width, prec, m_sign, exp, e_width, e_sign) - return ScalarFloat( - sign * float(value_s), - width=width, - prec=prec, - m_sign=m_sign, - m_lead0=lead0, - exp=exp, - e_width=e_width, - e_sign=e_sign, - anchor=node.anchor, - ) - width = len(value_so) - prec = value_so.index( - "." - ) # you can use index, this would not be float without dot - lead0 = leading_zeros(value_so) - return ScalarFloat( - sign * float(value_s), - width=width, - prec=prec, - m_sign=m_sign, - m_lead0=lead0, - anchor=node.anchor, - ) - - def construct_yaml_str(self, node): - # type: (Any) -> Any - value = self.construct_scalar(node) - if isinstance(value, ScalarString): - return value - if PY3: - return value - try: - return value.encode("ascii") - except AttributeError: - # in case you replace the node dynamically e.g. with a dict - return value - except UnicodeEncodeError: - return value - - def construct_rt_sequence(self, node, seqtyp, deep=False): - # type: (Any, Any, bool) -> Any - if not isinstance(node, SequenceNode): - raise ConstructorError( - None, - None, - "expected a sequence node, but found %s" % node.id, - node.start_mark, - ) - ret_val = [] - if node.comment: - seqtyp._yaml_add_comment(node.comment[:2]) - if len(node.comment) > 2: - seqtyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: - from .serializer import templated_id - - if not templated_id(node.anchor): - seqtyp.yaml_set_anchor(node.anchor) - for idx, child in enumerate(node.value): - if child.comment: - seqtyp._yaml_add_comment(child.comment, key=idx) - child.comment = None # if moved to sequence remove from child - ret_val.append(self.construct_object(child, deep=deep)) - seqtyp._yaml_set_idx_line_col( - idx, [child.start_mark.line, child.start_mark.column] - ) - return ret_val - - def flatten_mapping(self, node): - # type: (Any) -> Any - """ - This implements the merge key feature http://yaml.org/type/merge.html - by inserting keys from the merge dict/list of dicts if not yet - available in this node - """ - - def constructed(value_node): - # type: (Any) -> Any - # If the contents of a merge are defined within the - # merge marker, then they won't have been constructed - # yet. But if they were already constructed, we need to use - # the existing object. - if value_node in self.constructed_objects: - value = self.constructed_objects[value_node] - else: - value = self.construct_object(value_node, deep=False) - return value - - # merge = [] - merge_map_list = [] # type: List[Any] - index = 0 - while index < len(node.value): - key_node, value_node = node.value[index] - if key_node.tag == u"tag:yaml.org,2002:merge": - if merge_map_list: # double << key - if self.allow_duplicate_keys: - del node.value[index] - index += 1 - continue - args = [ - "while constructing a mapping", - node.start_mark, - 'found duplicate key "{}"'.format(key_node.value), - key_node.start_mark, - """ - To suppress this check see: - http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys - """, - """\ - Duplicate keys will become an error in future releases, and are errors - by default when using the new API. - """, - ] - if self.allow_duplicate_keys is None: - warnings.warn(DuplicateKeyFutureWarning(*args)) - else: - raise DuplicateKeyError(*args) - del node.value[index] - if isinstance(value_node, MappingNode): - merge_map_list.append((index, constructed(value_node))) - # self.flatten_mapping(value_node) - # merge.extend(value_node.value) - elif isinstance(value_node, SequenceNode): - # submerge = [] - for subnode in value_node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "expected a mapping for merging, but found %s" - % subnode.id, - subnode.start_mark, - ) - merge_map_list.append((index, constructed(subnode))) - # self.flatten_mapping(subnode) - # submerge.append(subnode.value) - # submerge.reverse() - # for value in submerge: - # merge.extend(value) - else: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "expected a mapping or list of mappings for merging, " - "but found %s" % value_node.id, - value_node.start_mark, - ) - elif key_node.tag == u"tag:yaml.org,2002:value": - key_node.tag = u"tag:yaml.org,2002:str" - index += 1 - else: - index += 1 - return merge_map_list - # if merge: - # node.value = merge + node.value - - def _sentinel(self): - # type: () -> None - pass - - def construct_mapping(self, node, maptyp, deep=False): # type: ignore - # type: (Any, Any, bool) -> Any - if not isinstance(node, MappingNode): - raise ConstructorError( - None, - None, - "expected a mapping node, but found %s" % node.id, - node.start_mark, - ) - merge_map = self.flatten_mapping(node) - # mapping = {} - if node.comment: - maptyp._yaml_add_comment(node.comment[:2]) - if len(node.comment) > 2: - maptyp.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: - from .serializer import templated_id - - if not templated_id(node.anchor): - maptyp.yaml_set_anchor(node.anchor) - last_key, last_value = None, self._sentinel - for key_node, value_node in node.value: - # keys can be list -> deep - key = self.construct_object(key_node, deep=True) - # lists are not hashable, but tuples are - if not isinstance(key, Hashable): - if isinstance(key, MutableSequence): - key_s = CommentedKeySeq(key) - if key_node.flow_style is True: - key_s.fa.set_flow_style() - elif key_node.flow_style is False: - key_s.fa.set_block_style() - key = key_s - elif isinstance(key, MutableMapping): - key_m = CommentedKeyMap(key) - if key_node.flow_style is True: - key_m.fa.set_flow_style() - elif key_node.flow_style is False: - key_m.fa.set_block_style() - key = key_m - if PY2: - try: - hash(key) - except TypeError as exc: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unacceptable key (%s)" % exc, - key_node.start_mark, - ) - else: - if not isinstance(key, Hashable): - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unhashable key", - key_node.start_mark, - ) - value = self.construct_object(value_node, deep=deep) - if self.check_mapping_key(node, key_node, maptyp, key, value): - if ( - key_node.comment - and len(key_node.comment) > 4 - and key_node.comment[4] - ): - if last_value is None: - key_node.comment[0] = key_node.comment.pop(4) - maptyp._yaml_add_comment(key_node.comment, value=last_key) - else: - key_node.comment[2] = key_node.comment.pop(4) - maptyp._yaml_add_comment(key_node.comment, key=key) - key_node.comment = None - if key_node.comment: - maptyp._yaml_add_comment(key_node.comment, key=key) - if value_node.comment: - maptyp._yaml_add_comment(value_node.comment, value=key) - maptyp._yaml_set_kv_line_col( - key, - [ - key_node.start_mark.line, - key_node.start_mark.column, - value_node.start_mark.line, - value_node.start_mark.column, - ], - ) - maptyp[key] = value - last_key, last_value = key, value # could use indexing - # do this last, or <<: before a key will prevent insertion in instances - # of collections.OrderedDict (as they have no __contains__ - if merge_map: - maptyp.add_yaml_merge(merge_map) - - def construct_setting(self, node, typ, deep=False): - # type: (Any, Any, bool) -> Any - if not isinstance(node, MappingNode): - raise ConstructorError( - None, - None, - "expected a mapping node, but found %s" % node.id, - node.start_mark, - ) - if node.comment: - typ._yaml_add_comment(node.comment[:2]) - if len(node.comment) > 2: - typ.yaml_end_comment_extend(node.comment[2], clear=True) - if node.anchor: - from .serializer import templated_id - - if not templated_id(node.anchor): - typ.yaml_set_anchor(node.anchor) - for key_node, value_node in node.value: - # keys can be list -> deep - key = self.construct_object(key_node, deep=True) - # lists are not hashable, but tuples are - if not isinstance(key, Hashable): - if isinstance(key, list): - key = tuple(key) - if PY2: - try: - hash(key) - except TypeError as exc: - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unacceptable key (%s)" % exc, - key_node.start_mark, - ) - else: - if not isinstance(key, Hashable): - raise ConstructorError( - "while constructing a mapping", - node.start_mark, - "found unhashable key", - key_node.start_mark, - ) - # construct but should be null - value = self.construct_object(value_node, deep=deep) # NOQA - self.check_set_key(node, key_node, typ, key) - if key_node.comment: - typ._yaml_add_comment(key_node.comment, key=key) - if value_node.comment: - typ._yaml_add_comment(value_node.comment, value=key) - typ.add(key) - - def construct_yaml_seq(self, node): - # type: (Any) -> Any - data = CommentedSeq() - data._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - if node.comment: - data._yaml_add_comment(node.comment) - yield data - data.extend(self.construct_rt_sequence(node, data)) - self.set_collection_style(data, node) - - def construct_yaml_map(self, node): - # type: (Any) -> Any - data = CommentedMap() - data._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - yield data - self.construct_mapping(node, data, deep=True) - self.set_collection_style(data, node) - - def set_collection_style(self, data, node): - # type: (Any, Any) -> None - if len(data) == 0: - return - if node.flow_style is True: - data.fa.set_flow_style() - elif node.flow_style is False: - data.fa.set_block_style() - - def construct_yaml_object(self, node, cls): - # type: (Any, Any) -> Any - data = cls.__new__(cls) - yield data - if hasattr(data, "__setstate__"): - state = SafeConstructor.construct_mapping(self, node, deep=True) - data.__setstate__(state) - else: - state = SafeConstructor.construct_mapping(self, node) - data.__dict__.update(state) - - def construct_yaml_omap(self, node): - # type: (Any) -> Any - # Note: we do now check for duplicate keys - omap = CommentedOrderedMap() - omap._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - if node.flow_style is True: - omap.fa.set_flow_style() - elif node.flow_style is False: - omap.fa.set_block_style() - yield omap - if node.comment: - omap._yaml_add_comment(node.comment[:2]) - if len(node.comment) > 2: - omap.yaml_end_comment_extend(node.comment[2], clear=True) - if not isinstance(node, SequenceNode): - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a sequence, but found %s" % node.id, - node.start_mark, - ) - for subnode in node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a mapping of length 1, but found %s" % subnode.id, - subnode.start_mark, - ) - if len(subnode.value) != 1: - raise ConstructorError( - "while constructing an ordered map", - node.start_mark, - "expected a single mapping item, but found %d items" - % len(subnode.value), - subnode.start_mark, - ) - key_node, value_node = subnode.value[0] - key = self.construct_object(key_node) - assert key not in omap - value = self.construct_object(value_node) - if key_node.comment: - omap._yaml_add_comment(key_node.comment, key=key) - if subnode.comment: - omap._yaml_add_comment(subnode.comment, key=key) - if value_node.comment: - omap._yaml_add_comment(value_node.comment, value=key) - omap[key] = value - - def construct_yaml_set(self, node): - # type: (Any) -> Any - data = CommentedSet() - data._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - yield data - self.construct_setting(node, data) - - def construct_undefined(self, node): - # type: (Any) -> Any - try: - if isinstance(node, MappingNode): - data = CommentedMap() - data._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - if node.flow_style is True: - data.fa.set_flow_style() - elif node.flow_style is False: - data.fa.set_block_style() - data.yaml_set_tag(node.tag) - yield data - if node.anchor: - data.yaml_set_anchor(node.anchor) - self.construct_mapping(node, data) - return - elif isinstance(node, ScalarNode): - data2 = TaggedScalar() - data2.value = self.construct_scalar(node) - data2.style = node.style - data2.yaml_set_tag(node.tag) - yield data2 - if node.anchor: - data2.yaml_set_anchor(node.anchor, always_dump=True) - return - elif isinstance(node, SequenceNode): - data3 = CommentedSeq() - data3._yaml_set_line_col(node.start_mark.line, node.start_mark.column) - if node.flow_style is True: - data3.fa.set_flow_style() - elif node.flow_style is False: - data3.fa.set_block_style() - data3.yaml_set_tag(node.tag) - yield data3 - if node.anchor: - data3.yaml_set_anchor(node.anchor) - data3.extend(self.construct_sequence(node)) - return - except: # NOQA - pass - raise ConstructorError( - None, - None, - "could not determine a constructor for the tag %r" % utf8(node.tag), - node.start_mark, - ) - - def construct_yaml_timestamp(self, node, values=None): - # type: (Any, Any) -> Any - try: - match = self.timestamp_regexp.match(node.value) - except TypeError: - match = None - if match is None: - raise ConstructorError( - None, - None, - 'failed to construct timestamp from "{}"'.format(node.value), - node.start_mark, - ) - values = match.groupdict() - if not values["hour"]: - return SafeConstructor.construct_yaml_timestamp(self, node, values) - for part in ["t", "tz_sign", "tz_hour", "tz_minute"]: - if values[part]: - break - else: - return SafeConstructor.construct_yaml_timestamp(self, node, values) - year = int(values["year"]) - month = int(values["month"]) - day = int(values["day"]) - hour = int(values["hour"]) - minute = int(values["minute"]) - second = int(values["second"]) - fraction = 0 - if values["fraction"]: - fraction_s = values["fraction"][:6] - while len(fraction_s) < 6: - fraction_s += "0" - fraction = int(fraction_s) - if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4: - fraction += 1 - delta = None - if values["tz_sign"]: - tz_hour = int(values["tz_hour"]) - minutes = values["tz_minute"] - tz_minute = int(minutes) if minutes else 0 - delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute) - if values["tz_sign"] == "-": - delta = -delta - if delta: - dt = datetime.datetime(year, month, day, hour, minute) - dt -= delta - data = TimeStamp( - dt.year, dt.month, dt.day, dt.hour, dt.minute, second, fraction - ) - data._yaml["delta"] = delta - tz = values["tz_sign"] + values["tz_hour"] - if values["tz_minute"]: - tz += ":" + values["tz_minute"] - data._yaml["tz"] = tz - else: - data = TimeStamp(year, month, day, hour, minute, second, fraction) - if values["tz"]: # no delta - data._yaml["tz"] = values["tz"] - - if values["t"]: - data._yaml["t"] = True - return data - - def construct_yaml_bool(self, node): - # type: (Any) -> Any - b = SafeConstructor.construct_yaml_bool(self, node) - if node.anchor: - return ScalarBoolean(b, anchor=node.anchor) - return b - - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:null", RoundTripConstructor.construct_yaml_null -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:bool", RoundTripConstructor.construct_yaml_bool -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:int", RoundTripConstructor.construct_yaml_int -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:float", RoundTripConstructor.construct_yaml_float -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:binary", RoundTripConstructor.construct_yaml_binary -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:timestamp", RoundTripConstructor.construct_yaml_timestamp -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:omap", RoundTripConstructor.construct_yaml_omap -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:pairs", RoundTripConstructor.construct_yaml_pairs -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:set", RoundTripConstructor.construct_yaml_set -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:str", RoundTripConstructor.construct_yaml_str -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:seq", RoundTripConstructor.construct_yaml_seq -) - -RoundTripConstructor.add_constructor( - u"tag:yaml.org,2002:map", RoundTripConstructor.construct_yaml_map -) - -RoundTripConstructor.add_constructor(None, RoundTripConstructor.construct_undefined) diff --git a/srsly/ruamel_yaml/cyaml.py b/srsly/ruamel_yaml/cyaml.py deleted file mode 100755 index 3375944..0000000 --- a/srsly/ruamel_yaml/cyaml.py +++ /dev/null @@ -1,192 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -from _ruamel_yaml import CParser, CEmitter # type: ignore - -from .constructor import Constructor, BaseConstructor, SafeConstructor -from .representer import Representer, SafeRepresenter, BaseRepresenter -from .resolver import Resolver, BaseResolver - -if False: # MYPY - from typing import Any, Union, Optional # NOQA - from .compat import StreamTextType, StreamType, VersionType # NOQA - -__all__ = [ - "CBaseLoader", - "CSafeLoader", - "CLoader", - "CBaseDumper", - "CSafeDumper", - "CDumper", -] - - -# this includes some hacks to solve the usage of resolver by lower level -# parts of the parser - - -class CBaseLoader(CParser, BaseConstructor, BaseResolver): # type: ignore - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - CParser.__init__(self, stream) - self._parser = self._composer = self - BaseConstructor.__init__(self, loader=self) - BaseResolver.__init__(self, loadumper=self) - # self.descend_resolver = self._resolver.descend_resolver - # self.ascend_resolver = self._resolver.ascend_resolver - # self.resolve = self._resolver.resolve - - -class CSafeLoader(CParser, SafeConstructor, Resolver): # type: ignore - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - CParser.__init__(self, stream) - self._parser = self._composer = self - SafeConstructor.__init__(self, loader=self) - Resolver.__init__(self, loadumper=self) - # self.descend_resolver = self._resolver.descend_resolver - # self.ascend_resolver = self._resolver.ascend_resolver - # self.resolve = self._resolver.resolve - - -class CLoader(CParser, Constructor, Resolver): # type: ignore - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - CParser.__init__(self, stream) - self._parser = self._composer = self - Constructor.__init__(self, loader=self) - Resolver.__init__(self, loadumper=self) - # self.descend_resolver = self._resolver.descend_resolver - # self.ascend_resolver = self._resolver.ascend_resolver - # self.resolve = self._resolver.resolve - - -class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): # type: ignore - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - CEmitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - encoding=encoding, - allow_unicode=allow_unicode, - line_break=line_break, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - ) - self._emitter = self._serializer = self._representer = self - BaseRepresenter.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=self, - ) - BaseResolver.__init__(self, loadumper=self) - - -class CSafeDumper(CEmitter, SafeRepresenter, Resolver): # type: ignore - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - self._emitter = self._serializer = self._representer = self - CEmitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - encoding=encoding, - allow_unicode=allow_unicode, - line_break=line_break, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - ) - self._emitter = self._serializer = self._representer = self - SafeRepresenter.__init__( - self, default_style=default_style, default_flow_style=default_flow_style - ) - Resolver.__init__(self) - - -class CDumper(CEmitter, Representer, Resolver): # type: ignore - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - CEmitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - encoding=encoding, - allow_unicode=allow_unicode, - line_break=line_break, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - ) - self._emitter = self._serializer = self._representer = self - Representer.__init__( - self, default_style=default_style, default_flow_style=default_flow_style - ) - Resolver.__init__(self) diff --git a/srsly/ruamel_yaml/dumper.py b/srsly/ruamel_yaml/dumper.py deleted file mode 100755 index 04d3b4d..0000000 --- a/srsly/ruamel_yaml/dumper.py +++ /dev/null @@ -1,221 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -from .emitter import Emitter -from .serializer import Serializer -from .representer import ( - Representer, - SafeRepresenter, - BaseRepresenter, - RoundTripRepresenter, -) -from .resolver import Resolver, BaseResolver, VersionedResolver - -if False: # MYPY - from typing import Any, Dict, List, Union, Optional # NOQA - from .compat import StreamType, VersionType # NOQA - -__all__ = ["BaseDumper", "SafeDumper", "Dumper", "RoundTripDumper"] - - -class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver): - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (Any, StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - Emitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - block_seq_indent=block_seq_indent, - dumper=self, - ) - Serializer.__init__( - self, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - dumper=self, - ) - BaseRepresenter.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=self, - ) - BaseResolver.__init__(self, loadumper=self) - - -class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver): - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - Emitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - block_seq_indent=block_seq_indent, - dumper=self, - ) - Serializer.__init__( - self, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - dumper=self, - ) - SafeRepresenter.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=self, - ) - Resolver.__init__(self, loadumper=self) - - -class Dumper(Emitter, Serializer, Representer, Resolver): - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - Emitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - block_seq_indent=block_seq_indent, - dumper=self, - ) - Serializer.__init__( - self, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - dumper=self, - ) - Representer.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=self, - ) - Resolver.__init__(self, loadumper=self) - - -class RoundTripDumper(Emitter, Serializer, RoundTripRepresenter, VersionedResolver): - def __init__( - self, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Optional[bool], Optional[int], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - Emitter.__init__( - self, - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - dumper=self, - ) - Serializer.__init__( - self, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - dumper=self, - ) - RoundTripRepresenter.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=self, - ) - VersionedResolver.__init__(self, loader=self) diff --git a/srsly/ruamel_yaml/emitter.py b/srsly/ruamel_yaml/emitter.py deleted file mode 100755 index 74a1660..0000000 --- a/srsly/ruamel_yaml/emitter.py +++ /dev/null @@ -1,1727 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import -from __future__ import print_function - -# Emitter expects events obeying the following grammar: -# stream ::= STREAM-START document* STREAM-END -# document ::= DOCUMENT-START node DOCUMENT-END -# node ::= SCALAR | sequence | mapping -# sequence ::= SEQUENCE-START node* SEQUENCE-END -# mapping ::= MAPPING-START (node node)* MAPPING-END - -import sys -from .error import YAMLError, YAMLStreamError -from .events import * # NOQA - -# fmt: off -from .compat import utf8, text_type, PY2, nprint, dbg, DBG_EVENT, \ - check_anchorname_char -# fmt: on - -if False: # MYPY - from typing import Any, Dict, List, Union, Text, Tuple, Optional # NOQA - from .compat import StreamType # NOQA - -__all__ = ["Emitter", "EmitterError"] - - -class EmitterError(YAMLError): - pass - - -class ScalarAnalysis(object): - def __init__( - self, - scalar, - empty, - multiline, - allow_flow_plain, - allow_block_plain, - allow_single_quoted, - allow_double_quoted, - allow_block, - ): - # type: (Any, Any, Any, bool, bool, bool, bool, bool) -> None - self.scalar = scalar - self.empty = empty - self.multiline = multiline - self.allow_flow_plain = allow_flow_plain - self.allow_block_plain = allow_block_plain - self.allow_single_quoted = allow_single_quoted - self.allow_double_quoted = allow_double_quoted - self.allow_block = allow_block - - -class Indents(object): - # replacement for the list based stack of None/int - def __init__(self): - # type: () -> None - self.values = [] # type: List[Tuple[int, bool]] - - def append(self, val, seq): - # type: (Any, Any) -> None - self.values.append((val, seq)) - - def pop(self): - # type: () -> Any - return self.values.pop()[0] - - def last_seq(self): - # type: () -> bool - # return the seq(uence) value for the element added before the last one - # in increase_indent() - try: - return self.values[-2][1] - except IndexError: - return False - - def seq_flow_align(self, seq_indent, column): - # type: (int, int) -> int - # extra spaces because of dash - if len(self.values) < 2 or not self.values[-1][1]: - return 0 - # -1 for the dash - base = self.values[-1][0] if self.values[-1][0] is not None else 0 - return base + seq_indent - column - 1 - - def __len__(self): - # type: () -> int - return len(self.values) - - -class Emitter(object): - # fmt: off - DEFAULT_TAG_PREFIXES = { - u'!': u'!', - u'tag:yaml.org,2002:': u'!!', - } - # fmt: on - - MAX_SIMPLE_KEY_LENGTH = 128 - - def __init__( - self, - stream, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - brace_single_entry_mapping_in_flow_sequence=None, - dumper=None, - ): - # type: (StreamType, Any, Optional[int], Optional[int], Optional[bool], Any, Optional[int], Optional[bool], Any, Optional[bool], Any) -> None # NOQA - self.dumper = dumper - if self.dumper is not None and getattr(self.dumper, "_emitter", None) is None: - self.dumper._emitter = self - self.stream = stream - - # Encoding can be overriden by STREAM-START. - self.encoding = None # type: Optional[Text] - self.allow_space_break = None - - # Emitter is a state machine with a stack of states to handle nested - # structures. - self.states = [] # type: List[Any] - self.state = self.expect_stream_start # type: Any - - # Current event and the event queue. - self.events = [] # type: List[Any] - self.event = None # type: Any - - # The current indentation level and the stack of previous indents. - self.indents = Indents() - self.indent = None # type: Optional[int] - - # flow_context is an expanding/shrinking list consisting of '{' and '[' - # for each unclosed flow context. If empty list that means block context - self.flow_context = [] # type: List[Text] - - # Contexts. - self.root_context = False - self.sequence_context = False - self.mapping_context = False - self.simple_key_context = False - - # Characteristics of the last emitted character: - # - current position. - # - is it a whitespace? - # - is it an indention character - # (indentation space, '-', '?', or ':')? - self.line = 0 - self.column = 0 - self.whitespace = True - self.indention = True - self.compact_seq_seq = True # dash after dash - self.compact_seq_map = True # key after dash - # self.compact_ms = False # dash after key, only when excplicit key with ? - self.no_newline = None # type: Optional[bool] # set if directly after `- ` - - # Whether the document requires an explicit document end indicator - self.open_ended = False - - # colon handling - self.colon = u":" - self.prefixed_colon = ( - self.colon if prefix_colon is None else prefix_colon + self.colon - ) - # single entry mappings in flow sequence - self.brace_single_entry_mapping_in_flow_sequence = ( - brace_single_entry_mapping_in_flow_sequence - ) # NOQA - - # Formatting details. - self.canonical = canonical - self.allow_unicode = allow_unicode - # set to False to get "\Uxxxxxxxx" for non-basic unicode like emojis - self.unicode_supplementary = sys.maxunicode > 0xFFFF - self.sequence_dash_offset = block_seq_indent if block_seq_indent else 0 - self.top_level_colon_align = top_level_colon_align - self.best_sequence_indent = 2 - self.requested_indent = indent # specific for literal zero indent - if indent and 1 < indent < 10: - self.best_sequence_indent = indent - self.best_map_indent = self.best_sequence_indent - # if self.best_sequence_indent < self.sequence_dash_offset + 1: - # self.best_sequence_indent = self.sequence_dash_offset + 1 - self.best_width = 80 - if width and width > self.best_sequence_indent * 2: - self.best_width = width - self.best_line_break = u"\n" # type: Any - if line_break in [u"\r", u"\n", u"\r\n"]: - self.best_line_break = line_break - - # Tag prefixes. - self.tag_prefixes = None # type: Any - - # Prepared anchor and tag. - self.prepared_anchor = None # type: Any - self.prepared_tag = None # type: Any - - # Scalar analysis and style. - self.analysis = None # type: Any - self.style = None # type: Any - - self.scalar_after_indicator = True # write a scalar on the same line as `---` - - @property - def stream(self): - # type: () -> Any - try: - return self._stream - except AttributeError: - raise YAMLStreamError("output stream needs to specified") - - @stream.setter - def stream(self, val): - # type: (Any) -> None - if val is None: - return - if not hasattr(val, "write"): - raise YAMLStreamError("stream argument needs to have a write() method") - self._stream = val - - @property - def serializer(self): - # type: () -> Any - try: - if hasattr(self.dumper, "typ"): - return self.dumper.serializer - return self.dumper._serializer - except AttributeError: - return self # cyaml - - @property - def flow_level(self): - # type: () -> int - return len(self.flow_context) - - def dispose(self): - # type: () -> None - # Reset the state attributes (to clear self-references) - self.states = [] - self.state = None - - def emit(self, event): - # type: (Any) -> None - if dbg(DBG_EVENT): - nprint(event) - self.events.append(event) - while not self.need_more_events(): - self.event = self.events.pop(0) - self.state() - self.event = None - - # In some cases, we wait for a few next events before emitting. - - def need_more_events(self): - # type: () -> bool - if not self.events: - return True - event = self.events[0] - if isinstance(event, DocumentStartEvent): - return self.need_events(1) - elif isinstance(event, SequenceStartEvent): - return self.need_events(2) - elif isinstance(event, MappingStartEvent): - return self.need_events(3) - else: - return False - - def need_events(self, count): - # type: (int) -> bool - level = 0 - for event in self.events[1:]: - if isinstance(event, (DocumentStartEvent, CollectionStartEvent)): - level += 1 - elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)): - level -= 1 - elif isinstance(event, StreamEndEvent): - level = -1 - if level < 0: - return False - return len(self.events) < count + 1 - - def increase_indent(self, flow=False, sequence=None, indentless=False): - # type: (bool, Optional[bool], bool) -> None - self.indents.append(self.indent, sequence) - if self.indent is None: # top level - if flow: - # self.indent = self.best_sequence_indent if self.indents.last_seq() else \ - # self.best_map_indent - # self.indent = self.best_sequence_indent - self.indent = self.requested_indent - else: - self.indent = 0 - elif not indentless: - self.indent += ( - self.best_sequence_indent - if self.indents.last_seq() - else self.best_map_indent - ) - # if self.indents.last_seq(): - # if self.indent == 0: # top level block sequence - # self.indent = self.best_sequence_indent - self.sequence_dash_offset - # else: - # self.indent += self.best_sequence_indent - # else: - # self.indent += self.best_map_indent - - # States. - - # Stream handlers. - - def expect_stream_start(self): - # type: () -> None - if isinstance(self.event, StreamStartEvent): - if PY2: - if self.event.encoding and not getattr(self.stream, "encoding", None): - self.encoding = self.event.encoding - else: - if self.event.encoding and not hasattr(self.stream, "encoding"): - self.encoding = self.event.encoding - self.write_stream_start() - self.state = self.expect_first_document_start - else: - raise EmitterError("expected StreamStartEvent, but got %s" % (self.event,)) - - def expect_nothing(self): - # type: () -> None - raise EmitterError("expected nothing, but got %s" % (self.event,)) - - # Document handlers. - - def expect_first_document_start(self): - # type: () -> Any - return self.expect_document_start(first=True) - - def expect_document_start(self, first=False): - # type: (bool) -> None - if isinstance(self.event, DocumentStartEvent): - if (self.event.version or self.event.tags) and self.open_ended: - self.write_indicator(u"...", True) - self.write_indent() - if self.event.version: - version_text = self.prepare_version(self.event.version) - self.write_version_directive(version_text) - self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy() - if self.event.tags: - handles = sorted(self.event.tags.keys()) - for handle in handles: - prefix = self.event.tags[handle] - self.tag_prefixes[prefix] = handle - handle_text = self.prepare_tag_handle(handle) - prefix_text = self.prepare_tag_prefix(prefix) - self.write_tag_directive(handle_text, prefix_text) - implicit = ( - first - and not self.event.explicit - and not self.canonical - and not self.event.version - and not self.event.tags - and not self.check_empty_document() - ) - if not implicit: - self.write_indent() - self.write_indicator(u"---", True) - if self.canonical: - self.write_indent() - self.state = self.expect_document_root - elif isinstance(self.event, StreamEndEvent): - if self.open_ended: - self.write_indicator(u"...", True) - self.write_indent() - self.write_stream_end() - self.state = self.expect_nothing - else: - raise EmitterError( - "expected DocumentStartEvent, but got %s" % (self.event,) - ) - - def expect_document_end(self): - # type: () -> None - if isinstance(self.event, DocumentEndEvent): - self.write_indent() - if self.event.explicit: - self.write_indicator(u"...", True) - self.write_indent() - self.flush_stream() - self.state = self.expect_document_start - else: - raise EmitterError("expected DocumentEndEvent, but got %s" % (self.event,)) - - def expect_document_root(self): - # type: () -> None - self.states.append(self.expect_document_end) - self.expect_node(root=True) - - # Node handlers. - - def expect_node(self, root=False, sequence=False, mapping=False, simple_key=False): - # type: (bool, bool, bool, bool) -> None - self.root_context = root - self.sequence_context = sequence # not used in PyYAML - self.mapping_context = mapping - self.simple_key_context = simple_key - if isinstance(self.event, AliasEvent): - self.expect_alias() - elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)): - if ( - self.process_anchor(u"&") - and isinstance(self.event, ScalarEvent) - and self.sequence_context - ): - self.sequence_context = False - if ( - root - and isinstance(self.event, ScalarEvent) - and not self.scalar_after_indicator - ): - self.write_indent() - self.process_tag() - if isinstance(self.event, ScalarEvent): - # nprint('@', self.indention, self.no_newline, self.column) - self.expect_scalar() - elif isinstance(self.event, SequenceStartEvent): - # nprint('@', self.indention, self.no_newline, self.column) - i2, n2 = self.indention, self.no_newline # NOQA - if self.event.comment: - if self.event.flow_style is False and self.event.comment: - if self.write_post_comment(self.event): - self.indention = False - self.no_newline = True - if self.write_pre_comment(self.event): - self.indention = i2 - self.no_newline = not self.indention - if ( - self.flow_level - or self.canonical - or self.event.flow_style - or self.check_empty_sequence() - ): - self.expect_flow_sequence() - else: - self.expect_block_sequence() - elif isinstance(self.event, MappingStartEvent): - if self.event.flow_style is False and self.event.comment: - self.write_post_comment(self.event) - if self.event.comment and self.event.comment[1]: - self.write_pre_comment(self.event) - if ( - self.flow_level - or self.canonical - or self.event.flow_style - or self.check_empty_mapping() - ): - self.expect_flow_mapping(single=self.event.nr_items == 1) - else: - self.expect_block_mapping() - else: - raise EmitterError("expected NodeEvent, but got %s" % (self.event,)) - - def expect_alias(self): - # type: () -> None - if self.event.anchor is None: - raise EmitterError("anchor is not specified for alias") - self.process_anchor(u"*") - self.state = self.states.pop() - - def expect_scalar(self): - # type: () -> None - self.increase_indent(flow=True) - self.process_scalar() - self.indent = self.indents.pop() - self.state = self.states.pop() - - # Flow sequence handlers. - - def expect_flow_sequence(self): - # type: () -> None - ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column) - self.write_indicator(u" " * ind + u"[", True, whitespace=True) - self.increase_indent(flow=True, sequence=True) - self.flow_context.append("[") - self.state = self.expect_first_flow_sequence_item - - def expect_first_flow_sequence_item(self): - # type: () -> None - if isinstance(self.event, SequenceEndEvent): - self.indent = self.indents.pop() - popped = self.flow_context.pop() - assert popped == "[" - self.write_indicator(u"]", False) - if self.event.comment and self.event.comment[0]: - # eol comment on empty flow sequence - self.write_post_comment(self.event) - elif self.flow_level == 0: - self.write_line_break() - self.state = self.states.pop() - else: - if self.canonical or self.column > self.best_width: - self.write_indent() - self.states.append(self.expect_flow_sequence_item) - self.expect_node(sequence=True) - - def expect_flow_sequence_item(self): - # type: () -> None - if isinstance(self.event, SequenceEndEvent): - self.indent = self.indents.pop() - popped = self.flow_context.pop() - assert popped == "[" - if self.canonical: - self.write_indicator(u",", False) - self.write_indent() - self.write_indicator(u"]", False) - if self.event.comment and self.event.comment[0]: - # eol comment on flow sequence - self.write_post_comment(self.event) - else: - self.no_newline = False - self.state = self.states.pop() - else: - self.write_indicator(u",", False) - if self.canonical or self.column > self.best_width: - self.write_indent() - self.states.append(self.expect_flow_sequence_item) - self.expect_node(sequence=True) - - # Flow mapping handlers. - - def expect_flow_mapping(self, single=False): - # type: (Optional[bool]) -> None - ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column) - map_init = u"{" - if ( - single - and self.flow_level - and self.flow_context[-1] == "[" - and not self.canonical - and not self.brace_single_entry_mapping_in_flow_sequence - ): - # single map item with flow context, no curly braces necessary - map_init = u"" - self.write_indicator(u" " * ind + map_init, True, whitespace=True) - self.flow_context.append(map_init) - self.increase_indent(flow=True, sequence=False) - self.state = self.expect_first_flow_mapping_key - - def expect_first_flow_mapping_key(self): - # type: () -> None - if isinstance(self.event, MappingEndEvent): - self.indent = self.indents.pop() - popped = self.flow_context.pop() - assert popped == "{" # empty flow mapping - self.write_indicator(u"}", False) - if self.event.comment and self.event.comment[0]: - # eol comment on empty mapping - self.write_post_comment(self.event) - elif self.flow_level == 0: - self.write_line_break() - self.state = self.states.pop() - else: - if self.canonical or self.column > self.best_width: - self.write_indent() - if not self.canonical and self.check_simple_key(): - self.states.append(self.expect_flow_mapping_simple_value) - self.expect_node(mapping=True, simple_key=True) - else: - self.write_indicator(u"?", True) - self.states.append(self.expect_flow_mapping_value) - self.expect_node(mapping=True) - - def expect_flow_mapping_key(self): - # type: () -> None - if isinstance(self.event, MappingEndEvent): - # if self.event.comment and self.event.comment[1]: - # self.write_pre_comment(self.event) - self.indent = self.indents.pop() - popped = self.flow_context.pop() - assert popped in [u"{", u""] - if self.canonical: - self.write_indicator(u",", False) - self.write_indent() - if popped != u"": - self.write_indicator(u"}", False) - if self.event.comment and self.event.comment[0]: - # eol comment on flow mapping, never reached on empty mappings - self.write_post_comment(self.event) - else: - self.no_newline = False - self.state = self.states.pop() - else: - self.write_indicator(u",", False) - if self.canonical or self.column > self.best_width: - self.write_indent() - if not self.canonical and self.check_simple_key(): - self.states.append(self.expect_flow_mapping_simple_value) - self.expect_node(mapping=True, simple_key=True) - else: - self.write_indicator(u"?", True) - self.states.append(self.expect_flow_mapping_value) - self.expect_node(mapping=True) - - def expect_flow_mapping_simple_value(self): - # type: () -> None - self.write_indicator(self.prefixed_colon, False) - self.states.append(self.expect_flow_mapping_key) - self.expect_node(mapping=True) - - def expect_flow_mapping_value(self): - # type: () -> None - if self.canonical or self.column > self.best_width: - self.write_indent() - self.write_indicator(self.prefixed_colon, True) - self.states.append(self.expect_flow_mapping_key) - self.expect_node(mapping=True) - - # Block sequence handlers. - - def expect_block_sequence(self): - # type: () -> None - if self.mapping_context: - indentless = not self.indention - else: - indentless = False - if not self.compact_seq_seq and self.column != 0: - self.write_line_break() - self.increase_indent(flow=False, sequence=True, indentless=indentless) - self.state = self.expect_first_block_sequence_item - - def expect_first_block_sequence_item(self): - # type: () -> Any - return self.expect_block_sequence_item(first=True) - - def expect_block_sequence_item(self, first=False): - # type: (bool) -> None - if not first and isinstance(self.event, SequenceEndEvent): - if self.event.comment and self.event.comment[1]: - # final comments on a block list e.g. empty line - self.write_pre_comment(self.event) - self.indent = self.indents.pop() - self.state = self.states.pop() - self.no_newline = False - else: - if self.event.comment and self.event.comment[1]: - self.write_pre_comment(self.event) - nonl = self.no_newline if self.column == 0 else False - self.write_indent() - ind = self.sequence_dash_offset # if len(self.indents) > 1 else 0 - self.write_indicator(u" " * ind + u"-", True, indention=True) - if nonl or self.sequence_dash_offset + 2 > self.best_sequence_indent: - self.no_newline = True - self.states.append(self.expect_block_sequence_item) - self.expect_node(sequence=True) - - # Block mapping handlers. - - def expect_block_mapping(self): - # type: () -> None - if not self.mapping_context and not (self.compact_seq_map or self.column == 0): - self.write_line_break() - self.increase_indent(flow=False, sequence=False) - self.state = self.expect_first_block_mapping_key - - def expect_first_block_mapping_key(self): - # type: () -> None - return self.expect_block_mapping_key(first=True) - - def expect_block_mapping_key(self, first=False): - # type: (Any) -> None - if not first and isinstance(self.event, MappingEndEvent): - if self.event.comment and self.event.comment[1]: - # final comments from a doc - self.write_pre_comment(self.event) - self.indent = self.indents.pop() - self.state = self.states.pop() - else: - if self.event.comment and self.event.comment[1]: - # final comments from a doc - self.write_pre_comment(self.event) - self.write_indent() - if self.check_simple_key(): - if not isinstance( - self.event, (SequenceStartEvent, MappingStartEvent) - ): # sequence keys - try: - if self.event.style == "?": - self.write_indicator(u"?", True, indention=True) - except AttributeError: # aliases have no style - pass - self.states.append(self.expect_block_mapping_simple_value) - self.expect_node(mapping=True, simple_key=True) - if isinstance(self.event, AliasEvent): - self.stream.write(u" ") - else: - self.write_indicator(u"?", True, indention=True) - self.states.append(self.expect_block_mapping_value) - self.expect_node(mapping=True) - - def expect_block_mapping_simple_value(self): - # type: () -> None - if getattr(self.event, "style", None) != "?": - # prefix = u'' - if self.indent == 0 and self.top_level_colon_align is not None: - # write non-prefixed colon - c = u" " * (self.top_level_colon_align - self.column) + self.colon - else: - c = self.prefixed_colon - self.write_indicator(c, False) - self.states.append(self.expect_block_mapping_key) - self.expect_node(mapping=True) - - def expect_block_mapping_value(self): - # type: () -> None - self.write_indent() - self.write_indicator(self.prefixed_colon, True, indention=True) - self.states.append(self.expect_block_mapping_key) - self.expect_node(mapping=True) - - # Checkers. - - def check_empty_sequence(self): - # type: () -> bool - return ( - isinstance(self.event, SequenceStartEvent) - and bool(self.events) - and isinstance(self.events[0], SequenceEndEvent) - ) - - def check_empty_mapping(self): - # type: () -> bool - return ( - isinstance(self.event, MappingStartEvent) - and bool(self.events) - and isinstance(self.events[0], MappingEndEvent) - ) - - def check_empty_document(self): - # type: () -> bool - if not isinstance(self.event, DocumentStartEvent) or not self.events: - return False - event = self.events[0] - return ( - isinstance(event, ScalarEvent) - and event.anchor is None - and event.tag is None - and event.implicit - and event.value == "" - ) - - def check_simple_key(self): - # type: () -> bool - length = 0 - if isinstance(self.event, NodeEvent) and self.event.anchor is not None: - if self.prepared_anchor is None: - self.prepared_anchor = self.prepare_anchor(self.event.anchor) - length += len(self.prepared_anchor) - if ( - isinstance(self.event, (ScalarEvent, CollectionStartEvent)) - and self.event.tag is not None - ): - if self.prepared_tag is None: - self.prepared_tag = self.prepare_tag(self.event.tag) - length += len(self.prepared_tag) - if isinstance(self.event, ScalarEvent): - if self.analysis is None: - self.analysis = self.analyze_scalar(self.event.value) - length += len(self.analysis.scalar) - return length < self.MAX_SIMPLE_KEY_LENGTH and ( - isinstance(self.event, AliasEvent) - or ( - isinstance(self.event, SequenceStartEvent) - and self.event.flow_style is True - ) - or ( - isinstance(self.event, MappingStartEvent) - and self.event.flow_style is True - ) - or ( - isinstance(self.event, ScalarEvent) - # if there is an explicit style for an empty string, it is a simple key - and not (self.analysis.empty and self.style and self.style not in "'\"") - and not self.analysis.multiline - ) - or self.check_empty_sequence() - or self.check_empty_mapping() - ) - - # Anchor, Tag, and Scalar processors. - - def process_anchor(self, indicator): - # type: (Any) -> bool - if self.event.anchor is None: - self.prepared_anchor = None - return False - if self.prepared_anchor is None: - self.prepared_anchor = self.prepare_anchor(self.event.anchor) - if self.prepared_anchor: - self.write_indicator(indicator + self.prepared_anchor, True) - # issue 288 - self.no_newline = False - self.prepared_anchor = None - return True - - def process_tag(self): - # type: () -> None - tag = self.event.tag - if isinstance(self.event, ScalarEvent): - if self.style is None: - self.style = self.choose_scalar_style() - if (not self.canonical or tag is None) and ( - (self.style == "" and self.event.implicit[0]) - or (self.style != "" and self.event.implicit[1]) - ): - self.prepared_tag = None - return - if self.event.implicit[0] and tag is None: - tag = u"!" - self.prepared_tag = None - else: - if (not self.canonical or tag is None) and self.event.implicit: - self.prepared_tag = None - return - if tag is None: - raise EmitterError("tag is not specified") - if self.prepared_tag is None: - self.prepared_tag = self.prepare_tag(tag) - if self.prepared_tag: - self.write_indicator(self.prepared_tag, True) - if ( - self.sequence_context - and not self.flow_level - and isinstance(self.event, ScalarEvent) - ): - self.no_newline = True - self.prepared_tag = None - - def choose_scalar_style(self): - # type: () -> Any - if self.analysis is None: - self.analysis = self.analyze_scalar(self.event.value) - if self.event.style == '"' or self.canonical: - return '"' - if (not self.event.style or self.event.style == "?") and ( - self.event.implicit[0] or not self.event.implicit[2] - ): - if not ( - self.simple_key_context - and (self.analysis.empty or self.analysis.multiline) - ) and ( - self.flow_level - and self.analysis.allow_flow_plain - or (not self.flow_level and self.analysis.allow_block_plain) - ): - return "" - self.analysis.allow_block = True - if self.event.style and self.event.style in "|>": - if ( - not self.flow_level - and not self.simple_key_context - and self.analysis.allow_block - ): - return self.event.style - if not self.event.style and self.analysis.allow_double_quoted: - if "'" in self.event.value or "\n" in self.event.value: - return '"' - if not self.event.style or self.event.style == "'": - if self.analysis.allow_single_quoted and not ( - self.simple_key_context and self.analysis.multiline - ): - return "'" - return '"' - - def process_scalar(self): - # type: () -> None - if self.analysis is None: - self.analysis = self.analyze_scalar(self.event.value) - if self.style is None: - self.style = self.choose_scalar_style() - split = not self.simple_key_context - # if self.analysis.multiline and split \ - # and (not self.style or self.style in '\'\"'): - # self.write_indent() - # nprint('xx', self.sequence_context, self.flow_level) - if self.sequence_context and not self.flow_level: - self.write_indent() - if self.style == '"': - self.write_double_quoted(self.analysis.scalar, split) - elif self.style == "'": - self.write_single_quoted(self.analysis.scalar, split) - elif self.style == ">": - self.write_folded(self.analysis.scalar) - elif self.style == "|": - self.write_literal(self.analysis.scalar, self.event.comment) - else: - self.write_plain(self.analysis.scalar, split) - self.analysis = None - self.style = None - if self.event.comment: - self.write_post_comment(self.event) - - # Analyzers. - - def prepare_version(self, version): - # type: (Any) -> Any - major, minor = version - if major != 1: - raise EmitterError("unsupported YAML version: %d.%d" % (major, minor)) - return u"%d.%d" % (major, minor) - - def prepare_tag_handle(self, handle): - # type: (Any) -> Any - if not handle: - raise EmitterError("tag handle must not be empty") - if handle[0] != u"!" or handle[-1] != u"!": - raise EmitterError( - "tag handle must start and end with '!': %r" % (utf8(handle)) - ) - for ch in handle[1:-1]: - if not ( - u"0" <= ch <= u"9" - or u"A" <= ch <= u"Z" - or u"a" <= ch <= u"z" - or ch in u"-_" - ): - raise EmitterError( - "invalid character %r in the tag handle: %r" - % (utf8(ch), utf8(handle)) - ) - return handle - - def prepare_tag_prefix(self, prefix): - # type: (Any) -> Any - if not prefix: - raise EmitterError("tag prefix must not be empty") - chunks = [] # type: List[Any] - start = end = 0 - if prefix[0] == u"!": - end = 1 - ch_set = u"-;/?:@&=+$,_.~*'()[]" - if self.dumper: - version = getattr(self.dumper, "version", (1, 2)) - if version is None or version >= (1, 2): - ch_set += u"#" - while end < len(prefix): - ch = prefix[end] - if ( - u"0" <= ch <= u"9" - or u"A" <= ch <= u"Z" - or u"a" <= ch <= u"z" - or ch in ch_set - ): - end += 1 - else: - if start < end: - chunks.append(prefix[start:end]) - start = end = end + 1 - data = utf8(ch) - for ch in data: - chunks.append(u"%%%02X" % ord(ch)) - if start < end: - chunks.append(prefix[start:end]) - return "".join(chunks) - - def prepare_tag(self, tag): - # type: (Any) -> Any - if not tag: - raise EmitterError("tag must not be empty") - if tag == u"!": - return tag - handle = None - suffix = tag - prefixes = sorted(self.tag_prefixes.keys()) - for prefix in prefixes: - if tag.startswith(prefix) and (prefix == u"!" or len(prefix) < len(tag)): - handle = self.tag_prefixes[prefix] - suffix = tag[len(prefix) :] - chunks = [] # type: List[Any] - start = end = 0 - ch_set = u"-;/?:@&=+$,_.~*'()[]" - if self.dumper: - version = getattr(self.dumper, "version", (1, 2)) - if version is None or version >= (1, 2): - ch_set += u"#" - while end < len(suffix): - ch = suffix[end] - if ( - u"0" <= ch <= u"9" - or u"A" <= ch <= u"Z" - or u"a" <= ch <= u"z" - or ch in ch_set - or (ch == u"!" and handle != u"!") - ): - end += 1 - else: - if start < end: - chunks.append(suffix[start:end]) - start = end = end + 1 - data = utf8(ch) - for ch in data: - chunks.append(u"%%%02X" % ord(ch)) - if start < end: - chunks.append(suffix[start:end]) - suffix_text = "".join(chunks) - if handle: - return u"%s%s" % (handle, suffix_text) - else: - return u"!<%s>" % suffix_text - - def prepare_anchor(self, anchor): - # type: (Any) -> Any - if not anchor: - raise EmitterError("anchor must not be empty") - for ch in anchor: - if not check_anchorname_char(ch): - raise EmitterError( - "invalid character %r in the anchor: %r" % (utf8(ch), utf8(anchor)) - ) - return anchor - - def analyze_scalar(self, scalar): - # type: (Any) -> Any - # Empty scalar is a special case. - if not scalar: - return ScalarAnalysis( - scalar=scalar, - empty=True, - multiline=False, - allow_flow_plain=False, - allow_block_plain=True, - allow_single_quoted=True, - allow_double_quoted=True, - allow_block=False, - ) - - # Indicators and special characters. - block_indicators = False - flow_indicators = False - line_breaks = False - special_characters = False - - # Important whitespace combinations. - leading_space = False - leading_break = False - trailing_space = False - trailing_break = False - break_space = False - space_break = False - - # Check document indicators. - if scalar.startswith(u"---") or scalar.startswith(u"..."): - block_indicators = True - flow_indicators = True - - # First character or preceded by a whitespace. - preceeded_by_whitespace = True - - # Last character or followed by a whitespace. - followed_by_whitespace = ( - len(scalar) == 1 or scalar[1] in u"\0 \t\r\n\x85\u2028\u2029" - ) - - # The previous character is a space. - previous_space = False - - # The previous character is a break. - previous_break = False - - index = 0 - while index < len(scalar): - ch = scalar[index] - - # Check for indicators. - if index == 0: - # Leading indicators are special characters. - if ch in u"#,[]{}&*!|>'\"%@`": - flow_indicators = True - block_indicators = True - if ch in u"?:": # ToDo - if self.serializer.use_version == (1, 1): - flow_indicators = True - elif len(scalar) == 1: # single character - flow_indicators = True - if followed_by_whitespace: - block_indicators = True - if ch == u"-" and followed_by_whitespace: - flow_indicators = True - block_indicators = True - else: - # Some indicators cannot appear within a scalar as well. - if ch in u",[]{}": # http://yaml.org/spec/1.2/spec.html#id2788859 - flow_indicators = True - if ch == u"?" and self.serializer.use_version == (1, 1): - flow_indicators = True - if ch == u":": - if followed_by_whitespace: - flow_indicators = True - block_indicators = True - if ch == u"#" and preceeded_by_whitespace: - flow_indicators = True - block_indicators = True - - # Check for line breaks, special, and unicode characters. - if ch in u"\n\x85\u2028\u2029": - line_breaks = True - if not (ch == u"\n" or u"\x20" <= ch <= u"\x7E"): - if ( - ch == u"\x85" - or u"\xA0" <= ch <= u"\uD7FF" - or u"\uE000" <= ch <= u"\uFFFD" - or ( - self.unicode_supplementary - and (u"\U00010000" <= ch <= u"\U0010FFFF") - ) - ) and ch != u"\uFEFF": - # unicode_characters = True - if not self.allow_unicode: - special_characters = True - else: - special_characters = True - - # Detect important whitespace combinations. - if ch == u" ": - if index == 0: - leading_space = True - if index == len(scalar) - 1: - trailing_space = True - if previous_break: - break_space = True - previous_space = True - previous_break = False - elif ch in u"\n\x85\u2028\u2029": - if index == 0: - leading_break = True - if index == len(scalar) - 1: - trailing_break = True - if previous_space: - space_break = True - previous_space = False - previous_break = True - else: - previous_space = False - previous_break = False - - # Prepare for the next character. - index += 1 - preceeded_by_whitespace = ch in u"\0 \t\r\n\x85\u2028\u2029" - followed_by_whitespace = ( - index + 1 >= len(scalar) - or scalar[index + 1] in u"\0 \t\r\n\x85\u2028\u2029" - ) - - # Let's decide what styles are allowed. - allow_flow_plain = True - allow_block_plain = True - allow_single_quoted = True - allow_double_quoted = True - allow_block = True - - # Leading and trailing whitespaces are bad for plain scalars. - if leading_space or leading_break or trailing_space or trailing_break: - allow_flow_plain = allow_block_plain = False - - # We do not permit trailing spaces for block scalars. - if trailing_space: - allow_block = False - - # Spaces at the beginning of a new line are only acceptable for block - # scalars. - if break_space: - allow_flow_plain = allow_block_plain = allow_single_quoted = False - - # Spaces followed by breaks, as well as special character are only - # allowed for double quoted scalars. - if special_characters: - allow_flow_plain = ( - allow_block_plain - ) = allow_single_quoted = allow_block = False - elif space_break: - allow_flow_plain = allow_block_plain = allow_single_quoted = False - if not self.allow_space_break: - allow_block = False - - # Although the plain scalar writer supports breaks, we never emit - # multiline plain scalars. - if line_breaks: - allow_flow_plain = allow_block_plain = False - - # Flow indicators are forbidden for flow plain scalars. - if flow_indicators: - allow_flow_plain = False - - # Block indicators are forbidden for block plain scalars. - if block_indicators: - allow_block_plain = False - - return ScalarAnalysis( - scalar=scalar, - empty=False, - multiline=line_breaks, - allow_flow_plain=allow_flow_plain, - allow_block_plain=allow_block_plain, - allow_single_quoted=allow_single_quoted, - allow_double_quoted=allow_double_quoted, - allow_block=allow_block, - ) - - # Writers. - - def flush_stream(self): - # type: () -> None - if hasattr(self.stream, "flush"): - self.stream.flush() - - def write_stream_start(self): - # type: () -> None - # Write BOM if needed. - if self.encoding and self.encoding.startswith("utf-16"): - self.stream.write(u"\uFEFF".encode(self.encoding)) - - def write_stream_end(self): - # type: () -> None - self.flush_stream() - - def write_indicator( - self, indicator, need_whitespace, whitespace=False, indention=False - ): - # type: (Any, Any, bool, bool) -> None - if self.whitespace or not need_whitespace: - data = indicator - else: - data = u" " + indicator - self.whitespace = whitespace - self.indention = self.indention and indention - self.column += len(data) - self.open_ended = False - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - - def write_indent(self): - # type: () -> None - indent = self.indent or 0 - if ( - not self.indention - or self.column > indent - or (self.column == indent and not self.whitespace) - ): - if bool(self.no_newline): - self.no_newline = False - else: - self.write_line_break() - if self.column < indent: - self.whitespace = True - data = u" " * (indent - self.column) - self.column = indent - if self.encoding: - data = data.encode(self.encoding) - self.stream.write(data) - - def write_line_break(self, data=None): - # type: (Any) -> None - if data is None: - data = self.best_line_break - self.whitespace = True - self.indention = True - self.line += 1 - self.column = 0 - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - - def write_version_directive(self, version_text): - # type: (Any) -> None - data = u"%%YAML %s" % version_text - if self.encoding: - data = data.encode(self.encoding) - self.stream.write(data) - self.write_line_break() - - def write_tag_directive(self, handle_text, prefix_text): - # type: (Any, Any) -> None - data = u"%%TAG %s %s" % (handle_text, prefix_text) - if self.encoding: - data = data.encode(self.encoding) - self.stream.write(data) - self.write_line_break() - - # Scalar streams. - - def write_single_quoted(self, text, split=True): - # type: (Any, Any) -> None - if self.root_context: - if self.requested_indent is not None: - self.write_line_break() - if self.requested_indent != 0: - self.write_indent() - self.write_indicator(u"'", True) - spaces = False - breaks = False - start = end = 0 - while end <= len(text): - ch = None - if end < len(text): - ch = text[end] - if spaces: - if ch is None or ch != u" ": - if ( - start + 1 == end - and self.column > self.best_width - and split - and start != 0 - and end != len(text) - ): - self.write_indent() - else: - data = text[start:end] - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end - elif breaks: - if ch is None or ch not in u"\n\x85\u2028\u2029": - if text[start] == u"\n": - self.write_line_break() - for br in text[start:end]: - if br == u"\n": - self.write_line_break() - else: - self.write_line_break(br) - self.write_indent() - start = end - else: - if ch is None or ch in u" \n\x85\u2028\u2029" or ch == u"'": - if start < end: - data = text[start:end] - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end - if ch == u"'": - data = u"''" - self.column += 2 - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end + 1 - if ch is not None: - spaces = ch == u" " - breaks = ch in u"\n\x85\u2028\u2029" - end += 1 - self.write_indicator(u"'", False) - - ESCAPE_REPLACEMENTS = { - u"\0": u"0", - u"\x07": u"a", - u"\x08": u"b", - u"\x09": u"t", - u"\x0A": u"n", - u"\x0B": u"v", - u"\x0C": u"f", - u"\x0D": u"r", - u"\x1B": u"e", - u'"': u'"', - u"\\": u"\\", - u"\x85": u"N", - u"\xA0": u"_", - u"\u2028": u"L", - u"\u2029": u"P", - } - - def write_double_quoted(self, text, split=True): - # type: (Any, Any) -> None - if self.root_context: - if self.requested_indent is not None: - self.write_line_break() - if self.requested_indent != 0: - self.write_indent() - self.write_indicator(u'"', True) - start = end = 0 - while end <= len(text): - ch = None - if end < len(text): - ch = text[end] - if ( - ch is None - or ch in u'"\\\x85\u2028\u2029\uFEFF' - or not ( - u"\x20" <= ch <= u"\x7E" - or ( - self.allow_unicode - and (u"\xA0" <= ch <= u"\uD7FF" or u"\uE000" <= ch <= u"\uFFFD") - ) - ) - ): - if start < end: - data = text[start:end] - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end - if ch is not None: - if ch in self.ESCAPE_REPLACEMENTS: - data = u"\\" + self.ESCAPE_REPLACEMENTS[ch] - elif ch <= u"\xFF": - data = u"\\x%02X" % ord(ch) - elif ch <= u"\uFFFF": - data = u"\\u%04X" % ord(ch) - else: - data = u"\\U%08X" % ord(ch) - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end + 1 - if ( - 0 < end < len(text) - 1 - and (ch == u" " or start >= end) - and self.column + (end - start) > self.best_width - and split - ): - data = text[start:end] + u"\\" - if start < end: - start = end - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - self.write_indent() - self.whitespace = False - self.indention = False - if text[start] == u" ": - data = u"\\" - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - end += 1 - self.write_indicator(u'"', False) - - def determine_block_hints(self, text): - # type: (Any) -> Any - indent = 0 - indicator = u"" - hints = u"" - if text: - if text[0] in u" \n\x85\u2028\u2029": - indent = self.best_sequence_indent - hints += text_type(indent) - elif self.root_context: - for end in ["\n---", "\n..."]: - pos = 0 - while True: - pos = text.find(end, pos) - if pos == -1: - break - try: - if text[pos + 4] in " \r\n": - break - except IndexError: - pass - pos += 1 - if pos > -1: - break - if pos > 0: - indent = self.best_sequence_indent - if text[-1] not in u"\n\x85\u2028\u2029": - indicator = u"-" - elif len(text) == 1 or text[-2] in u"\n\x85\u2028\u2029": - indicator = u"+" - hints += indicator - return hints, indent, indicator - - def write_folded(self, text): - # type: (Any) -> None - hints, _indent, _indicator = self.determine_block_hints(text) - self.write_indicator(u">" + hints, True) - if _indicator == u"+": - self.open_ended = True - self.write_line_break() - leading_space = True - spaces = False - breaks = True - start = end = 0 - while end <= len(text): - ch = None - if end < len(text): - ch = text[end] - if breaks: - if ch is None or ch not in u"\n\x85\u2028\u2029\a": - if ( - not leading_space - and ch is not None - and ch != u" " - and text[start] == u"\n" - ): - self.write_line_break() - leading_space = ch == u" " - for br in text[start:end]: - if br == u"\n": - self.write_line_break() - else: - self.write_line_break(br) - if ch is not None: - self.write_indent() - start = end - elif spaces: - if ch != u" ": - if start + 1 == end and self.column > self.best_width: - self.write_indent() - else: - data = text[start:end] - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - start = end - else: - if ch is None or ch in u" \n\x85\u2028\u2029\a": - data = text[start:end] - self.column += len(data) - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - if ch == u"\a": - if end < (len(text) - 1) and not text[end + 2].isspace(): - self.write_line_break() - self.write_indent() - end += 2 # \a and the space that is inserted on the fold - else: - raise EmitterError( - "unexcpected fold indicator \\a before space" - ) - if ch is None: - self.write_line_break() - start = end - if ch is not None: - breaks = ch in u"\n\x85\u2028\u2029" - spaces = ch == u" " - end += 1 - - def write_literal(self, text, comment=None): - # type: (Any, Any) -> None - hints, _indent, _indicator = self.determine_block_hints(text) - self.write_indicator(u"|" + hints, True) - try: - comment = comment[1][0] - if comment: - self.stream.write(comment) - except (TypeError, IndexError): - pass - if _indicator == u"+": - self.open_ended = True - self.write_line_break() - breaks = True - start = end = 0 - while end <= len(text): - ch = None - if end < len(text): - ch = text[end] - if breaks: - if ch is None or ch not in u"\n\x85\u2028\u2029": - for br in text[start:end]: - if br == u"\n": - self.write_line_break() - else: - self.write_line_break(br) - if ch is not None: - if self.root_context: - idnx = self.indent if self.indent is not None else 0 - self.stream.write(u" " * (_indent + idnx)) - else: - self.write_indent() - start = end - else: - if ch is None or ch in u"\n\x85\u2028\u2029": - data = text[start:end] - if bool(self.encoding): - data = data.encode(self.encoding) - self.stream.write(data) - if ch is None: - self.write_line_break() - start = end - if ch is not None: - breaks = ch in u"\n\x85\u2028\u2029" - end += 1 - - def write_plain(self, text, split=True): - # type: (Any, Any) -> None - if self.root_context: - if self.requested_indent is not None: - self.write_line_break() - if self.requested_indent != 0: - self.write_indent() - else: - self.open_ended = True - if not text: - return - if not self.whitespace: - data = u" " - self.column += len(data) - if self.encoding: - data = data.encode(self.encoding) - self.stream.write(data) - self.whitespace = False - self.indention = False - spaces = False - breaks = False - start = end = 0 - while end <= len(text): - ch = None - if end < len(text): - ch = text[end] - if spaces: - if ch != u" ": - if start + 1 == end and self.column > self.best_width and split: - self.write_indent() - self.whitespace = False - self.indention = False - else: - data = text[start:end] - self.column += len(data) - if self.encoding: - data = data.encode(self.encoding) - self.stream.write(data) - start = end - elif breaks: - if ch not in u"\n\x85\u2028\u2029": # type: ignore - if text[start] == u"\n": - self.write_line_break() - for br in text[start:end]: - if br == u"\n": - self.write_line_break() - else: - self.write_line_break(br) - self.write_indent() - self.whitespace = False - self.indention = False - start = end - else: - if ch is None or ch in u" \n\x85\u2028\u2029": - data = text[start:end] - self.column += len(data) - if self.encoding: - data = data.encode(self.encoding) - try: - self.stream.write(data) - except: # NOQA - sys.stdout.write(repr(data) + "\n") - raise - start = end - if ch is not None: - spaces = ch == u" " - breaks = ch in u"\n\x85\u2028\u2029" - end += 1 - - def write_comment(self, comment, pre=False): - # type: (Any, bool) -> None - value = comment.value - # nprintf('{:02d} {:02d} {!r}'.format(self.column, comment.start_mark.column, value)) - if not pre and value[-1] == "\n": - value = value[:-1] - try: - # get original column position - col = comment.start_mark.column - if comment.value and comment.value.startswith("\n"): - # never inject extra spaces if the comment starts with a newline - # and not a real comment (e.g. if you have an empty line following a key-value - col = self.column - elif col < self.column + 1: - ValueError - except ValueError: - col = self.column + 1 - # nprint('post_comment', self.line, self.column, value) - try: - # at least one space if the current column >= the start column of the comment - # but not at the start of a line - nr_spaces = col - self.column - if self.column and value.strip() and nr_spaces < 1 and value[0] != "\n": - nr_spaces = 1 - value = " " * nr_spaces + value - try: - if bool(self.encoding): - value = value.encode(self.encoding) - except UnicodeDecodeError: - pass - self.stream.write(value) - except TypeError: - raise - if not pre: - self.write_line_break() - - def write_pre_comment(self, event): - # type: (Any) -> bool - comments = event.comment[1] - if comments is None: - return False - try: - start_events = (MappingStartEvent, SequenceStartEvent) - for comment in comments: - if isinstance(event, start_events) and getattr( - comment, "pre_done", None - ): - continue - if self.column != 0: - self.write_line_break() - self.write_comment(comment, pre=True) - if isinstance(event, start_events): - comment.pre_done = True - except TypeError: - sys.stdout.write("eventtt {} {}".format(type(event), event)) - raise - return True - - def write_post_comment(self, event): - # type: (Any) -> bool - if self.event.comment[0] is None: - return False - comment = event.comment[0] - self.write_comment(comment) - return True diff --git a/srsly/ruamel_yaml/error.py b/srsly/ruamel_yaml/error.py deleted file mode 100755 index 93040ac..0000000 --- a/srsly/ruamel_yaml/error.py +++ /dev/null @@ -1,321 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -import warnings -import textwrap - -from .compat import utf8 - -if False: # MYPY - from typing import Any, Dict, Optional, List, Text # NOQA - - -__all__ = [ - "FileMark", - "StringMark", - "CommentMark", - "YAMLError", - "MarkedYAMLError", - "ReusedAnchorWarning", - "UnsafeLoaderWarning", - "MarkedYAMLWarning", - "MarkedYAMLFutureWarning", -] - - -class StreamMark(object): - __slots__ = "name", "index", "line", "column" - - def __init__(self, name, index, line, column): - # type: (Any, int, int, int) -> None - self.name = name - self.index = index - self.line = line - self.column = column - - def __str__(self): - # type: () -> Any - where = ' in "%s", line %d, column %d' % ( - self.name, - self.line + 1, - self.column + 1, - ) - return where - - def __eq__(self, other): - # type: (Any) -> bool - if self.line != other.line or self.column != other.column: - return False - if self.name != other.name or self.index != other.index: - return False - return True - - def __ne__(self, other): - # type: (Any) -> bool - return not self.__eq__(other) - - -class FileMark(StreamMark): - __slots__ = () - - -class StringMark(StreamMark): - __slots__ = "name", "index", "line", "column", "buffer", "pointer" - - def __init__(self, name, index, line, column, buffer, pointer): - # type: (Any, int, int, int, Any, Any) -> None - StreamMark.__init__(self, name, index, line, column) - self.buffer = buffer - self.pointer = pointer - - def get_snippet(self, indent=4, max_length=75): - # type: (int, int) -> Any - if self.buffer is None: # always False - return None - head = "" - start = self.pointer - while start > 0 and self.buffer[start - 1] not in u"\0\r\n\x85\u2028\u2029": - start -= 1 - if self.pointer - start > max_length / 2 - 1: - head = " ... " - start += 5 - break - tail = "" - end = self.pointer - while ( - end < len(self.buffer) and self.buffer[end] not in u"\0\r\n\x85\u2028\u2029" - ): - end += 1 - if end - self.pointer > max_length / 2 - 1: - tail = " ... " - end -= 5 - break - snippet = utf8(self.buffer[start:end]) - caret = "^" - caret = "^ (line: {})".format(self.line + 1) - return ( - " " * indent - + head - + snippet - + tail - + "\n" - + " " * (indent + self.pointer - start + len(head)) - + caret - ) - - def __str__(self): - # type: () -> Any - snippet = self.get_snippet() - where = ' in "%s", line %d, column %d' % ( - self.name, - self.line + 1, - self.column + 1, - ) - if snippet is not None: - where += ":\n" + snippet - return where - - -class CommentMark(object): - __slots__ = ("column",) - - def __init__(self, column): - # type: (Any) -> None - self.column = column - - -class YAMLError(Exception): - pass - - -class MarkedYAMLError(YAMLError): - def __init__( - self, - context=None, - context_mark=None, - problem=None, - problem_mark=None, - note=None, - warn=None, - ): - # type: (Any, Any, Any, Any, Any, Any) -> None - self.context = context - self.context_mark = context_mark - self.problem = problem - self.problem_mark = problem_mark - self.note = note - # warn is ignored - - def __str__(self): - # type: () -> Any - lines = [] # type: List[str] - if self.context is not None: - lines.append(self.context) - if self.context_mark is not None and ( - self.problem is None - or self.problem_mark is None - or self.context_mark.name != self.problem_mark.name - or self.context_mark.line != self.problem_mark.line - or self.context_mark.column != self.problem_mark.column - ): - lines.append(str(self.context_mark)) - if self.problem is not None: - lines.append(self.problem) - if self.problem_mark is not None: - lines.append(str(self.problem_mark)) - if self.note is not None and self.note: - note = textwrap.dedent(self.note) - lines.append(note) - return "\n".join(lines) - - -class YAMLStreamError(Exception): - pass - - -class YAMLWarning(Warning): - pass - - -class MarkedYAMLWarning(YAMLWarning): - def __init__( - self, - context=None, - context_mark=None, - problem=None, - problem_mark=None, - note=None, - warn=None, - ): - # type: (Any, Any, Any, Any, Any, Any) -> None - self.context = context - self.context_mark = context_mark - self.problem = problem - self.problem_mark = problem_mark - self.note = note - self.warn = warn - - def __str__(self): - # type: () -> Any - lines = [] # type: List[str] - if self.context is not None: - lines.append(self.context) - if self.context_mark is not None and ( - self.problem is None - or self.problem_mark is None - or self.context_mark.name != self.problem_mark.name - or self.context_mark.line != self.problem_mark.line - or self.context_mark.column != self.problem_mark.column - ): - lines.append(str(self.context_mark)) - if self.problem is not None: - lines.append(self.problem) - if self.problem_mark is not None: - lines.append(str(self.problem_mark)) - if self.note is not None and self.note: - note = textwrap.dedent(self.note) - lines.append(note) - if self.warn is not None and self.warn: - warn = textwrap.dedent(self.warn) - lines.append(warn) - return "\n".join(lines) - - -class ReusedAnchorWarning(YAMLWarning): - pass - - -class UnsafeLoaderWarning(YAMLWarning): - text = """ -The default 'Loader' for 'load(stream)' without further arguments can be unsafe. -Use 'load(stream, Loader=srsly.ruamel_yaml.Loader)' explicitly if that is OK. -Alternatively include the following in your code: - - import warnings - warnings.simplefilter('ignore', srsly.ruamel_yaml.error.UnsafeLoaderWarning) - -In most other cases you should consider using 'safe_load(stream)'""" - pass - - -warnings.simplefilter("once", UnsafeLoaderWarning) - - -class MantissaNoDotYAML1_1Warning(YAMLWarning): - def __init__(self, node, flt_str): - # type: (Any, Any) -> None - self.node = node - self.flt = flt_str - - def __str__(self): - # type: () -> Any - line = self.node.start_mark.line - col = self.node.start_mark.column - return """ -In YAML 1.1 floating point values should have a dot ('.') in their mantissa. -See the Floating-Point Language-Independent Type for YAML™ Version 1.1 specification -( http://yaml.org/type/float.html ). This dot is not required for JSON nor for YAML 1.2 - -Correct your float: "{}" on line: {}, column: {} - -or alternatively include the following in your code: - - import warnings - warnings.simplefilter('ignore', srsly.ruamel_yaml.error.MantissaNoDotYAML1_1Warning) - -""".format( - self.flt, line, col - ) - - -warnings.simplefilter("once", MantissaNoDotYAML1_1Warning) - - -class YAMLFutureWarning(Warning): - pass - - -class MarkedYAMLFutureWarning(YAMLFutureWarning): - def __init__( - self, - context=None, - context_mark=None, - problem=None, - problem_mark=None, - note=None, - warn=None, - ): - # type: (Any, Any, Any, Any, Any, Any) -> None - self.context = context - self.context_mark = context_mark - self.problem = problem - self.problem_mark = problem_mark - self.note = note - self.warn = warn - - def __str__(self): - # type: () -> Any - lines = [] # type: List[str] - if self.context is not None: - lines.append(self.context) - - if self.context_mark is not None and ( - self.problem is None - or self.problem_mark is None - or self.context_mark.name != self.problem_mark.name - or self.context_mark.line != self.problem_mark.line - or self.context_mark.column != self.problem_mark.column - ): - lines.append(str(self.context_mark)) - if self.problem is not None: - lines.append(self.problem) - if self.problem_mark is not None: - lines.append(str(self.problem_mark)) - if self.note is not None and self.note: - note = textwrap.dedent(self.note) - lines.append(note) - if self.warn is not None and self.warn: - warn = textwrap.dedent(self.warn) - lines.append(warn) - return "\n".join(lines) diff --git a/srsly/ruamel_yaml/events.py b/srsly/ruamel_yaml/events.py deleted file mode 100755 index 58b2121..0000000 --- a/srsly/ruamel_yaml/events.py +++ /dev/null @@ -1,157 +0,0 @@ -# coding: utf-8 - -# Abstract classes. - -if False: # MYPY - from typing import Any, Dict, Optional, List # NOQA - - -def CommentCheck(): - # type: () -> None - pass - - -class Event(object): - __slots__ = 'start_mark', 'end_mark', 'comment' - - def __init__(self, start_mark=None, end_mark=None, comment=CommentCheck): - # type: (Any, Any, Any) -> None - self.start_mark = start_mark - self.end_mark = end_mark - # assert comment is not CommentCheck - if comment is CommentCheck: - comment = None - self.comment = comment - - def __repr__(self): - # type: () -> Any - attributes = [ - key - for key in ['anchor', 'tag', 'implicit', 'value', 'flow_style', 'style'] - if hasattr(self, key) - ] - arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes]) - if self.comment not in [None, CommentCheck]: - arguments += ', comment={!r}'.format(self.comment) - return '%s(%s)' % (self.__class__.__name__, arguments) - - -class NodeEvent(Event): - __slots__ = ('anchor',) - - def __init__(self, anchor, start_mark=None, end_mark=None, comment=None): - # type: (Any, Any, Any, Any) -> None - Event.__init__(self, start_mark, end_mark, comment) - self.anchor = anchor - - -class CollectionStartEvent(NodeEvent): - __slots__ = 'tag', 'implicit', 'flow_style', 'nr_items' - - def __init__( - self, - anchor, - tag, - implicit, - start_mark=None, - end_mark=None, - flow_style=None, - comment=None, - nr_items=None, - ): - # type: (Any, Any, Any, Any, Any, Any, Any, Optional[int]) -> None - NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) - self.tag = tag - self.implicit = implicit - self.flow_style = flow_style - self.nr_items = nr_items - - -class CollectionEndEvent(Event): - __slots__ = () - - -# Implementations. - - -class StreamStartEvent(Event): - __slots__ = ('encoding',) - - def __init__(self, start_mark=None, end_mark=None, encoding=None, comment=None): - # type: (Any, Any, Any, Any) -> None - Event.__init__(self, start_mark, end_mark, comment) - self.encoding = encoding - - -class StreamEndEvent(Event): - __slots__ = () - - -class DocumentStartEvent(Event): - __slots__ = 'explicit', 'version', 'tags' - - def __init__( - self, - start_mark=None, - end_mark=None, - explicit=None, - version=None, - tags=None, - comment=None, - ): - # type: (Any, Any, Any, Any, Any, Any) -> None - Event.__init__(self, start_mark, end_mark, comment) - self.explicit = explicit - self.version = version - self.tags = tags - - -class DocumentEndEvent(Event): - __slots__ = ('explicit',) - - def __init__(self, start_mark=None, end_mark=None, explicit=None, comment=None): - # type: (Any, Any, Any, Any) -> None - Event.__init__(self, start_mark, end_mark, comment) - self.explicit = explicit - - -class AliasEvent(NodeEvent): - __slots__ = () - - -class ScalarEvent(NodeEvent): - __slots__ = 'tag', 'implicit', 'value', 'style' - - def __init__( - self, - anchor, - tag, - implicit, - value, - start_mark=None, - end_mark=None, - style=None, - comment=None, - ): - # type: (Any, Any, Any, Any, Any, Any, Any, Any) -> None - NodeEvent.__init__(self, anchor, start_mark, end_mark, comment) - self.tag = tag - self.implicit = implicit - self.value = value - self.style = style - - -class SequenceStartEvent(CollectionStartEvent): - __slots__ = () - - -class SequenceEndEvent(CollectionEndEvent): - __slots__ = () - - -class MappingStartEvent(CollectionStartEvent): - __slots__ = () - - -class MappingEndEvent(CollectionEndEvent): - __slots__ = () diff --git a/srsly/ruamel_yaml/loader.py b/srsly/ruamel_yaml/loader.py deleted file mode 100755 index 177cac2..0000000 --- a/srsly/ruamel_yaml/loader.py +++ /dev/null @@ -1,70 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - - -from .reader import Reader -from .scanner import Scanner, RoundTripScanner -from .parser import Parser, RoundTripParser -from .composer import Composer -from .constructor import ( - BaseConstructor, - SafeConstructor, - Constructor, - RoundTripConstructor, -) -from .resolver import VersionedResolver - -if False: # MYPY - from typing import Any, Dict, List, Union, Optional # NOQA - from .compat import StreamTextType, VersionType # NOQA - -__all__ = ["BaseLoader", "SafeLoader", "Loader", "RoundTripLoader"] - - -class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, VersionedResolver): - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - Reader.__init__(self, stream, loader=self) - Scanner.__init__(self, loader=self) - Parser.__init__(self, loader=self) - Composer.__init__(self, loader=self) - BaseConstructor.__init__(self, loader=self) - VersionedResolver.__init__(self, version, loader=self) - - -class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, VersionedResolver): - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - Reader.__init__(self, stream, loader=self) - Scanner.__init__(self, loader=self) - Parser.__init__(self, loader=self) - Composer.__init__(self, loader=self) - SafeConstructor.__init__(self, loader=self) - VersionedResolver.__init__(self, version, loader=self) - - -class Loader(Reader, Scanner, Parser, Composer, Constructor, VersionedResolver): - def __init__(self, stream, version=None, preserve_quotes=None): - raise ValueError("Unsafe loader not implemented in this library.") - - -class RoundTripLoader( - Reader, - RoundTripScanner, - RoundTripParser, - Composer, - RoundTripConstructor, - VersionedResolver, -): - def __init__(self, stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None - # self.reader = Reader.__init__(self, stream) - Reader.__init__(self, stream, loader=self) - RoundTripScanner.__init__(self, loader=self) - RoundTripParser.__init__(self, loader=self) - Composer.__init__(self, loader=self) - RoundTripConstructor.__init__( - self, preserve_quotes=preserve_quotes, loader=self - ) - VersionedResolver.__init__(self, version, loader=self) diff --git a/srsly/ruamel_yaml/main.py b/srsly/ruamel_yaml/main.py deleted file mode 100755 index 433cf21..0000000 --- a/srsly/ruamel_yaml/main.py +++ /dev/null @@ -1,1561 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import, unicode_literals, print_function - -import sys -import os -import warnings -import glob -from importlib import import_module - - -from . import resolver -from . import emitter -from . import representer -from . import parser -from . import composer -from . import constructor -from . import serializer -from . import scanner -from . import loader -from . import dumper -from . import reader -from .error import UnsafeLoaderWarning, YAMLError # NOQA - -from .tokens import * # NOQA -from .events import * # NOQA -from .nodes import * # NOQA - -from .loader import BaseLoader, SafeLoader, Loader, RoundTripLoader # NOQA -from .dumper import BaseDumper, SafeDumper, Dumper, RoundTripDumper # NOQA -from .compat import StringIO, BytesIO, with_metaclass, PY3, nprint -from .resolver import VersionedResolver, Resolver # NOQA -from .representer import ( - BaseRepresenter, - SafeRepresenter, - Representer, - RoundTripRepresenter, -) -from .constructor import ( - BaseConstructor, - SafeConstructor, - Constructor, - RoundTripConstructor, -) -from .loader import Loader as UnsafeLoader - -if False: # MYPY - from typing import List, Set, Dict, Union, Any, Callable, Optional, Text # NOQA - from .compat import StreamType, StreamTextType, VersionType # NOQA - - if PY3: - from pathlib import Path - else: - Path = Any - -try: - from _ruamel_yaml import CParser, CEmitter # type: ignore -except: # NOQA - CParser = CEmitter = None - -# import io - -enforce = object() - - -# YAML is an acronym, i.e. spoken: rhymes with "camel". And thus a -# subset of abbreviations, which should be all caps according to PEP8 - - -class YAML(object): - def __init__( - self, - _kw=enforce, - typ=None, - pure=False, - output=None, - plug_ins=None, # input=None, - ): - # type: (Any, Optional[Text], Any, Any, Any) -> None - """ - _kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..) - typ: 'rt'/None -> RoundTripLoader/RoundTripDumper, (default) - 'safe' -> SafeLoader/SafeDumper, - 'unsafe' -> normal/unsafe Loader/Dumper - 'base' -> baseloader - pure: if True only use Python modules - input/output: needed to work as context manager - plug_ins: a list of plug-in files - """ - if _kw is not enforce: - raise TypeError( - "{}.__init__() takes no positional argument but at least " - "one was given ({!r})".format(self.__class__.__name__, _kw) - ) - - self.typ = ["rt"] if typ is None else (typ if isinstance(typ, list) else [typ]) - self.pure = pure - - # self._input = input - self._output = output - self._context_manager = None # type: Any - - self.plug_ins = [] # type: List[Any] - for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins(): - file_name = pu.replace(os.sep, ".") - self.plug_ins.append(import_module(file_name)) - self.Resolver = resolver.VersionedResolver # type: Any - self.allow_unicode = True - self.Reader = None # type: Any - self.Representer = None # type: Any - self.Constructor = None # type: Any - self.Scanner = None # type: Any - self.Serializer = None # type: Any - self.default_flow_style = None # type: Any - typ_found = 1 - setup_rt = False - if "rt" in self.typ: - setup_rt = True - elif "safe" in self.typ: - self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter - self.Representer = representer.SafeRepresenter - self.Parser = parser.Parser if pure or CParser is None else CParser - self.Composer = composer.Composer - self.Constructor = constructor.SafeConstructor - elif "base" in self.typ: - self.Emitter = emitter.Emitter - self.Representer = representer.BaseRepresenter - self.Parser = parser.Parser if pure or CParser is None else CParser - self.Composer = composer.Composer - self.Constructor = constructor.BaseConstructor - elif "unsafe" in self.typ: - self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter - self.Representer = representer.Representer - self.Parser = parser.Parser if pure or CParser is None else CParser - self.Composer = composer.Composer - self.Constructor = constructor.Constructor - else: - setup_rt = True - typ_found = 0 - if setup_rt: - self.default_flow_style = False - # no optimized rt-dumper yet - self.Emitter = emitter.Emitter - self.Serializer = serializer.Serializer - self.Representer = representer.RoundTripRepresenter - self.Scanner = scanner.RoundTripScanner - # no optimized rt-parser yet - self.Parser = parser.RoundTripParser - self.Composer = composer.Composer - self.Constructor = constructor.RoundTripConstructor - del setup_rt - self.stream = None - self.canonical = None - self.old_indent = None - self.width = None - self.line_break = None - - self.map_indent = None - self.sequence_indent = None - self.sequence_dash_offset = 0 - self.compact_seq_seq = None - self.compact_seq_map = None - self.sort_base_mapping_type_on_output = None # default: sort - - self.top_level_colon_align = None - self.prefix_colon = None - self.version = None - self.preserve_quotes = None - self.allow_duplicate_keys = False # duplicate keys in map, set - self.encoding = "utf-8" - self.explicit_start = None - self.explicit_end = None - self.tags = None - self.default_style = None - self.top_level_block_style_scalar_no_indent_error_1_1 = False - # directives end indicator with single scalar document - self.scalar_after_indicator = None - # [a, b: 1, c: {d: 2}] vs. [a, {b: 1}, {c: {d: 2}}] - self.brace_single_entry_mapping_in_flow_sequence = False - for module in self.plug_ins: - if getattr(module, "typ", None) in self.typ: - typ_found += 1 - module.init_typ(self) - break - if typ_found == 0: - raise NotImplementedError( - 'typ "{}"not recognised (need to install plug-in?)'.format(self.typ) - ) - - @property - def reader(self): - # type: () -> Any - try: - return self._reader # type: ignore - except AttributeError: - self._reader = self.Reader(None, loader=self) - return self._reader - - @property - def scanner(self): - # type: () -> Any - try: - return self._scanner # type: ignore - except AttributeError: - self._scanner = self.Scanner(loader=self) - return self._scanner - - @property - def parser(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - if self.Parser is not CParser: - setattr(self, attr, self.Parser(loader=self)) - else: - if getattr(self, "_stream", None) is None: - # wait for the stream - return None - else: - # if not hasattr(self._stream, 'read') and hasattr(self._stream, 'open'): - # # pathlib.Path() instance - # setattr(self, attr, CParser(self._stream)) - # else: - setattr(self, attr, CParser(self._stream)) - # self._parser = self._composer = self - # nprint('scanner', self.loader.scanner) - - return getattr(self, attr) - - @property - def composer(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - setattr(self, attr, self.Composer(loader=self)) - return getattr(self, attr) - - @property - def constructor(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - cnst = self.Constructor(preserve_quotes=self.preserve_quotes, loader=self) - cnst.allow_duplicate_keys = self.allow_duplicate_keys - setattr(self, attr, cnst) - return getattr(self, attr) - - @property - def resolver(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - setattr(self, attr, self.Resolver(version=self.version, loader=self)) - return getattr(self, attr) - - @property - def emitter(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - if self.Emitter is not CEmitter: - _emitter = self.Emitter( - None, - canonical=self.canonical, - indent=self.old_indent, - width=self.width, - allow_unicode=self.allow_unicode, - line_break=self.line_break, - prefix_colon=self.prefix_colon, - brace_single_entry_mapping_in_flow_sequence=self.brace_single_entry_mapping_in_flow_sequence, # NOQA - dumper=self, - ) - setattr(self, attr, _emitter) - if self.map_indent is not None: - _emitter.best_map_indent = self.map_indent - if self.sequence_indent is not None: - _emitter.best_sequence_indent = self.sequence_indent - if self.sequence_dash_offset is not None: - _emitter.sequence_dash_offset = self.sequence_dash_offset - # _emitter.block_seq_indent = self.sequence_dash_offset - if self.compact_seq_seq is not None: - _emitter.compact_seq_seq = self.compact_seq_seq - if self.compact_seq_map is not None: - _emitter.compact_seq_map = self.compact_seq_map - else: - if getattr(self, "_stream", None) is None: - # wait for the stream - return None - return None - return getattr(self, attr) - - @property - def serializer(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - setattr( - self, - attr, - self.Serializer( - encoding=self.encoding, - explicit_start=self.explicit_start, - explicit_end=self.explicit_end, - version=self.version, - tags=self.tags, - dumper=self, - ), - ) - return getattr(self, attr) - - @property - def representer(self): - # type: () -> Any - attr = "_" + sys._getframe().f_code.co_name - if not hasattr(self, attr): - repres = self.Representer( - default_style=self.default_style, - default_flow_style=self.default_flow_style, - dumper=self, - ) - if self.sort_base_mapping_type_on_output is not None: - repres.sort_base_mapping_type_on_output = ( - self.sort_base_mapping_type_on_output - ) - setattr(self, attr, repres) - return getattr(self, attr) - - # separate output resolver? - - # def load(self, stream=None): - # if self._context_manager: - # if not self._input: - # raise TypeError("Missing input stream while dumping from context manager") - # for data in self._context_manager.load(): - # yield data - # return - # if stream is None: - # raise TypeError("Need a stream argument when not loading from context manager") - # return self.load_one(stream) - - def load(self, stream): - # type: (Union[Path, StreamTextType]) -> Any - """ - at this point you either have the non-pure Parser (which has its own reader and - scanner) or you have the pure Parser. - If the pure Parser is set, then set the Reader and Scanner, if not already set. - If either the Scanner or Reader are set, you cannot use the non-pure Parser, - so reset it to the pure parser and set the Reader resp. Scanner if necessary - """ - if not hasattr(stream, "read") and hasattr(stream, "open"): - # pathlib.Path() instance - with stream.open("rb") as fp: - return self.load(fp) - constructor, parser = self.get_constructor_parser(stream) - try: - return constructor.get_single_data() - finally: - parser.dispose() - try: - self._reader.reset_reader() - except AttributeError: - pass - try: - self._scanner.reset_scanner() - except AttributeError: - pass - - def load_all(self, stream, _kw=enforce): # , skip=None): - # type: (Union[Path, StreamTextType], Any) -> Any - if _kw is not enforce: - raise TypeError( - "{}.__init__() takes no positional argument but at least " - "one was given ({!r})".format(self.__class__.__name__, _kw) - ) - if not hasattr(stream, "read") and hasattr(stream, "open"): - # pathlib.Path() instance - with stream.open("r") as fp: - for d in self.load_all(fp, _kw=enforce): - yield d - return - # if skip is None: - # skip = [] - # elif isinstance(skip, int): - # skip = [skip] - constructor, parser = self.get_constructor_parser(stream) - try: - while constructor.check_data(): - yield constructor.get_data() - finally: - parser.dispose() - try: - self._reader.reset_reader() - except AttributeError: - pass - try: - self._scanner.reset_scanner() - except AttributeError: - pass - - def get_constructor_parser(self, stream): - # type: (StreamTextType) -> Any - """ - the old cyaml needs special setup, and therefore the stream - """ - if self.Parser is not CParser: - if self.Reader is None: - self.Reader = reader.Reader - if self.Scanner is None: - self.Scanner = scanner.Scanner - self.reader.stream = stream - else: - if self.Reader is not None: - if self.Scanner is None: - self.Scanner = scanner.Scanner - self.Parser = parser.Parser - self.reader.stream = stream - elif self.Scanner is not None: - if self.Reader is None: - self.Reader = reader.Reader - self.Parser = parser.Parser - self.reader.stream = stream - else: - # combined C level reader>scanner>parser - # does some calls to the resolver, e.g. BaseResolver.descend_resolver - # if you just initialise the CParser, to much of resolver.py - # is actually used - rslvr = self.Resolver - # if rslvr is srsly.ruamel_yaml.resolver.VersionedResolver: - # rslvr = srsly.ruamel_yaml.resolver.Resolver - - class XLoader(self.Parser, self.Constructor, rslvr): # type: ignore - def __init__( - selfx, stream, version=self.version, preserve_quotes=None - ): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None # NOQA - CParser.__init__(selfx, stream) - selfx._parser = selfx._composer = selfx - self.Constructor.__init__(selfx, loader=selfx) - selfx.allow_duplicate_keys = self.allow_duplicate_keys - rslvr.__init__(selfx, version=version, loadumper=selfx) - - self._stream = stream - loader = XLoader(stream) - return loader, loader - return self.constructor, self.parser - - def dump(self, data, stream=None, _kw=enforce, transform=None): - # type: (Any, Union[Path, StreamType], Any, Any) -> Any - if self._context_manager: - if not self._output: - raise TypeError( - "Missing output stream while dumping from context manager" - ) - if _kw is not enforce: - raise TypeError( - "{}.dump() takes one positional argument but at least " - "two were given ({!r})".format(self.__class__.__name__, _kw) - ) - if transform is not None: - raise TypeError( - "{}.dump() in the context manager cannot have transform keyword " - "".format(self.__class__.__name__) - ) - self._context_manager.dump(data) - else: # old style - if stream is None: - raise TypeError( - "Need a stream argument when not dumping from context manager" - ) - return self.dump_all([data], stream, _kw, transform=transform) - - def dump_all(self, documents, stream, _kw=enforce, transform=None): - # type: (Any, Union[Path, StreamType], Any, Any) -> Any - if self._context_manager: - raise NotImplementedError - if _kw is not enforce: - raise TypeError( - "{}.dump(_all) takes two positional argument but at least " - "three were given ({!r})".format(self.__class__.__name__, _kw) - ) - self._output = stream - self._context_manager = YAMLContextManager(self, transform=transform) - for data in documents: - self._context_manager.dump(data) - self._context_manager.teardown_output() - self._output = None - self._context_manager = None - - def Xdump_all(self, documents, stream, _kw=enforce, transform=None): - # type: (Any, Union[Path, StreamType], Any, Any) -> Any - """ - Serialize a sequence of Python objects into a YAML stream. - """ - if not hasattr(stream, "write") and hasattr(stream, "open"): - # pathlib.Path() instance - with stream.open("w") as fp: - return self.dump_all(documents, fp, _kw, transform=transform) - if _kw is not enforce: - raise TypeError( - "{}.dump(_all) takes two positional argument but at least " - "three were given ({!r})".format(self.__class__.__name__, _kw) - ) - # The stream should have the methods `write` and possibly `flush`. - if self.top_level_colon_align is True: - tlca = max([len(str(x)) for x in documents[0]]) # type: Any - else: - tlca = self.top_level_colon_align - if transform is not None: - fstream = stream - if self.encoding is None: - stream = StringIO() - else: - stream = BytesIO() - serializer, representer, emitter = self.get_serializer_representer_emitter( - stream, tlca - ) - try: - self.serializer.open() - for data in documents: - try: - self.representer.represent(data) - except AttributeError: - # nprint(dir(dumper._representer)) - raise - self.serializer.close() - finally: - try: - self.emitter.dispose() - except AttributeError: - raise - # self.dumper.dispose() # cyaml - delattr(self, "_serializer") - delattr(self, "_emitter") - if transform: - val = stream.getvalue() - if self.encoding: - val = val.decode(self.encoding) - if fstream is None: - transform(val) - else: - fstream.write(transform(val)) - return None - - def get_serializer_representer_emitter(self, stream, tlca): - # type: (StreamType, Any) -> Any - # we have only .Serializer to deal with (vs .Reader & .Scanner), much simpler - if self.Emitter is not CEmitter: - if self.Serializer is None: - self.Serializer = serializer.Serializer - self.emitter.stream = stream - self.emitter.top_level_colon_align = tlca - if self.scalar_after_indicator is not None: - self.emitter.scalar_after_indicator = self.scalar_after_indicator - return self.serializer, self.representer, self.emitter - if self.Serializer is not None: - # cannot set serializer with CEmitter - self.Emitter = emitter.Emitter - self.emitter.stream = stream - self.emitter.top_level_colon_align = tlca - if self.scalar_after_indicator is not None: - self.emitter.scalar_after_indicator = self.scalar_after_indicator - return self.serializer, self.representer, self.emitter - # C routines - - rslvr = resolver.BaseResolver if "base" in self.typ else resolver.Resolver - - class XDumper(CEmitter, self.Representer, rslvr): # type: ignore - def __init__( - selfx, - stream, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - ): - # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA - CEmitter.__init__( - selfx, - stream, - canonical=canonical, - indent=indent, - width=width, - encoding=encoding, - allow_unicode=allow_unicode, - line_break=line_break, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - ) - selfx._emitter = selfx._serializer = selfx._representer = selfx - self.Representer.__init__( - selfx, - default_style=default_style, - default_flow_style=default_flow_style, - ) - rslvr.__init__(selfx) - - self._stream = stream - dumper = XDumper( - stream, - default_style=self.default_style, - default_flow_style=self.default_flow_style, - canonical=self.canonical, - indent=self.old_indent, - width=self.width, - allow_unicode=self.allow_unicode, - line_break=self.line_break, - explicit_start=self.explicit_start, - explicit_end=self.explicit_end, - version=self.version, - tags=self.tags, - ) - self._emitter = self._serializer = dumper - return dumper, dumper, dumper - - # basic types - def map(self, **kw): - # type: (Any) -> Any - if "rt" in self.typ: - from .comments import CommentedMap - - return CommentedMap(**kw) - else: - return dict(**kw) - - def seq(self, *args): - # type: (Any) -> Any - if "rt" in self.typ: - from .comments import CommentedSeq - - return CommentedSeq(*args) - else: - return list(*args) - - # helpers - def official_plug_ins(self): - # type: () -> Any - bd = os.path.dirname(__file__) - gpbd = os.path.dirname(os.path.dirname(bd)) - res = [x.replace(gpbd, "")[1:-3] for x in glob.glob(bd + "/*/__plug_in__.py")] - return res - - def register_class(self, cls): - # type:(Any) -> Any - """ - register a class for dumping loading - - if it has attribute yaml_tag use that to register, else use class name - - if it has methods to_yaml/from_yaml use those to dump/load else dump attributes - as mapping - """ - tag = getattr(cls, "yaml_tag", "!" + cls.__name__) - try: - self.representer.add_representer(cls, cls.to_yaml) - except AttributeError: - - def t_y(representer, data): - # type: (Any, Any) -> Any - return representer.represent_yaml_object( - tag, data, cls, flow_style=representer.default_flow_style - ) - - self.representer.add_representer(cls, t_y) - try: - self.constructor.add_constructor(tag, cls.from_yaml) - except AttributeError: - - def f_y(constructor, node): - # type: (Any, Any) -> Any - return constructor.construct_yaml_object(node, cls) - - self.constructor.add_constructor(tag, f_y) - return cls - - def parse(self, stream): - # type: (StreamTextType) -> Any - """ - Parse a YAML stream and produce parsing events. - """ - _, parser = self.get_constructor_parser(stream) - try: - while parser.check_event(): - yield parser.get_event() - finally: - parser.dispose() - try: - self._reader.reset_reader() - except AttributeError: - pass - try: - self._scanner.reset_scanner() - except AttributeError: - pass - - # ### context manager - - def __enter__(self): - # type: () -> Any - self._context_manager = YAMLContextManager(self) - return self - - def __exit__(self, typ, value, traceback): - # type: (Any, Any, Any) -> None - if typ: - nprint("typ", typ) - self._context_manager.teardown_output() - # self._context_manager.teardown_input() - self._context_manager = None - - # ### backwards compatibility - def _indent(self, mapping=None, sequence=None, offset=None): - # type: (Any, Any, Any) -> None - if mapping is not None: - self.map_indent = mapping - if sequence is not None: - self.sequence_indent = sequence - if offset is not None: - self.sequence_dash_offset = offset - - @property - def indent(self): - # type: () -> Any - return self._indent - - @indent.setter - def indent(self, val): - # type: (Any) -> None - self.old_indent = val - - @property - def block_seq_indent(self): - # type: () -> Any - return self.sequence_dash_offset - - @block_seq_indent.setter - def block_seq_indent(self, val): - # type: (Any) -> None - self.sequence_dash_offset = val - - def compact(self, seq_seq=None, seq_map=None): - # type: (Any, Any) -> None - self.compact_seq_seq = seq_seq - self.compact_seq_map = seq_map - - -class YAMLContextManager(object): - def __init__(self, yaml, transform=None): - # type: (Any, Any) -> None # used to be: (Any, Optional[Callable]) -> None - self._yaml = yaml - self._output_inited = False - self._output_path = None - self._output = self._yaml._output - self._transform = transform - - # self._input_inited = False - # self._input = input - # self._input_path = None - # self._transform = yaml.transform - # self._fstream = None - - if not hasattr(self._output, "write") and hasattr(self._output, "open"): - # pathlib.Path() instance, open with the same mode - self._output_path = self._output - self._output = self._output_path.open("w") - - # if not hasattr(self._stream, 'write') and hasattr(stream, 'open'): - # if not hasattr(self._input, 'read') and hasattr(self._input, 'open'): - # # pathlib.Path() instance, open with the same mode - # self._input_path = self._input - # self._input = self._input_path.open('r') - - if self._transform is not None: - self._fstream = self._output - if self._yaml.encoding is None: - self._output = StringIO() - else: - self._output = BytesIO() - - def teardown_output(self): - # type: () -> None - if self._output_inited: - self._yaml.serializer.close() - else: - return - try: - self._yaml.emitter.dispose() - except AttributeError: - raise - # self.dumper.dispose() # cyaml - try: - delattr(self._yaml, "_serializer") - delattr(self._yaml, "_emitter") - except AttributeError: - raise - if self._transform: - val = self._output.getvalue() - if self._yaml.encoding: - val = val.decode(self._yaml.encoding) - if self._fstream is None: - self._transform(val) - else: - self._fstream.write(self._transform(val)) - self._fstream.flush() - self._output = self._fstream # maybe not necessary - if self._output_path is not None: - self._output.close() - - def init_output(self, first_data): - # type: (Any) -> None - if self._yaml.top_level_colon_align is True: - tlca = max([len(str(x)) for x in first_data]) # type: Any - else: - tlca = self._yaml.top_level_colon_align - self._yaml.get_serializer_representer_emitter(self._output, tlca) - self._yaml.serializer.open() - self._output_inited = True - - def dump(self, data): - # type: (Any) -> None - if not self._output_inited: - self.init_output(data) - try: - self._yaml.representer.represent(data) - except AttributeError: - # nprint(dir(dumper._representer)) - raise - - # def teardown_input(self): - # pass - # - # def init_input(self): - # # set the constructor and parser on YAML() instance - # self._yaml.get_constructor_parser(stream) - # - # def load(self): - # if not self._input_inited: - # self.init_input() - # try: - # while self._yaml.constructor.check_data(): - # yield self._yaml.constructor.get_data() - # finally: - # parser.dispose() - # try: - # self._reader.reset_reader() # type: ignore - # except AttributeError: - # pass - # try: - # self._scanner.reset_scanner() # type: ignore - # except AttributeError: - # pass - - -def yaml_object(yml): - # type: (Any) -> Any - """ decorator for classes that needs to dump/load objects - The tag for such objects is taken from the class attribute yaml_tag (or the - class name in lowercase in case unavailable) - If methods to_yaml and/or from_yaml are available, these are called for dumping resp. - loading, default routines (dumping a mapping of the attributes) used otherwise. - """ - - def yo_deco(cls): - # type: (Any) -> Any - tag = getattr(cls, "yaml_tag", "!" + cls.__name__) - try: - yml.representer.add_representer(cls, cls.to_yaml) - except AttributeError: - - def t_y(representer, data): - # type: (Any, Any) -> Any - return representer.represent_yaml_object( - tag, data, cls, flow_style=representer.default_flow_style - ) - - yml.representer.add_representer(cls, t_y) - try: - yml.constructor.add_constructor(tag, cls.from_yaml) - except AttributeError: - - def f_y(constructor, node): - # type: (Any, Any) -> Any - return constructor.construct_yaml_object(node, cls) - - yml.constructor.add_constructor(tag, f_y) - return cls - - return yo_deco - - -######################################################################################## - - -def scan(stream, Loader=Loader): - # type: (StreamTextType, Any) -> Any - """ - Scan a YAML stream and produce scanning tokens. - """ - loader = Loader(stream) - try: - while loader.scanner.check_token(): - yield loader.scanner.get_token() - finally: - loader._parser.dispose() - - -def parse(stream, Loader=Loader): - # type: (StreamTextType, Any) -> Any - """ - Parse a YAML stream and produce parsing events. - """ - loader = Loader(stream) - try: - while loader._parser.check_event(): - yield loader._parser.get_event() - finally: - loader._parser.dispose() - - -def compose(stream, Loader=Loader): - # type: (StreamTextType, Any) -> Any - """ - Parse the first YAML document in a stream - and produce the corresponding representation tree. - """ - loader = Loader(stream) - try: - return loader.get_single_node() - finally: - loader.dispose() - - -def compose_all(stream, Loader=Loader): - # type: (StreamTextType, Any) -> Any - """ - Parse all YAML documents in a stream - and produce corresponding representation trees. - """ - loader = Loader(stream) - try: - while loader.check_node(): - yield loader._composer.get_node() - finally: - loader._parser.dispose() - - -def load(stream, Loader=None, version=None, preserve_quotes=None): - # type: (StreamTextType, Any, Optional[VersionType], Any) -> Any - """ - Parse the first YAML document in a stream - and produce the corresponding Python object. - """ - if Loader is None: - warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2) - Loader = UnsafeLoader - loader = Loader(stream, version, preserve_quotes=preserve_quotes) - try: - return loader._constructor.get_single_data() - finally: - loader._parser.dispose() - try: - loader._reader.reset_reader() - except AttributeError: - pass - try: - loader._scanner.reset_scanner() - except AttributeError: - pass - - -def load_all(stream, Loader=None, version=None, preserve_quotes=None): - # type: (Optional[StreamTextType], Any, Optional[VersionType], Optional[bool]) -> Any # NOQA - """ - Parse all YAML documents in a stream - and produce corresponding Python objects. - """ - if Loader is None: - warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2) - Loader = UnsafeLoader - loader = Loader(stream, version, preserve_quotes=preserve_quotes) - try: - while loader._constructor.check_data(): - yield loader._constructor.get_data() - finally: - loader._parser.dispose() - try: - loader._reader.reset_reader() - except AttributeError: - pass - try: - loader._scanner.reset_scanner() - except AttributeError: - pass - - -def safe_load(stream, version=None): - # type: (StreamTextType, Optional[VersionType]) -> Any - """ - Parse the first YAML document in a stream - and produce the corresponding Python object. - Resolve only basic YAML tags. - """ - return load(stream, SafeLoader, version) - - -def safe_load_all(stream, version=None): - # type: (StreamTextType, Optional[VersionType]) -> Any - """ - Parse all YAML documents in a stream - and produce corresponding Python objects. - Resolve only basic YAML tags. - """ - return load_all(stream, SafeLoader, version) - - -def round_trip_load(stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any - """ - Parse the first YAML document in a stream - and produce the corresponding Python object. - Resolve only basic YAML tags. - """ - return load(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes) - - -def round_trip_load_all(stream, version=None, preserve_quotes=None): - # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any - """ - Parse all YAML documents in a stream - and produce corresponding Python objects. - Resolve only basic YAML tags. - """ - return load_all(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes) - - -def emit( - events, - stream=None, - Dumper=Dumper, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, -): - # type: (Any, Optional[StreamType], Any, Optional[bool], Union[int, None], Optional[int], Optional[bool], Any) -> Any # NOQA - """ - Emit YAML parsing events into a stream. - If stream is None, return the produced string instead. - """ - getvalue = None - if stream is None: - stream = StringIO() - getvalue = stream.getvalue - dumper = Dumper( - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - ) - try: - for event in events: - dumper.emit(event) - finally: - try: - dumper._emitter.dispose() - except AttributeError: - raise - dumper.dispose() # cyaml - if getvalue is not None: - return getvalue() - - -enc = None if PY3 else "utf-8" - - -def serialize_all( - nodes, - stream=None, - Dumper=Dumper, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=enc, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, -): - # type: (Any, Optional[StreamType], Any, Any, Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any) -> Any # NOQA - """ - Serialize a sequence of representation trees into a YAML stream. - If stream is None, return the produced string instead. - """ - getvalue = None - if stream is None: - if encoding is None: - stream = StringIO() - else: - stream = BytesIO() - getvalue = stream.getvalue - dumper = Dumper( - stream, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - encoding=encoding, - version=version, - tags=tags, - explicit_start=explicit_start, - explicit_end=explicit_end, - ) - try: - dumper._serializer.open() - for node in nodes: - dumper.serialize(node) - dumper._serializer.close() - finally: - try: - dumper._emitter.dispose() - except AttributeError: - raise - dumper.dispose() # cyaml - if getvalue is not None: - return getvalue() - - -def serialize(node, stream=None, Dumper=Dumper, **kwds): - # type: (Any, Optional[StreamType], Any, Any) -> Any - """ - Serialize a representation tree into a YAML stream. - If stream is None, return the produced string instead. - """ - return serialize_all([node], stream, Dumper=Dumper, **kwds) - - -def dump_all( - documents, - stream=None, - Dumper=Dumper, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=enc, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, -): - # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> Optional[str] # NOQA - """ - Serialize a sequence of Python objects into a YAML stream. - If stream is None, return the produced string instead. - """ - getvalue = None - if top_level_colon_align is True: - top_level_colon_align = max([len(str(x)) for x in documents[0]]) - if stream is None: - if encoding is None: - stream = StringIO() - else: - stream = BytesIO() - getvalue = stream.getvalue - dumper = Dumper( - stream, - default_style=default_style, - default_flow_style=default_flow_style, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - ) - try: - dumper._serializer.open() - for data in documents: - try: - dumper._representer.represent(data) - except AttributeError: - # nprint(dir(dumper._representer)) - raise - dumper._serializer.close() - finally: - try: - dumper._emitter.dispose() - except AttributeError: - raise - dumper.dispose() # cyaml - if getvalue is not None: - return getvalue() - return None - - -def dump( - data, - stream=None, - Dumper=Dumper, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=enc, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, -): - # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> Optional[str] # NOQA - """ - Serialize a Python object into a YAML stream. - If stream is None, return the produced string instead. - - default_style ∈ None, '', '"', "'", '|', '>' - - """ - return dump_all( - [data], - stream, - Dumper=Dumper, - default_style=default_style, - default_flow_style=default_flow_style, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - block_seq_indent=block_seq_indent, - ) - - -def safe_dump_all(documents, stream=None, **kwds): - # type: (Any, Optional[StreamType], Any) -> Optional[str] - """ - Serialize a sequence of Python objects into a YAML stream. - Produce only basic YAML tags. - If stream is None, return the produced string instead. - """ - return dump_all(documents, stream, Dumper=SafeDumper, **kwds) - - -def safe_dump(data, stream=None, **kwds): - # type: (Any, Optional[StreamType], Any) -> Optional[str] - """ - Serialize a Python object into a YAML stream. - Produce only basic YAML tags. - If stream is None, return the produced string instead. - """ - return dump_all([data], stream, Dumper=SafeDumper, **kwds) - - -def round_trip_dump( - data, - stream=None, - Dumper=RoundTripDumper, - default_style=None, - default_flow_style=None, - canonical=None, - indent=None, - width=None, - allow_unicode=None, - line_break=None, - encoding=enc, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, -): - # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any, Any, Any) -> Optional[str] # NOQA - allow_unicode = True if allow_unicode is None else allow_unicode - return dump_all( - [data], - stream, - Dumper=Dumper, - default_style=default_style, - default_flow_style=default_flow_style, - canonical=canonical, - indent=indent, - width=width, - allow_unicode=allow_unicode, - line_break=line_break, - encoding=encoding, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - tags=tags, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - ) - - -# Loader/Dumper are no longer composites, to get to the associated -# Resolver()/Representer(), etc., you need to instantiate the class - - -def add_implicit_resolver( - tag, regexp, first=None, Loader=None, Dumper=None, resolver=Resolver -): - # type: (Any, Any, Any, Any, Any, Any) -> None - """ - Add an implicit scalar detector. - If an implicit scalar value matches the given regexp, - the corresponding tag is assigned to the scalar. - first is a sequence of possible initial characters or None. - """ - if Loader is None and Dumper is None: - resolver.add_implicit_resolver(tag, regexp, first) - return - if Loader: - if hasattr(Loader, "add_implicit_resolver"): - Loader.add_implicit_resolver(tag, regexp, first) - elif issubclass( - Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader) - ): - Resolver.add_implicit_resolver(tag, regexp, first) - else: - raise NotImplementedError - if Dumper: - if hasattr(Dumper, "add_implicit_resolver"): - Dumper.add_implicit_resolver(tag, regexp, first) - elif issubclass( - Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper) - ): - Resolver.add_implicit_resolver(tag, regexp, first) - else: - raise NotImplementedError - - -# this code currently not tested -def add_path_resolver( - tag, path, kind=None, Loader=None, Dumper=None, resolver=Resolver -): - # type: (Any, Any, Any, Any, Any, Any) -> None - """ - Add a path based resolver for the given tag. - A path is a list of keys that forms a path - to a node in the representation tree. - Keys can be string values, integers, or None. - """ - if Loader is None and Dumper is None: - resolver.add_path_resolver(tag, path, kind) - return - if Loader: - if hasattr(Loader, "add_path_resolver"): - Loader.add_path_resolver(tag, path, kind) - elif issubclass( - Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader) - ): - Resolver.add_path_resolver(tag, path, kind) - else: - raise NotImplementedError - if Dumper: - if hasattr(Dumper, "add_path_resolver"): - Dumper.add_path_resolver(tag, path, kind) - elif issubclass( - Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper) - ): - Resolver.add_path_resolver(tag, path, kind) - else: - raise NotImplementedError - - -def add_constructor(tag, object_constructor, Loader=None, constructor=Constructor): - # type: (Any, Any, Any, Any) -> None - """ - Add an object constructor for the given tag. - object_onstructor is a function that accepts a Loader instance - and a node object and produces the corresponding Python object. - """ - if Loader is None: - constructor.add_constructor(tag, object_constructor) - else: - if hasattr(Loader, "add_constructor"): - Loader.add_constructor(tag, object_constructor) - return - if issubclass(Loader, BaseLoader): - BaseConstructor.add_constructor(tag, object_constructor) - elif issubclass(Loader, SafeLoader): - SafeConstructor.add_constructor(tag, object_constructor) - elif issubclass(Loader, Loader): - Constructor.add_constructor(tag, object_constructor) - elif issubclass(Loader, RoundTripLoader): - RoundTripConstructor.add_constructor(tag, object_constructor) - else: - raise NotImplementedError - - -def add_multi_constructor( - tag_prefix, multi_constructor, Loader=None, constructor=Constructor -): - # type: (Any, Any, Any, Any) -> None - """ - Add a multi-constructor for the given tag prefix. - Multi-constructor is called for a node if its tag starts with tag_prefix. - Multi-constructor accepts a Loader instance, a tag suffix, - and a node object and produces the corresponding Python object. - """ - if Loader is None: - constructor.add_multi_constructor(tag_prefix, multi_constructor) - else: - if False and hasattr(Loader, "add_multi_constructor"): - Loader.add_multi_constructor(tag_prefix, constructor) - return - if issubclass(Loader, BaseLoader): - BaseConstructor.add_multi_constructor(tag_prefix, multi_constructor) - elif issubclass(Loader, SafeLoader): - SafeConstructor.add_multi_constructor(tag_prefix, multi_constructor) - elif issubclass(Loader, loader.Loader): - Constructor.add_multi_constructor(tag_prefix, multi_constructor) - elif issubclass(Loader, RoundTripLoader): - RoundTripConstructor.add_multi_constructor(tag_prefix, multi_constructor) - else: - raise NotImplementedError - - -def add_representer( - data_type, object_representer, Dumper=None, representer=Representer -): - # type: (Any, Any, Any, Any) -> None - """ - Add a representer for the given type. - object_representer is a function accepting a Dumper instance - and an instance of the given data type - and producing the corresponding representation node. - """ - if Dumper is None: - representer.add_representer(data_type, object_representer) - else: - if hasattr(Dumper, "add_representer"): - Dumper.add_representer(data_type, object_representer) - return - if issubclass(Dumper, BaseDumper): - BaseRepresenter.add_representer(data_type, object_representer) - elif issubclass(Dumper, SafeDumper): - SafeRepresenter.add_representer(data_type, object_representer) - elif issubclass(Dumper, Dumper): - Representer.add_representer(data_type, object_representer) - elif issubclass(Dumper, RoundTripDumper): - RoundTripRepresenter.add_representer(data_type, object_representer) - else: - raise NotImplementedError - - -# this code currently not tested -def add_multi_representer( - data_type, multi_representer, Dumper=None, representer=Representer -): - # type: (Any, Any, Any, Any) -> None - """ - Add a representer for the given type. - multi_representer is a function accepting a Dumper instance - and an instance of the given data type or subtype - and producing the corresponding representation node. - """ - if Dumper is None: - representer.add_multi_representer(data_type, multi_representer) - else: - if hasattr(Dumper, "add_multi_representer"): - Dumper.add_multi_representer(data_type, multi_representer) - return - if issubclass(Dumper, BaseDumper): - BaseRepresenter.add_multi_representer(data_type, multi_representer) - elif issubclass(Dumper, SafeDumper): - SafeRepresenter.add_multi_representer(data_type, multi_representer) - elif issubclass(Dumper, Dumper): - Representer.add_multi_representer(data_type, multi_representer) - elif issubclass(Dumper, RoundTripDumper): - RoundTripRepresenter.add_multi_representer(data_type, multi_representer) - else: - raise NotImplementedError - - -class YAMLObjectMetaclass(type): - """ - The metaclass for YAMLObject. - """ - - def __init__(cls, name, bases, kwds): - # type: (Any, Any, Any) -> None - super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds) - if "yaml_tag" in kwds and kwds["yaml_tag"] is not None: - cls.yaml_constructor.add_constructor( - cls.yaml_tag, cls.from_yaml - ) # type: ignore - cls.yaml_representer.add_representer(cls, cls.to_yaml) # type: ignore - - -class YAMLObject(with_metaclass(YAMLObjectMetaclass)): # type: ignore - """ - An object that can dump itself to a YAML stream - and load itself from a YAML stream. - """ - - __slots__ = () # no direct instantiation, so allow immutable subclasses - - yaml_constructor = Constructor - yaml_representer = Representer - - yaml_tag = None # type: Any - yaml_flow_style = None # type: Any - - @classmethod - def from_yaml(cls, constructor, node): - # type: (Any, Any) -> Any - """ - Convert a representation node to a Python object. - """ - return constructor.construct_yaml_object(node, cls) - - @classmethod - def to_yaml(cls, representer, data): - # type: (Any, Any) -> Any - """ - Convert a Python object to a representation node. - """ - return representer.represent_yaml_object( - cls.yaml_tag, data, cls, flow_style=cls.yaml_flow_style - ) diff --git a/srsly/ruamel_yaml/nodes.py b/srsly/ruamel_yaml/nodes.py deleted file mode 100755 index da86e9c..0000000 --- a/srsly/ruamel_yaml/nodes.py +++ /dev/null @@ -1,131 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -import sys -from .compat import string_types - -if False: # MYPY - from typing import Dict, Any, Text # NOQA - - -class Node(object): - __slots__ = 'tag', 'value', 'start_mark', 'end_mark', 'comment', 'anchor' - - def __init__(self, tag, value, start_mark, end_mark, comment=None, anchor=None): - # type: (Any, Any, Any, Any, Any, Any) -> None - self.tag = tag - self.value = value - self.start_mark = start_mark - self.end_mark = end_mark - self.comment = comment - self.anchor = anchor - - def __repr__(self): - # type: () -> str - value = self.value - # if isinstance(value, list): - # if len(value) == 0: - # value = '' - # elif len(value) == 1: - # value = '<1 item>' - # else: - # value = '<%d items>' % len(value) - # else: - # if len(value) > 75: - # value = repr(value[:70]+u' ... ') - # else: - # value = repr(value) - value = repr(value) - return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value) - - def dump(self, indent=0): - # type: (int) -> None - if isinstance(self.value, string_types): - sys.stdout.write( - '{}{}(tag={!r}, value={!r})\n'.format( - ' ' * indent, self.__class__.__name__, self.tag, self.value - ) - ) - if self.comment: - sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment)) - return - sys.stdout.write( - '{}{}(tag={!r})\n'.format(' ' * indent, self.__class__.__name__, self.tag) - ) - if self.comment: - sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment)) - for v in self.value: - if isinstance(v, tuple): - for v1 in v: - v1.dump(indent + 1) - elif isinstance(v, Node): - v.dump(indent + 1) - else: - sys.stdout.write('Node value type? {}\n'.format(type(v))) - - -class ScalarNode(Node): - """ - styles: - ? -> set() ? key, no value - " -> double quoted - ' -> single quoted - | -> literal style - > -> folding style - """ - - __slots__ = ('style',) - id = 'scalar' - - def __init__( - self, tag, value, start_mark=None, end_mark=None, style=None, comment=None, anchor=None - ): - # type: (Any, Any, Any, Any, Any, Any, Any) -> None - Node.__init__(self, tag, value, start_mark, end_mark, comment=comment, anchor=anchor) - self.style = style - - -class CollectionNode(Node): - __slots__ = ('flow_style',) - - def __init__( - self, - tag, - value, - start_mark=None, - end_mark=None, - flow_style=None, - comment=None, - anchor=None, - ): - # type: (Any, Any, Any, Any, Any, Any, Any) -> None - Node.__init__(self, tag, value, start_mark, end_mark, comment=comment) - self.flow_style = flow_style - self.anchor = anchor - - -class SequenceNode(CollectionNode): - __slots__ = () - id = 'sequence' - - -class MappingNode(CollectionNode): - __slots__ = ('merge',) - id = 'mapping' - - def __init__( - self, - tag, - value, - start_mark=None, - end_mark=None, - flow_style=None, - comment=None, - anchor=None, - ): - # type: (Any, Any, Any, Any, Any, Any, Any) -> None - CollectionNode.__init__( - self, tag, value, start_mark, end_mark, flow_style, comment, anchor - ) - self.merge = None diff --git a/srsly/ruamel_yaml/parser.py b/srsly/ruamel_yaml/parser.py deleted file mode 100755 index 437bd22..0000000 --- a/srsly/ruamel_yaml/parser.py +++ /dev/null @@ -1,844 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -# The following YAML grammar is LL(1) and is parsed by a recursive descent -# parser. -# -# stream ::= STREAM-START implicit_document? explicit_document* -# STREAM-END -# implicit_document ::= block_node DOCUMENT-END* -# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* -# block_node_or_indentless_sequence ::= -# ALIAS -# | properties (block_content | -# indentless_block_sequence)? -# | block_content -# | indentless_block_sequence -# block_node ::= ALIAS -# | properties block_content? -# | block_content -# flow_node ::= ALIAS -# | properties flow_content? -# | flow_content -# properties ::= TAG ANCHOR? | ANCHOR TAG? -# block_content ::= block_collection | flow_collection | SCALAR -# flow_content ::= flow_collection | SCALAR -# block_collection ::= block_sequence | block_mapping -# flow_collection ::= flow_sequence | flow_mapping -# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* -# BLOCK-END -# indentless_sequence ::= (BLOCK-ENTRY block_node?)+ -# block_mapping ::= BLOCK-MAPPING_START -# ((KEY block_node_or_indentless_sequence?)? -# (VALUE block_node_or_indentless_sequence?)?)* -# BLOCK-END -# flow_sequence ::= FLOW-SEQUENCE-START -# (flow_sequence_entry FLOW-ENTRY)* -# flow_sequence_entry? -# FLOW-SEQUENCE-END -# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -# flow_mapping ::= FLOW-MAPPING-START -# (flow_mapping_entry FLOW-ENTRY)* -# flow_mapping_entry? -# FLOW-MAPPING-END -# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? -# -# FIRST sets: -# -# stream: { STREAM-START } -# explicit_document: { DIRECTIVE DOCUMENT-START } -# implicit_document: FIRST(block_node) -# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START -# BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START } -# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START } -# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START -# FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } -# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR } -# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START } -# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } -# block_sequence: { BLOCK-SEQUENCE-START } -# block_mapping: { BLOCK-MAPPING-START } -# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR -# BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START -# FLOW-MAPPING-START BLOCK-ENTRY } -# indentless_sequence: { ENTRY } -# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START } -# flow_sequence: { FLOW-SEQUENCE-START } -# flow_mapping: { FLOW-MAPPING-START } -# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START -# FLOW-MAPPING-START KEY } -# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START -# FLOW-MAPPING-START KEY } - -# need to have full path with import, as pkg_resources tries to load parser.py in __init__.py -# only to not do anything with the package afterwards -# and for Jython too - - -from .error import MarkedYAMLError -from .tokens import * # NOQA -from .events import * # NOQA -from .scanner import Scanner, RoundTripScanner, ScannerError # NOQA -from .compat import utf8, nprint, nprintf # NOQA - -if False: # MYPY - from typing import Any, Dict, Optional, List # NOQA - -__all__ = ["Parser", "RoundTripParser", "ParserError"] - - -class ParserError(MarkedYAMLError): - pass - - -class Parser(object): - # Since writing a recursive-descendant parser is a straightforward task, we - # do not give many comments here. - - DEFAULT_TAGS = {u"!": u"!", u"!!": u"tag:yaml.org,2002:"} - - def __init__(self, loader): - # type: (Any) -> None - self.loader = loader - if self.loader is not None and getattr(self.loader, "_parser", None) is None: - self.loader._parser = self - self.reset_parser() - - def reset_parser(self): - # type: () -> None - # Reset the state attributes (to clear self-references) - self.current_event = None - self.tag_handles = {} # type: Dict[Any, Any] - self.states = [] # type: List[Any] - self.marks = [] # type: List[Any] - self.state = self.parse_stream_start # type: Any - - def dispose(self): - # type: () -> None - self.reset_parser() - - @property - def scanner(self): - # type: () -> Any - if hasattr(self.loader, "typ"): - return self.loader.scanner - return self.loader._scanner - - @property - def resolver(self): - # type: () -> Any - if hasattr(self.loader, "typ"): - return self.loader.resolver - return self.loader._resolver - - def check_event(self, *choices): - # type: (Any) -> bool - # Check the type of the next event. - if self.current_event is None: - if self.state: - self.current_event = self.state() - if self.current_event is not None: - if not choices: - return True - for choice in choices: - if isinstance(self.current_event, choice): - return True - return False - - def peek_event(self): - # type: () -> Any - # Get the next event. - if self.current_event is None: - if self.state: - self.current_event = self.state() - return self.current_event - - def get_event(self): - # type: () -> Any - # Get the next event and proceed further. - if self.current_event is None: - if self.state: - self.current_event = self.state() - value = self.current_event - self.current_event = None - return value - - # stream ::= STREAM-START implicit_document? explicit_document* - # STREAM-END - # implicit_document ::= block_node DOCUMENT-END* - # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END* - - def parse_stream_start(self): - # type: () -> Any - # Parse the stream start. - token = self.scanner.get_token() - token.move_comment(self.scanner.peek_token()) - event = StreamStartEvent( - token.start_mark, token.end_mark, encoding=token.encoding - ) - - # Prepare the next state. - self.state = self.parse_implicit_document_start - - return event - - def parse_implicit_document_start(self): - # type: () -> Any - # Parse an implicit document. - if not self.scanner.check_token( - DirectiveToken, DocumentStartToken, StreamEndToken - ): - self.tag_handles = self.DEFAULT_TAGS - token = self.scanner.peek_token() - start_mark = end_mark = token.start_mark - event = DocumentStartEvent(start_mark, end_mark, explicit=False) - - # Prepare the next state. - self.states.append(self.parse_document_end) - self.state = self.parse_block_node - - return event - - else: - return self.parse_document_start() - - def parse_document_start(self): - # type: () -> Any - # Parse any extra document end indicators. - while self.scanner.check_token(DocumentEndToken): - self.scanner.get_token() - # Parse an explicit document. - if not self.scanner.check_token(StreamEndToken): - token = self.scanner.peek_token() - start_mark = token.start_mark - version, tags = self.process_directives() - if not self.scanner.check_token(DocumentStartToken): - raise ParserError( - None, - None, - "expected '', but found %r" - % self.scanner.peek_token().id, - self.scanner.peek_token().start_mark, - ) - token = self.scanner.get_token() - end_mark = token.end_mark - # if self.loader is not None and \ - # end_mark.line != self.scanner.peek_token().start_mark.line: - # self.loader.scalar_after_indicator = False - event = DocumentStartEvent( - start_mark, end_mark, explicit=True, version=version, tags=tags - ) # type: Any - self.states.append(self.parse_document_end) - self.state = self.parse_document_content - else: - # Parse the end of the stream. - token = self.scanner.get_token() - event = StreamEndEvent( - token.start_mark, token.end_mark, comment=token.comment - ) - assert not self.states - assert not self.marks - self.state = None - return event - - def parse_document_end(self): - # type: () -> Any - # Parse the document end. - token = self.scanner.peek_token() - start_mark = end_mark = token.start_mark - explicit = False - if self.scanner.check_token(DocumentEndToken): - token = self.scanner.get_token() - end_mark = token.end_mark - explicit = True - event = DocumentEndEvent(start_mark, end_mark, explicit=explicit) - - # Prepare the next state. - if self.resolver.processing_version == (1, 1): - self.state = self.parse_document_start - else: - self.state = self.parse_implicit_document_start - - return event - - def parse_document_content(self): - # type: () -> Any - if self.scanner.check_token( - DirectiveToken, DocumentStartToken, DocumentEndToken, StreamEndToken - ): - event = self.process_empty_scalar(self.scanner.peek_token().start_mark) - self.state = self.states.pop() - return event - else: - return self.parse_block_node() - - def process_directives(self): - # type: () -> Any - yaml_version = None - self.tag_handles = {} - while self.scanner.check_token(DirectiveToken): - token = self.scanner.get_token() - if token.name == u"YAML": - if yaml_version is not None: - raise ParserError( - None, None, "found duplicate YAML directive", token.start_mark - ) - major, minor = token.value - if major != 1: - raise ParserError( - None, - None, - "found incompatible YAML document (version 1.* is " "required)", - token.start_mark, - ) - yaml_version = token.value - elif token.name == u"TAG": - handle, prefix = token.value - if handle in self.tag_handles: - raise ParserError( - None, - None, - "duplicate tag handle %r" % utf8(handle), - token.start_mark, - ) - self.tag_handles[handle] = prefix - if bool(self.tag_handles): - value = yaml_version, self.tag_handles.copy() # type: Any - else: - value = yaml_version, None - if self.loader is not None and hasattr(self.loader, "tags"): - self.loader.version = yaml_version - if self.loader.tags is None: - self.loader.tags = {} - for k in self.tag_handles: - self.loader.tags[k] = self.tag_handles[k] - for key in self.DEFAULT_TAGS: - if key not in self.tag_handles: - self.tag_handles[key] = self.DEFAULT_TAGS[key] - return value - - # block_node_or_indentless_sequence ::= ALIAS - # | properties (block_content | indentless_block_sequence)? - # | block_content - # | indentless_block_sequence - # block_node ::= ALIAS - # | properties block_content? - # | block_content - # flow_node ::= ALIAS - # | properties flow_content? - # | flow_content - # properties ::= TAG ANCHOR? | ANCHOR TAG? - # block_content ::= block_collection | flow_collection | SCALAR - # flow_content ::= flow_collection | SCALAR - # block_collection ::= block_sequence | block_mapping - # flow_collection ::= flow_sequence | flow_mapping - - def parse_block_node(self): - # type: () -> Any - return self.parse_node(block=True) - - def parse_flow_node(self): - # type: () -> Any - return self.parse_node() - - def parse_block_node_or_indentless_sequence(self): - # type: () -> Any - return self.parse_node(block=True, indentless_sequence=True) - - def transform_tag(self, handle, suffix): - # type: (Any, Any) -> Any - return self.tag_handles[handle] + suffix - - def parse_node(self, block=False, indentless_sequence=False): - # type: (bool, bool) -> Any - if self.scanner.check_token(AliasToken): - token = self.scanner.get_token() - event = AliasEvent( - token.value, token.start_mark, token.end_mark - ) # type: Any - self.state = self.states.pop() - return event - - anchor = None - tag = None - start_mark = end_mark = tag_mark = None - if self.scanner.check_token(AnchorToken): - token = self.scanner.get_token() - start_mark = token.start_mark - end_mark = token.end_mark - anchor = token.value - if self.scanner.check_token(TagToken): - token = self.scanner.get_token() - tag_mark = token.start_mark - end_mark = token.end_mark - tag = token.value - elif self.scanner.check_token(TagToken): - token = self.scanner.get_token() - start_mark = tag_mark = token.start_mark - end_mark = token.end_mark - tag = token.value - if self.scanner.check_token(AnchorToken): - token = self.scanner.get_token() - start_mark = tag_mark = token.start_mark - end_mark = token.end_mark - anchor = token.value - if tag is not None: - handle, suffix = tag - if handle is not None: - if handle not in self.tag_handles: - raise ParserError( - "while parsing a node", - start_mark, - "found undefined tag handle %r" % utf8(handle), - tag_mark, - ) - tag = self.transform_tag(handle, suffix) - else: - tag = suffix - # if tag == u'!': - # raise ParserError("while parsing a node", start_mark, - # "found non-specific tag '!'", tag_mark, - # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' - # and share your opinion.") - if start_mark is None: - start_mark = end_mark = self.scanner.peek_token().start_mark - event = None - implicit = tag is None or tag == u"!" - if indentless_sequence and self.scanner.check_token(BlockEntryToken): - comment = None - pt = self.scanner.peek_token() - if pt.comment and pt.comment[0]: - comment = [pt.comment[0], []] - pt.comment[0] = None - end_mark = self.scanner.peek_token().end_mark - event = SequenceStartEvent( - anchor, - tag, - implicit, - start_mark, - end_mark, - flow_style=False, - comment=comment, - ) - self.state = self.parse_indentless_sequence_entry - return event - - if self.scanner.check_token(ScalarToken): - token = self.scanner.get_token() - # self.scanner.peek_token_same_line_comment(token) - end_mark = token.end_mark - if (token.plain and tag is None) or tag == u"!": - implicit = (True, False) - elif tag is None: - implicit = (False, True) - else: - implicit = (False, False) - # nprint('se', token.value, token.comment) - event = ScalarEvent( - anchor, - tag, - implicit, - token.value, - start_mark, - end_mark, - style=token.style, - comment=token.comment, - ) - self.state = self.states.pop() - elif self.scanner.check_token(FlowSequenceStartToken): - pt = self.scanner.peek_token() - end_mark = pt.end_mark - event = SequenceStartEvent( - anchor, - tag, - implicit, - start_mark, - end_mark, - flow_style=True, - comment=pt.comment, - ) - self.state = self.parse_flow_sequence_first_entry - elif self.scanner.check_token(FlowMappingStartToken): - pt = self.scanner.peek_token() - end_mark = pt.end_mark - event = MappingStartEvent( - anchor, - tag, - implicit, - start_mark, - end_mark, - flow_style=True, - comment=pt.comment, - ) - self.state = self.parse_flow_mapping_first_key - elif block and self.scanner.check_token(BlockSequenceStartToken): - end_mark = self.scanner.peek_token().start_mark - # should inserting the comment be dependent on the - # indentation? - pt = self.scanner.peek_token() - comment = pt.comment - # nprint('pt0', type(pt)) - if comment is None or comment[1] is None: - comment = pt.split_comment() - # nprint('pt1', comment) - event = SequenceStartEvent( - anchor, - tag, - implicit, - start_mark, - end_mark, - flow_style=False, - comment=comment, - ) - self.state = self.parse_block_sequence_first_entry - elif block and self.scanner.check_token(BlockMappingStartToken): - end_mark = self.scanner.peek_token().start_mark - comment = self.scanner.peek_token().comment - event = MappingStartEvent( - anchor, - tag, - implicit, - start_mark, - end_mark, - flow_style=False, - comment=comment, - ) - self.state = self.parse_block_mapping_first_key - elif anchor is not None or tag is not None: - # Empty scalars are allowed even if a tag or an anchor is - # specified. - event = ScalarEvent( - anchor, tag, (implicit, False), "", start_mark, end_mark - ) - self.state = self.states.pop() - else: - if block: - node = "block" - else: - node = "flow" - token = self.scanner.peek_token() - raise ParserError( - "while parsing a %s node" % node, - start_mark, - "expected the node content, but found %r" % token.id, - token.start_mark, - ) - return event - - # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* - # BLOCK-END - - def parse_block_sequence_first_entry(self): - # type: () -> Any - token = self.scanner.get_token() - # move any comment from start token - # token.move_comment(self.scanner.peek_token()) - self.marks.append(token.start_mark) - return self.parse_block_sequence_entry() - - def parse_block_sequence_entry(self): - # type: () -> Any - if self.scanner.check_token(BlockEntryToken): - token = self.scanner.get_token() - token.move_comment(self.scanner.peek_token()) - if not self.scanner.check_token(BlockEntryToken, BlockEndToken): - self.states.append(self.parse_block_sequence_entry) - return self.parse_block_node() - else: - self.state = self.parse_block_sequence_entry - return self.process_empty_scalar(token.end_mark) - if not self.scanner.check_token(BlockEndToken): - token = self.scanner.peek_token() - raise ParserError( - "while parsing a block collection", - self.marks[-1], - "expected , but found %r" % token.id, - token.start_mark, - ) - token = self.scanner.get_token() # BlockEndToken - event = SequenceEndEvent( - token.start_mark, token.end_mark, comment=token.comment - ) - self.state = self.states.pop() - self.marks.pop() - return event - - # indentless_sequence ::= (BLOCK-ENTRY block_node?)+ - - # indentless_sequence? - # sequence: - # - entry - # - nested - - def parse_indentless_sequence_entry(self): - # type: () -> Any - if self.scanner.check_token(BlockEntryToken): - token = self.scanner.get_token() - token.move_comment(self.scanner.peek_token()) - if not self.scanner.check_token( - BlockEntryToken, KeyToken, ValueToken, BlockEndToken - ): - self.states.append(self.parse_indentless_sequence_entry) - return self.parse_block_node() - else: - self.state = self.parse_indentless_sequence_entry - return self.process_empty_scalar(token.end_mark) - token = self.scanner.peek_token() - event = SequenceEndEvent( - token.start_mark, token.start_mark, comment=token.comment - ) - self.state = self.states.pop() - return event - - # block_mapping ::= BLOCK-MAPPING_START - # ((KEY block_node_or_indentless_sequence?)? - # (VALUE block_node_or_indentless_sequence?)?)* - # BLOCK-END - - def parse_block_mapping_first_key(self): - # type: () -> Any - token = self.scanner.get_token() - self.marks.append(token.start_mark) - return self.parse_block_mapping_key() - - def parse_block_mapping_key(self): - # type: () -> Any - if self.scanner.check_token(KeyToken): - token = self.scanner.get_token() - token.move_comment(self.scanner.peek_token()) - if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken): - self.states.append(self.parse_block_mapping_value) - return self.parse_block_node_or_indentless_sequence() - else: - self.state = self.parse_block_mapping_value - return self.process_empty_scalar(token.end_mark) - if self.resolver.processing_version > (1, 1) and self.scanner.check_token( - ValueToken - ): - self.state = self.parse_block_mapping_value - return self.process_empty_scalar(self.scanner.peek_token().start_mark) - if not self.scanner.check_token(BlockEndToken): - token = self.scanner.peek_token() - raise ParserError( - "while parsing a block mapping", - self.marks[-1], - "expected , but found %r" % token.id, - token.start_mark, - ) - token = self.scanner.get_token() - token.move_comment(self.scanner.peek_token()) - event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment) - self.state = self.states.pop() - self.marks.pop() - return event - - def parse_block_mapping_value(self): - # type: () -> Any - if self.scanner.check_token(ValueToken): - token = self.scanner.get_token() - # value token might have post comment move it to e.g. block - if self.scanner.check_token(ValueToken): - token.move_comment(self.scanner.peek_token()) - else: - if not self.scanner.check_token(KeyToken): - token.move_comment(self.scanner.peek_token(), empty=True) - # else: empty value for this key cannot move token.comment - if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken): - self.states.append(self.parse_block_mapping_key) - return self.parse_block_node_or_indentless_sequence() - else: - self.state = self.parse_block_mapping_key - comment = token.comment - if comment is None: - token = self.scanner.peek_token() - comment = token.comment - if comment: - token._comment = [None, comment[1]] - comment = [comment[0], None] - return self.process_empty_scalar(token.end_mark, comment=comment) - else: - self.state = self.parse_block_mapping_key - token = self.scanner.peek_token() - return self.process_empty_scalar(token.start_mark) - - # flow_sequence ::= FLOW-SEQUENCE-START - # (flow_sequence_entry FLOW-ENTRY)* - # flow_sequence_entry? - # FLOW-SEQUENCE-END - # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? - # - # Note that while production rules for both flow_sequence_entry and - # flow_mapping_entry are equal, their interpretations are different. - # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?` - # generate an inline mapping (set syntax). - - def parse_flow_sequence_first_entry(self): - # type: () -> Any - token = self.scanner.get_token() - self.marks.append(token.start_mark) - return self.parse_flow_sequence_entry(first=True) - - def parse_flow_sequence_entry(self, first=False): - # type: (bool) -> Any - if not self.scanner.check_token(FlowSequenceEndToken): - if not first: - if self.scanner.check_token(FlowEntryToken): - self.scanner.get_token() - else: - token = self.scanner.peek_token() - raise ParserError( - "while parsing a flow sequence", - self.marks[-1], - "expected ',' or ']', but got %r" % token.id, - token.start_mark, - ) - - if self.scanner.check_token(KeyToken): - token = self.scanner.peek_token() - event = MappingStartEvent( - None, None, True, token.start_mark, token.end_mark, flow_style=True - ) # type: Any - self.state = self.parse_flow_sequence_entry_mapping_key - return event - elif not self.scanner.check_token(FlowSequenceEndToken): - self.states.append(self.parse_flow_sequence_entry) - return self.parse_flow_node() - token = self.scanner.get_token() - event = SequenceEndEvent( - token.start_mark, token.end_mark, comment=token.comment - ) - self.state = self.states.pop() - self.marks.pop() - return event - - def parse_flow_sequence_entry_mapping_key(self): - # type: () -> Any - token = self.scanner.get_token() - if not self.scanner.check_token( - ValueToken, FlowEntryToken, FlowSequenceEndToken - ): - self.states.append(self.parse_flow_sequence_entry_mapping_value) - return self.parse_flow_node() - else: - self.state = self.parse_flow_sequence_entry_mapping_value - return self.process_empty_scalar(token.end_mark) - - def parse_flow_sequence_entry_mapping_value(self): - # type: () -> Any - if self.scanner.check_token(ValueToken): - token = self.scanner.get_token() - if not self.scanner.check_token(FlowEntryToken, FlowSequenceEndToken): - self.states.append(self.parse_flow_sequence_entry_mapping_end) - return self.parse_flow_node() - else: - self.state = self.parse_flow_sequence_entry_mapping_end - return self.process_empty_scalar(token.end_mark) - else: - self.state = self.parse_flow_sequence_entry_mapping_end - token = self.scanner.peek_token() - return self.process_empty_scalar(token.start_mark) - - def parse_flow_sequence_entry_mapping_end(self): - # type: () -> Any - self.state = self.parse_flow_sequence_entry - token = self.scanner.peek_token() - return MappingEndEvent(token.start_mark, token.start_mark) - - # flow_mapping ::= FLOW-MAPPING-START - # (flow_mapping_entry FLOW-ENTRY)* - # flow_mapping_entry? - # FLOW-MAPPING-END - # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)? - - def parse_flow_mapping_first_key(self): - # type: () -> Any - token = self.scanner.get_token() - self.marks.append(token.start_mark) - return self.parse_flow_mapping_key(first=True) - - def parse_flow_mapping_key(self, first=False): - # type: (Any) -> Any - if not self.scanner.check_token(FlowMappingEndToken): - if not first: - if self.scanner.check_token(FlowEntryToken): - self.scanner.get_token() - else: - token = self.scanner.peek_token() - raise ParserError( - "while parsing a flow mapping", - self.marks[-1], - "expected ',' or '}', but got %r" % token.id, - token.start_mark, - ) - if self.scanner.check_token(KeyToken): - token = self.scanner.get_token() - if not self.scanner.check_token( - ValueToken, FlowEntryToken, FlowMappingEndToken - ): - self.states.append(self.parse_flow_mapping_value) - return self.parse_flow_node() - else: - self.state = self.parse_flow_mapping_value - return self.process_empty_scalar(token.end_mark) - elif self.resolver.processing_version > (1, 1) and self.scanner.check_token( - ValueToken - ): - self.state = self.parse_flow_mapping_value - return self.process_empty_scalar(self.scanner.peek_token().end_mark) - elif not self.scanner.check_token(FlowMappingEndToken): - self.states.append(self.parse_flow_mapping_empty_value) - return self.parse_flow_node() - token = self.scanner.get_token() - event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment) - self.state = self.states.pop() - self.marks.pop() - return event - - def parse_flow_mapping_value(self): - # type: () -> Any - if self.scanner.check_token(ValueToken): - token = self.scanner.get_token() - if not self.scanner.check_token(FlowEntryToken, FlowMappingEndToken): - self.states.append(self.parse_flow_mapping_key) - return self.parse_flow_node() - else: - self.state = self.parse_flow_mapping_key - return self.process_empty_scalar(token.end_mark) - else: - self.state = self.parse_flow_mapping_key - token = self.scanner.peek_token() - return self.process_empty_scalar(token.start_mark) - - def parse_flow_mapping_empty_value(self): - # type: () -> Any - self.state = self.parse_flow_mapping_key - return self.process_empty_scalar(self.scanner.peek_token().start_mark) - - def process_empty_scalar(self, mark, comment=None): - # type: (Any, Any) -> Any - return ScalarEvent(None, None, (True, False), "", mark, mark, comment=comment) - - -class RoundTripParser(Parser): - """roundtrip is a safe loader, that wants to see the unmangled tag""" - - def transform_tag(self, handle, suffix): - # type: (Any, Any) -> Any - # return self.tag_handles[handle]+suffix - if handle == "!!" and suffix in ( - u"null", - u"bool", - u"int", - u"float", - u"binary", - u"timestamp", - u"omap", - u"pairs", - u"set", - u"str", - u"seq", - u"map", - ): - return Parser.transform_tag(self, handle, suffix) - return handle + suffix diff --git a/srsly/ruamel_yaml/py.typed b/srsly/ruamel_yaml/py.typed deleted file mode 100755 index e69de29..0000000 diff --git a/srsly/ruamel_yaml/reader.py b/srsly/ruamel_yaml/reader.py deleted file mode 100755 index b90c73b..0000000 --- a/srsly/ruamel_yaml/reader.py +++ /dev/null @@ -1,327 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -# This module contains abstractions for the input stream. You don't have to -# looks further, there are no pretty code. -# -# We define two classes here. -# -# Mark(source, line, column) -# It's just a record and its only use is producing nice error messages. -# Parser does not use it for any other purposes. -# -# Reader(source, data) -# Reader determines the encoding of `data` and converts it to unicode. -# Reader provides the following methods and attributes: -# reader.peek(length=1) - return the next `length` characters -# reader.forward(length=1) - move the current position to `length` -# characters. -# reader.index - the number of the current character. -# reader.line, stream.column - the line and the column of the current -# character. - -import codecs - -from .error import YAMLError, FileMark, StringMark, YAMLStreamError -from .compat import text_type, binary_type, PY3, UNICODE_SIZE -from .util import RegExp - -if False: # MYPY - from typing import Any, Dict, Optional, List, Union, Text, Tuple, Optional # NOQA -# from srsly.ruamel_yaml.compat import StreamTextType # NOQA - -__all__ = ["Reader", "ReaderError"] - - -class ReaderError(YAMLError): - def __init__(self, name, position, character, encoding, reason): - # type: (Any, Any, Any, Any, Any) -> None - self.name = name - self.character = character - self.position = position - self.encoding = encoding - self.reason = reason - - def __str__(self): - # type: () -> str - if isinstance(self.character, binary_type): - return ( - "'%s' codec can't decode byte #x%02x: %s\n" - ' in "%s", position %d' - % ( - self.encoding, - ord(self.character), - self.reason, - self.name, - self.position, - ) - ) - else: - return "unacceptable character #x%04x: %s\n" ' in "%s", position %d' % ( - self.character, - self.reason, - self.name, - self.position, - ) - - -class Reader(object): - # Reader: - # - determines the data encoding and converts it to a unicode string, - # - checks if characters are in allowed range, - # - adds '\0' to the end. - - # Reader accepts - # - a `str` object (PY2) / a `bytes` object (PY3), - # - a `unicode` object (PY2) / a `str` object (PY3), - # - a file-like object with its `read` method returning `str`, - # - a file-like object with its `read` method returning `unicode`. - - # Yeah, it's ugly and slow. - - def __init__(self, stream, loader=None): - # type: (Any, Any) -> None - self.loader = loader - if self.loader is not None and getattr(self.loader, "_reader", None) is None: - self.loader._reader = self - self.reset_reader() - self.stream = stream # type: Any # as .read is called - - def reset_reader(self): - # type: () -> None - self.name = None # type: Any - self.stream_pointer = 0 - self.eof = True - self.buffer = "" - self.pointer = 0 - self.raw_buffer = None # type: Any - self.raw_decode = None - self.encoding = None # type: Optional[Text] - self.index = 0 - self.line = 0 - self.column = 0 - - @property - def stream(self): - # type: () -> Any - try: - return self._stream - except AttributeError: - raise YAMLStreamError("input stream needs to specified") - - @stream.setter - def stream(self, val): - # type: (Any) -> None - if val is None: - return - self._stream = None - if isinstance(val, text_type): - self.name = "" - self.check_printable(val) - self.buffer = val + u"\0" # type: ignore - elif isinstance(val, binary_type): - self.name = "" - self.raw_buffer = val - self.determine_encoding() - else: - if not hasattr(val, "read"): - raise YAMLStreamError("stream argument needs to have a read() method") - self._stream = val - self.name = getattr(self.stream, "name", "") - self.eof = False - self.raw_buffer = None - self.determine_encoding() - - def peek(self, index=0): - # type: (int) -> Text - try: - return self.buffer[self.pointer + index] - except IndexError: - self.update(index + 1) - return self.buffer[self.pointer + index] - - def prefix(self, length=1): - # type: (int) -> Any - if self.pointer + length >= len(self.buffer): - self.update(length) - return self.buffer[self.pointer : self.pointer + length] - - def forward_1_1(self, length=1): - # type: (int) -> None - if self.pointer + length + 1 >= len(self.buffer): - self.update(length + 1) - while length != 0: - ch = self.buffer[self.pointer] - self.pointer += 1 - self.index += 1 - if ch in u"\n\x85\u2028\u2029" or ( - ch == u"\r" and self.buffer[self.pointer] != u"\n" - ): - self.line += 1 - self.column = 0 - elif ch != u"\uFEFF": - self.column += 1 - length -= 1 - - def forward(self, length=1): - # type: (int) -> None - if self.pointer + length + 1 >= len(self.buffer): - self.update(length + 1) - while length != 0: - ch = self.buffer[self.pointer] - self.pointer += 1 - self.index += 1 - if ch == u"\n" or (ch == u"\r" and self.buffer[self.pointer] != u"\n"): - self.line += 1 - self.column = 0 - elif ch != u"\uFEFF": - self.column += 1 - length -= 1 - - def get_mark(self): - # type: () -> Any - if self.stream is None: - return StringMark( - self.name, self.index, self.line, self.column, self.buffer, self.pointer - ) - else: - return FileMark(self.name, self.index, self.line, self.column) - - def determine_encoding(self): - # type: () -> None - while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2): - self.update_raw() - if isinstance(self.raw_buffer, binary_type): - if self.raw_buffer.startswith(codecs.BOM_UTF16_LE): - self.raw_decode = codecs.utf_16_le_decode # type: ignore - self.encoding = "utf-16-le" - elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE): - self.raw_decode = codecs.utf_16_be_decode # type: ignore - self.encoding = "utf-16-be" - else: - self.raw_decode = codecs.utf_8_decode # type: ignore - self.encoding = "utf-8" - self.update(1) - - if UNICODE_SIZE == 2: - NON_PRINTABLE = RegExp( - u"[^\x09\x0A\x0D\x20-\x7E\x85" u"\xA0-\uD7FF" u"\uE000-\uFFFD" u"]" - ) - else: - NON_PRINTABLE = RegExp( - u"[^\x09\x0A\x0D\x20-\x7E\x85" - u"\xA0-\uD7FF" - u"\uE000-\uFFFD" - u"\U00010000-\U0010FFFF" - u"]" - ) - - _printable_ascii = ("\x09\x0A\x0D" + "".join(map(chr, range(0x20, 0x7F)))).encode( - "ascii" - ) - - @classmethod - def _get_non_printable_ascii(cls, data): # type: ignore - # type: (Text, bytes) -> Optional[Tuple[int, Text]] - ascii_bytes = data.encode("ascii") - non_printables = ascii_bytes.translate( - None, cls._printable_ascii - ) # type: ignore - if not non_printables: - return None - non_printable = non_printables[:1] - return ascii_bytes.index(non_printable), non_printable.decode("ascii") - - @classmethod - def _get_non_printable_regex(cls, data): - # type: (Text) -> Optional[Tuple[int, Text]] - match = cls.NON_PRINTABLE.search(data) - if not bool(match): - return None - return match.start(), match.group() - - @classmethod - def _get_non_printable(cls, data): - # type: (Text) -> Optional[Tuple[int, Text]] - try: - return cls._get_non_printable_ascii(data) # type: ignore - except UnicodeEncodeError: - return cls._get_non_printable_regex(data) - - def check_printable(self, data): - # type: (Any) -> None - non_printable_match = self._get_non_printable(data) - if non_printable_match is not None: - start, character = non_printable_match - position = self.index + (len(self.buffer) - self.pointer) + start - raise ReaderError( - self.name, - position, - ord(character), - "unicode", - "special characters are not allowed", - ) - - def update(self, length): - # type: (int) -> None - if self.raw_buffer is None: - return - self.buffer = self.buffer[self.pointer :] - self.pointer = 0 - while len(self.buffer) < length: - if not self.eof: - self.update_raw() - if self.raw_decode is not None: - try: - data, converted = self.raw_decode( - self.raw_buffer, "strict", self.eof - ) - except UnicodeDecodeError as exc: - if PY3: - character = self.raw_buffer[exc.start] - else: - character = exc.object[exc.start] - if self.stream is not None: - position = ( - self.stream_pointer - len(self.raw_buffer) + exc.start - ) - elif self.stream is not None: - position = ( - self.stream_pointer - len(self.raw_buffer) + exc.start - ) - else: - position = exc.start - raise ReaderError( - self.name, position, character, exc.encoding, exc.reason - ) - else: - data = self.raw_buffer - converted = len(data) - self.check_printable(data) - self.buffer += data - self.raw_buffer = self.raw_buffer[converted:] - if self.eof: - self.buffer += "\0" - self.raw_buffer = None - break - - def update_raw(self, size=None): - # type: (Optional[int]) -> None - if size is None: - size = 4096 if PY3 else 1024 - data = self.stream.read(size) - if self.raw_buffer is None: - self.raw_buffer = data - else: - self.raw_buffer += data - self.stream_pointer += len(data) - if not data: - self.eof = True - - -# try: -# import psyco -# psyco.bind(Reader) -# except ImportError: -# pass diff --git a/srsly/ruamel_yaml/representer.py b/srsly/ruamel_yaml/representer.py deleted file mode 100755 index 73e1793..0000000 --- a/srsly/ruamel_yaml/representer.py +++ /dev/null @@ -1,1330 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division - - -from .error import * # NOQA -from .nodes import * # NOQA -from .compat import text_type, binary_type, to_unicode, PY2, PY3 -from .compat import ordereddict # type: ignore -from .compat import nprint, nprintf # NOQA -from .scalarstring import ( - LiteralScalarString, - FoldedScalarString, - SingleQuotedScalarString, - DoubleQuotedScalarString, - PlainScalarString, -) -from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt -from .scalarfloat import ScalarFloat -from .scalarbool import ScalarBoolean -from .timestamp import TimeStamp - -import datetime -import sys -import types - -if PY3: - import copyreg - import base64 -else: - import copy_reg as copyreg # type: ignore - -if False: # MYPY - from typing import Dict, List, Any, Union, Text, Optional # NOQA - -# fmt: off -__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer', - 'RepresenterError', 'RoundTripRepresenter'] -# fmt: on - - -class RepresenterError(YAMLError): - pass - - -if PY2: - - def get_classobj_bases(cls): - # type: (Any) -> Any - bases = [cls] - for base in cls.__bases__: - bases.extend(get_classobj_bases(base)) - return bases - - -class BaseRepresenter(object): - - yaml_representers = {} # type: Dict[Any, Any] - yaml_multi_representers = {} # type: Dict[Any, Any] - - def __init__(self, default_style=None, default_flow_style=None, dumper=None): - # type: (Any, Any, Any, Any) -> None - self.dumper = dumper - if self.dumper is not None: - self.dumper._representer = self - self.default_style = default_style - self.default_flow_style = default_flow_style - self.represented_objects = {} # type: Dict[Any, Any] - self.object_keeper = [] # type: List[Any] - self.alias_key = None # type: Optional[int] - self.sort_base_mapping_type_on_output = True - - @property - def serializer(self): - # type: () -> Any - try: - if hasattr(self.dumper, "typ"): - return self.dumper.serializer - return self.dumper._serializer - except AttributeError: - return self # cyaml - - def represent(self, data): - # type: (Any) -> None - node = self.represent_data(data) - self.serializer.serialize(node) - self.represented_objects = {} - self.object_keeper = [] - self.alias_key = None - - def represent_data(self, data): - # type: (Any) -> Any - if self.ignore_aliases(data): - self.alias_key = None - else: - self.alias_key = id(data) - if self.alias_key is not None: - if self.alias_key in self.represented_objects: - node = self.represented_objects[self.alias_key] - # if node is None: - # raise RepresenterError( - # "recursive objects are not allowed: %r" % data) - return node - # self.represented_objects[alias_key] = None - self.object_keeper.append(data) - data_types = type(data).__mro__ - if PY2: - # if type(data) is types.InstanceType: - if isinstance(data, types.InstanceType): - data_types = get_classobj_bases(data.__class__) + list(data_types) - if data_types[0] in self.yaml_representers: - node = self.yaml_representers[data_types[0]](self, data) - else: - for data_type in data_types: - if data_type in self.yaml_multi_representers: - node = self.yaml_multi_representers[data_type](self, data) - break - else: - if None in self.yaml_multi_representers: - node = self.yaml_multi_representers[None](self, data) - elif None in self.yaml_representers: - node = self.yaml_representers[None](self, data) - else: - node = ScalarNode(None, text_type(data)) - # if alias_key is not None: - # self.represented_objects[alias_key] = node - return node - - def represent_key(self, data): - # type: (Any) -> Any - """ - David Fraser: Extract a method to represent keys in mappings, so that - a subclass can choose not to quote them (for example) - used in represent_mapping - https://bitbucket.org/davidfraser/pyyaml/commits/d81df6eb95f20cac4a79eed95ae553b5c6f77b8c - """ - return self.represent_data(data) - - @classmethod - def add_representer(cls, data_type, representer): - # type: (Any, Any) -> None - if "yaml_representers" not in cls.__dict__: - cls.yaml_representers = cls.yaml_representers.copy() - cls.yaml_representers[data_type] = representer - - @classmethod - def add_multi_representer(cls, data_type, representer): - # type: (Any, Any) -> None - if "yaml_multi_representers" not in cls.__dict__: - cls.yaml_multi_representers = cls.yaml_multi_representers.copy() - cls.yaml_multi_representers[data_type] = representer - - def represent_scalar(self, tag, value, style=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - if style is None: - style = self.default_style - comment = None - if style and style[0] in "|>": - comment = getattr(value, "comment", None) - if comment: - comment = [None, [comment]] - node = ScalarNode(tag, value, style=style, comment=comment, anchor=anchor) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - return node - - def represent_sequence(self, tag, sequence, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - node = SequenceNode(tag, value, flow_style=flow_style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - for item in sequence: - node_item = self.represent_data(item) - if not (isinstance(node_item, ScalarNode) and not node_item.style): - best_style = False - value.append(node_item) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - def represent_omap(self, tag, omap, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - node = SequenceNode(tag, value, flow_style=flow_style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - for item_key in omap: - item_val = omap[item_key] - node_item = self.represent_data({item_key: item_val}) - # if not (isinstance(node_item, ScalarNode) \ - # and not node_item.style): - # best_style = False - value.append(node_item) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - def represent_mapping(self, tag, mapping, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - node = MappingNode(tag, value, flow_style=flow_style) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - if hasattr(mapping, "items"): - mapping = list(mapping.items()) - if self.sort_base_mapping_type_on_output: - try: - mapping = sorted(mapping) - except TypeError: - pass - for item_key, item_value in mapping: - node_key = self.represent_key(item_key) - node_value = self.represent_data(item_value) - if not (isinstance(node_key, ScalarNode) and not node_key.style): - best_style = False - if not (isinstance(node_value, ScalarNode) and not node_value.style): - best_style = False - value.append((node_key, node_value)) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - def ignore_aliases(self, data): - # type: (Any) -> bool - return False - - -class SafeRepresenter(BaseRepresenter): - def ignore_aliases(self, data): - # type: (Any) -> bool - # https://docs.python.org/3/reference/expressions.html#parenthesized-forms : - # "i.e. two occurrences of the empty tuple may or may not yield the same object" - # so "data is ()" should not be used - if data is None or (isinstance(data, tuple) and data == ()): - return True - if isinstance(data, (binary_type, text_type, bool, int, float)): - return True - return False - - def represent_none(self, data): - # type: (Any) -> Any - return self.represent_scalar(u"tag:yaml.org,2002:null", u"null") - - if PY3: - - def represent_str(self, data): - # type: (Any) -> Any - return self.represent_scalar(u"tag:yaml.org,2002:str", data) - - def represent_binary(self, data): - # type: (Any) -> Any - if hasattr(base64, "encodebytes"): - data = base64.encodebytes(data).decode("ascii") - else: - data = base64.encodestring(data).decode("ascii") - return self.represent_scalar(u"tag:yaml.org,2002:binary", data, style="|") - - else: - - def represent_str(self, data): - # type: (Any) -> Any - tag = None - style = None - try: - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - except UnicodeDecodeError: - try: - data = unicode(data, "utf-8") - tag = u"tag:yaml.org,2002:str" - except UnicodeDecodeError: - data = data.encode("base64") - tag = u"tag:yaml.org,2002:binary" - style = "|" - return self.represent_scalar(tag, data, style=style) - - def represent_unicode(self, data): - # type: (Any) -> Any - return self.represent_scalar(u"tag:yaml.org,2002:str", data) - - def represent_bool(self, data, anchor=None): - # type: (Any, Optional[Any]) -> Any - try: - value = self.dumper.boolean_representation[bool(data)] - except AttributeError: - if data: - value = u"true" - else: - value = u"false" - return self.represent_scalar(u"tag:yaml.org,2002:bool", value, anchor=anchor) - - def represent_int(self, data): - # type: (Any) -> Any - return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data)) - - if PY2: - - def represent_long(self, data): - # type: (Any) -> Any - return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data)) - - inf_value = 1e300 - while repr(inf_value) != repr(inf_value * inf_value): - inf_value *= inf_value - - def represent_float(self, data): - # type: (Any) -> Any - if data != data or (data == 0.0 and data == 1.0): - value = u".nan" - elif data == self.inf_value: - value = u".inf" - elif data == -self.inf_value: - value = u"-.inf" - else: - value = to_unicode(repr(data)).lower() - if getattr(self.serializer, "use_version", None) == (1, 1): - if u"." not in value and u"e" in value: - # Note that in some cases `repr(data)` represents a float number - # without the decimal parts. For instance: - # >>> repr(1e17) - # '1e17' - # Unfortunately, this is not a valid float representation according - # to the definition of the `!!float` tag in YAML 1.1. We fix - # this by adding '.0' before the 'e' symbol. - value = value.replace(u"e", u".0e", 1) - return self.represent_scalar(u"tag:yaml.org,2002:float", value) - - def represent_list(self, data): - # type: (Any) -> Any - # pairs = (len(data) > 0 and isinstance(data, list)) - # if pairs: - # for item in data: - # if not isinstance(item, tuple) or len(item) != 2: - # pairs = False - # break - # if not pairs: - return self.represent_sequence(u"tag:yaml.org,2002:seq", data) - - # value = [] - # for item_key, item_value in data: - # value.append(self.represent_mapping(u'tag:yaml.org,2002:map', - # [(item_key, item_value)])) - # return SequenceNode(u'tag:yaml.org,2002:pairs', value) - - def represent_dict(self, data): - # type: (Any) -> Any - return self.represent_mapping(u"tag:yaml.org,2002:map", data) - - def represent_ordereddict(self, data): - # type: (Any) -> Any - return self.represent_omap(u"tag:yaml.org,2002:omap", data) - - def represent_set(self, data): - # type: (Any) -> Any - value = {} # type: Dict[Any, None] - for key in data: - value[key] = None - return self.represent_mapping(u"tag:yaml.org,2002:set", value) - - def represent_date(self, data): - # type: (Any) -> Any - value = to_unicode(data.isoformat()) - return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value) - - def represent_datetime(self, data): - # type: (Any) -> Any - value = to_unicode(data.isoformat(" ")) - return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value) - - def represent_yaml_object(self, tag, data, cls, flow_style=None): - # type: (Any, Any, Any, Any) -> Any - if hasattr(data, "__getstate__"): - state = data.__getstate__() - else: - state = data.__dict__.copy() - return self.represent_mapping(tag, state, flow_style=flow_style) - - def represent_undefined(self, data): - # type: (Any) -> None - raise RepresenterError("cannot represent an object: %s" % (data,)) - - -SafeRepresenter.add_representer(type(None), SafeRepresenter.represent_none) - -SafeRepresenter.add_representer(str, SafeRepresenter.represent_str) - -if PY2: - SafeRepresenter.add_representer(unicode, SafeRepresenter.represent_unicode) -else: - SafeRepresenter.add_representer(bytes, SafeRepresenter.represent_binary) - -SafeRepresenter.add_representer(bool, SafeRepresenter.represent_bool) - -SafeRepresenter.add_representer(int, SafeRepresenter.represent_int) - -if PY2: - SafeRepresenter.add_representer(long, SafeRepresenter.represent_long) - -SafeRepresenter.add_representer(float, SafeRepresenter.represent_float) - -SafeRepresenter.add_representer(list, SafeRepresenter.represent_list) - -SafeRepresenter.add_representer(tuple, SafeRepresenter.represent_list) - -SafeRepresenter.add_representer(dict, SafeRepresenter.represent_dict) - -SafeRepresenter.add_representer(set, SafeRepresenter.represent_set) - -SafeRepresenter.add_representer(ordereddict, SafeRepresenter.represent_ordereddict) - -if sys.version_info >= (2, 7): - import collections - - SafeRepresenter.add_representer( - collections.OrderedDict, SafeRepresenter.represent_ordereddict - ) - -SafeRepresenter.add_representer(datetime.date, SafeRepresenter.represent_date) - -SafeRepresenter.add_representer(datetime.datetime, SafeRepresenter.represent_datetime) - -SafeRepresenter.add_representer(None, SafeRepresenter.represent_undefined) - - -class Representer(SafeRepresenter): - if PY2: - - def represent_str(self, data): - # type: (Any) -> Any - tag = None - style = None - try: - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - except UnicodeDecodeError: - try: - data = unicode(data, "utf-8") - tag = u"tag:yaml.org,2002:python/str" - except UnicodeDecodeError: - data = data.encode("base64") - tag = u"tag:yaml.org,2002:binary" - style = "|" - return self.represent_scalar(tag, data, style=style) - - def represent_unicode(self, data): - # type: (Any) -> Any - tag = None - try: - data.encode("ascii") - tag = u"tag:yaml.org,2002:python/unicode" - except UnicodeEncodeError: - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data) - - def represent_long(self, data): - # type: (Any) -> Any - tag = u"tag:yaml.org,2002:int" - if int(data) is not data: - tag = u"tag:yaml.org,2002:python/long" - return self.represent_scalar(tag, to_unicode(data)) - - def represent_complex(self, data): - # type: (Any) -> Any - if data.imag == 0.0: - data = u"%r" % data.real - elif data.real == 0.0: - data = u"%rj" % data.imag - elif data.imag > 0: - data = u"%r+%rj" % (data.real, data.imag) - else: - data = u"%r%rj" % (data.real, data.imag) - return self.represent_scalar(u"tag:yaml.org,2002:python/complex", data) - - def represent_tuple(self, data): - # type: (Any) -> Any - return self.represent_sequence(u"tag:yaml.org,2002:python/tuple", data) - - def represent_name(self, data): - # type: (Any) -> Any - try: - name = u"%s.%s" % (data.__module__, data.__qualname__) - except AttributeError: - # probably PY2 - name = u"%s.%s" % (data.__module__, data.__name__) - return self.represent_scalar(u"tag:yaml.org,2002:python/name:" + name, "") - - def represent_module(self, data): - # type: (Any) -> Any - return self.represent_scalar( - u"tag:yaml.org,2002:python/module:" + data.__name__, "" - ) - - if PY2: - - def represent_instance(self, data): - # type: (Any) -> Any - # For instances of classic classes, we use __getinitargs__ and - # __getstate__ to serialize the data. - - # If data.__getinitargs__ exists, the object must be reconstructed - # by calling cls(**args), where args is a tuple returned by - # __getinitargs__. Otherwise, the cls.__init__ method should never - # be called and the class instance is created by instantiating a - # trivial class and assigning to the instance's __class__ variable. - - # If data.__getstate__ exists, it returns the state of the object. - # Otherwise, the state of the object is data.__dict__. - - # We produce either a !!python/object or !!python/object/new node. - # If data.__getinitargs__ does not exist and state is a dictionary, - # we produce a !!python/object node . Otherwise we produce a - # !!python/object/new node. - - cls = data.__class__ - class_name = u"%s.%s" % (cls.__module__, cls.__name__) - args = None - state = None - if hasattr(data, "__getinitargs__"): - args = list(data.__getinitargs__()) - if hasattr(data, "__getstate__"): - state = data.__getstate__() - else: - state = data.__dict__ - if args is None and isinstance(state, dict): - return self.represent_mapping( - u"tag:yaml.org,2002:python/object:" + class_name, state - ) - if isinstance(state, dict) and not state: - return self.represent_sequence( - u"tag:yaml.org,2002:python/object/new:" + class_name, args - ) - value = {} - if bool(args): - value["args"] = args - value["state"] = state # type: ignore - return self.represent_mapping( - u"tag:yaml.org,2002:python/object/new:" + class_name, value - ) - - def represent_object(self, data): - # type: (Any) -> Any - # We use __reduce__ API to save the data. data.__reduce__ returns - # a tuple of length 2-5: - # (function, args, state, listitems, dictitems) - - # For reconstructing, we calls function(*args), then set its state, - # listitems, and dictitems if they are not None. - - # A special case is when function.__name__ == '__newobj__'. In this - # case we create the object with args[0].__new__(*args). - - # Another special case is when __reduce__ returns a string - we don't - # support it. - - # We produce a !!python/object, !!python/object/new or - # !!python/object/apply node. - - cls = type(data) - if cls in copyreg.dispatch_table: - reduce = copyreg.dispatch_table[cls](data) - elif hasattr(data, "__reduce_ex__"): - reduce = data.__reduce_ex__(2) - elif hasattr(data, "__reduce__"): - reduce = data.__reduce__() - else: - raise RepresenterError("cannot represent object: %r" % (data,)) - reduce = (list(reduce) + [None] * 5)[:5] - function, args, state, listitems, dictitems = reduce - args = list(args) - if state is None: - state = {} - if listitems is not None: - listitems = list(listitems) - if dictitems is not None: - dictitems = dict(dictitems) - if function.__name__ == "__newobj__": - function = args[0] - args = args[1:] - tag = u"tag:yaml.org,2002:python/object/new:" - newobj = True - else: - tag = u"tag:yaml.org,2002:python/object/apply:" - newobj = False - try: - function_name = u"%s.%s" % (function.__module__, function.__qualname__) - except AttributeError: - # probably PY2 - function_name = u"%s.%s" % (function.__module__, function.__name__) - if ( - not args - and not listitems - and not dictitems - and isinstance(state, dict) - and newobj - ): - return self.represent_mapping( - u"tag:yaml.org,2002:python/object:" + function_name, state - ) - if not listitems and not dictitems and isinstance(state, dict) and not state: - return self.represent_sequence(tag + function_name, args) - value = {} - if args: - value["args"] = args - if state or not isinstance(state, dict): - value["state"] = state - if listitems: - value["listitems"] = listitems - if dictitems: - value["dictitems"] = dictitems - return self.represent_mapping(tag + function_name, value) - - -if PY2: - Representer.add_representer(str, Representer.represent_str) - - Representer.add_representer(unicode, Representer.represent_unicode) - - Representer.add_representer(long, Representer.represent_long) - -Representer.add_representer(complex, Representer.represent_complex) - -Representer.add_representer(tuple, Representer.represent_tuple) - -Representer.add_representer(type, Representer.represent_name) - -if PY2: - Representer.add_representer(types.ClassType, Representer.represent_name) - -Representer.add_representer(types.FunctionType, Representer.represent_name) - -Representer.add_representer(types.BuiltinFunctionType, Representer.represent_name) - -Representer.add_representer(types.ModuleType, Representer.represent_module) - -if PY2: - Representer.add_multi_representer( - types.InstanceType, Representer.represent_instance - ) - -Representer.add_multi_representer(object, Representer.represent_object) - -Representer.add_multi_representer(type, Representer.represent_name) - -from .comments import ( - CommentedMap, - CommentedOrderedMap, - CommentedSeq, - CommentedKeySeq, - CommentedKeyMap, - CommentedSet, - comment_attrib, - merge_attrib, - TaggedScalar, -) # NOQA - - -class RoundTripRepresenter(SafeRepresenter): - # need to add type here and write out the .comment - # in serializer and emitter - - def __init__(self, default_style=None, default_flow_style=None, dumper=None): - # type: (Any, Any, Any) -> None - if not hasattr(dumper, "typ") and default_flow_style is None: - default_flow_style = False - SafeRepresenter.__init__( - self, - default_style=default_style, - default_flow_style=default_flow_style, - dumper=dumper, - ) - - def ignore_aliases(self, data): - # type: (Any) -> bool - try: - if data.anchor is not None and data.anchor.value is not None: - return False - except AttributeError: - pass - return SafeRepresenter.ignore_aliases(self, data) - - def represent_none(self, data): - # type: (Any) -> Any - if ( - len(self.represented_objects) == 0 - and not self.serializer.use_explicit_start - ): - # this will be open ended (although it is not yet) - return self.represent_scalar(u"tag:yaml.org,2002:null", u"null") - return self.represent_scalar(u"tag:yaml.org,2002:null", "") - - def represent_literal_scalarstring(self, data): - # type: (Any) -> Any - tag = None - style = "|" - anchor = data.yaml_anchor(any=True) - if PY2 and not isinstance(data, unicode): - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data, style=style, anchor=anchor) - - represent_preserved_scalarstring = represent_literal_scalarstring - - def represent_folded_scalarstring(self, data): - # type: (Any) -> Any - tag = None - style = ">" - anchor = data.yaml_anchor(any=True) - for fold_pos in reversed(getattr(data, "fold_pos", [])): - if ( - data[fold_pos] == " " - and (fold_pos > 0 and not data[fold_pos - 1].isspace()) - and (fold_pos < len(data) and not data[fold_pos + 1].isspace()) - ): - data = data[:fold_pos] + "\a" + data[fold_pos:] - if PY2 and not isinstance(data, unicode): - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data, style=style, anchor=anchor) - - def represent_single_quoted_scalarstring(self, data): - # type: (Any) -> Any - tag = None - style = "'" - anchor = data.yaml_anchor(any=True) - if PY2 and not isinstance(data, unicode): - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data, style=style, anchor=anchor) - - def represent_double_quoted_scalarstring(self, data): - # type: (Any) -> Any - tag = None - style = '"' - anchor = data.yaml_anchor(any=True) - if PY2 and not isinstance(data, unicode): - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data, style=style, anchor=anchor) - - def represent_plain_scalarstring(self, data): - # type: (Any) -> Any - tag = None - style = "" - anchor = data.yaml_anchor(any=True) - if PY2 and not isinstance(data, unicode): - data = unicode(data, "ascii") - tag = u"tag:yaml.org,2002:str" - return self.represent_scalar(tag, data, style=style, anchor=anchor) - - def insert_underscore(self, prefix, s, underscore, anchor=None): - # type: (Any, Any, Any, Any) -> Any - if underscore is None: - return self.represent_scalar( - u"tag:yaml.org,2002:int", prefix + s, anchor=anchor - ) - if underscore[0]: - sl = list(s) - pos = len(s) - underscore[0] - while pos > 0: - sl.insert(pos, "_") - pos -= underscore[0] - s = "".join(sl) - if underscore[1]: - s = "_" + s - if underscore[2]: - s += "_" - return self.represent_scalar( - u"tag:yaml.org,2002:int", prefix + s, anchor=anchor - ) - - def represent_scalar_int(self, data): - # type: (Any) -> Any - if data._width is not None: - s = "{:0{}d}".format(data, data._width) - else: - s = format(data, "d") - anchor = data.yaml_anchor(any=True) - return self.insert_underscore("", s, data._underscore, anchor=anchor) - - def represent_binary_int(self, data): - # type: (Any) -> Any - if data._width is not None: - # cannot use '{:#0{}b}', that strips the zeros - s = "{:0{}b}".format(data, data._width) - else: - s = format(data, "b") - anchor = data.yaml_anchor(any=True) - return self.insert_underscore("0b", s, data._underscore, anchor=anchor) - - def represent_octal_int(self, data): - # type: (Any) -> Any - if data._width is not None: - # cannot use '{:#0{}o}', that strips the zeros - s = "{:0{}o}".format(data, data._width) - else: - s = format(data, "o") - anchor = data.yaml_anchor(any=True) - return self.insert_underscore("0o", s, data._underscore, anchor=anchor) - - def represent_hex_int(self, data): - # type: (Any) -> Any - if data._width is not None: - # cannot use '{:#0{}x}', that strips the zeros - s = "{:0{}x}".format(data, data._width) - else: - s = format(data, "x") - anchor = data.yaml_anchor(any=True) - return self.insert_underscore("0x", s, data._underscore, anchor=anchor) - - def represent_hex_caps_int(self, data): - # type: (Any) -> Any - if data._width is not None: - # cannot use '{:#0{}X}', that strips the zeros - s = "{:0{}X}".format(data, data._width) - else: - s = format(data, "X") - anchor = data.yaml_anchor(any=True) - return self.insert_underscore("0x", s, data._underscore, anchor=anchor) - - def represent_scalar_float(self, data): - # type: (Any) -> Any - """ this is way more complicated """ - value = None - anchor = data.yaml_anchor(any=True) - if data != data or (data == 0.0 and data == 1.0): - value = u".nan" - elif data == self.inf_value: - value = u".inf" - elif data == -self.inf_value: - value = u"-.inf" - if value: - return self.represent_scalar( - u"tag:yaml.org,2002:float", value, anchor=anchor - ) - if data._exp is None and data._prec > 0 and data._prec == data._width - 1: - # no exponent, but trailing dot - value = u"{}{:d}.".format( - data._m_sign if data._m_sign else "", abs(int(data)) - ) - elif data._exp is None: - # no exponent, "normal" dot - prec = data._prec - ms = data._m_sign if data._m_sign else "" - # -1 for the dot - value = u"{}{:0{}.{}f}".format( - ms, abs(data), data._width - len(ms), data._width - prec - 1 - ) - if prec == 0 or (prec == 1 and ms != ""): - value = value.replace(u"0.", u".") - while len(value) < data._width: - value += u"0" - else: - # exponent - m, es = u"{:{}.{}e}".format( - # data, data._width, data._width - data._prec + (1 if data._m_sign else 0) - data, - data._width, - data._width + (1 if data._m_sign else 0), - ).split("e") - w = data._width if data._prec > 0 else (data._width + 1) - if data < 0: - w += 1 - m = m[:w] - e = int(es) - m1, m2 = m.split(".") # always second? - while len(m1) + len(m2) < data._width - (1 if data._prec >= 0 else 0): - m2 += u"0" - if data._m_sign and data > 0: - m1 = "+" + m1 - esgn = u"+" if data._e_sign else "" - if data._prec < 0: # mantissa without dot - if m2 != u"0": - e -= len(m2) - else: - m2 = "" - while (len(m1) + len(m2) - (1 if data._m_sign else 0)) < data._width: - m2 += u"0" - e -= 1 - value = ( - m1 + m2 + data._exp + u"{:{}0{}d}".format(e, esgn, data._e_width) - ) - elif data._prec == 0: # mantissa with trailing dot - e -= len(m2) - value = ( - m1 - + m2 - + u"." - + data._exp - + u"{:{}0{}d}".format(e, esgn, data._e_width) - ) - else: - if data._m_lead0 > 0: - m2 = u"0" * (data._m_lead0 - 1) + m1 + m2 - m1 = u"0" - m2 = m2[: -data._m_lead0] # these should be zeros - e += data._m_lead0 - while len(m1) < data._prec: - m1 += m2[0] - m2 = m2[1:] - e -= 1 - value = ( - m1 - + u"." - + m2 - + data._exp - + u"{:{}0{}d}".format(e, esgn, data._e_width) - ) - - if value is None: - value = to_unicode(repr(data)).lower() - return self.represent_scalar(u"tag:yaml.org,2002:float", value, anchor=anchor) - - def represent_sequence(self, tag, sequence, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - # if the flow_style is None, the flow style tacked on to the object - # explicitly will be taken. If that is None as well the default flow - # style rules - try: - flow_style = sequence.fa.flow_style(flow_style) - except AttributeError: - flow_style = flow_style - try: - anchor = sequence.yaml_anchor() - except AttributeError: - anchor = None - node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - try: - comment = getattr(sequence, comment_attrib) - node.comment = comment.comment - # reset any comment already printed information - if node.comment and node.comment[1]: - for ct in node.comment[1]: - ct.reset() - item_comments = comment.items - for v in item_comments.values(): - if v and v[1]: - for ct in v[1]: - ct.reset() - item_comments = comment.items - node.comment = comment.comment - try: - node.comment.append(comment.end) - except AttributeError: - pass - except AttributeError: - item_comments = {} - for idx, item in enumerate(sequence): - node_item = self.represent_data(item) - self.merge_comments(node_item, item_comments.get(idx)) - if not (isinstance(node_item, ScalarNode) and not node_item.style): - best_style = False - value.append(node_item) - if flow_style is None: - if len(sequence) != 0 and self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - def merge_comments(self, node, comments): - # type: (Any, Any) -> Any - if comments is None: - assert hasattr(node, "comment") - return node - if getattr(node, "comment", None) is not None: - for idx, val in enumerate(comments): - if idx >= len(node.comment): - continue - nc = node.comment[idx] - if nc is not None: - assert val is None or val == nc - comments[idx] = nc - node.comment = comments - return node - - def represent_key(self, data): - # type: (Any) -> Any - if isinstance(data, CommentedKeySeq): - self.alias_key = None - return self.represent_sequence( - u"tag:yaml.org,2002:seq", data, flow_style=True - ) - if isinstance(data, CommentedKeyMap): - self.alias_key = None - return self.represent_mapping( - u"tag:yaml.org,2002:map", data, flow_style=True - ) - return SafeRepresenter.represent_key(self, data) - - def represent_mapping(self, tag, mapping, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - try: - flow_style = mapping.fa.flow_style(flow_style) - except AttributeError: - flow_style = flow_style - try: - anchor = mapping.yaml_anchor() - except AttributeError: - anchor = None - node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - # no sorting! !! - try: - comment = getattr(mapping, comment_attrib) - node.comment = comment.comment - if node.comment and node.comment[1]: - for ct in node.comment[1]: - ct.reset() - item_comments = comment.items - for v in item_comments.values(): - if v and v[1]: - for ct in v[1]: - ct.reset() - try: - node.comment.append(comment.end) - except AttributeError: - pass - except AttributeError: - item_comments = {} - merge_list = [m[1] for m in getattr(mapping, merge_attrib, [])] - try: - merge_pos = getattr(mapping, merge_attrib, [[0]])[0][0] - except IndexError: - merge_pos = 0 - item_count = 0 - if bool(merge_list): - items = mapping.non_merged_items() - else: - items = mapping.items() - for item_key, item_value in items: - item_count += 1 - node_key = self.represent_key(item_key) - node_value = self.represent_data(item_value) - item_comment = item_comments.get(item_key) - if item_comment: - assert getattr(node_key, "comment", None) is None - node_key.comment = item_comment[:2] - nvc = getattr(node_value, "comment", None) - if nvc is not None: # end comment already there - nvc[0] = item_comment[2] - nvc[1] = item_comment[3] - else: - node_value.comment = item_comment[2:] - if not (isinstance(node_key, ScalarNode) and not node_key.style): - best_style = False - if not (isinstance(node_value, ScalarNode) and not node_value.style): - best_style = False - value.append((node_key, node_value)) - if flow_style is None: - if ( - (item_count != 0) or bool(merge_list) - ) and self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - if bool(merge_list): - # because of the call to represent_data here, the anchors - # are marked as being used and thereby created - if len(merge_list) == 1: - arg = self.represent_data(merge_list[0]) - else: - arg = self.represent_data(merge_list) - arg.flow_style = True - value.insert(merge_pos, (ScalarNode(u"tag:yaml.org,2002:merge", "<<"), arg)) - return node - - def represent_omap(self, tag, omap, flow_style=None): - # type: (Any, Any, Any) -> Any - value = [] # type: List[Any] - try: - flow_style = omap.fa.flow_style(flow_style) - except AttributeError: - flow_style = flow_style - try: - anchor = omap.yaml_anchor() - except AttributeError: - anchor = None - node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - try: - comment = getattr(omap, comment_attrib) - node.comment = comment.comment - if node.comment and node.comment[1]: - for ct in node.comment[1]: - ct.reset() - item_comments = comment.items - for v in item_comments.values(): - if v and v[1]: - for ct in v[1]: - ct.reset() - try: - node.comment.append(comment.end) - except AttributeError: - pass - except AttributeError: - item_comments = {} - for item_key in omap: - item_val = omap[item_key] - node_item = self.represent_data({item_key: item_val}) - # node_item.flow_style = False - # node item has two scalars in value: node_key and node_value - item_comment = item_comments.get(item_key) - if item_comment: - if item_comment[1]: - node_item.comment = [None, item_comment[1]] - assert getattr(node_item.value[0][0], "comment", None) is None - node_item.value[0][0].comment = [item_comment[0], None] - nvc = getattr(node_item.value[0][1], "comment", None) - if nvc is not None: # end comment already there - nvc[0] = item_comment[2] - nvc[1] = item_comment[3] - else: - node_item.value[0][1].comment = item_comment[2:] - # if not (isinstance(node_item, ScalarNode) \ - # and not node_item.style): - # best_style = False - value.append(node_item) - if flow_style is None: - if self.default_flow_style is not None: - node.flow_style = self.default_flow_style - else: - node.flow_style = best_style - return node - - def represent_set(self, setting): - # type: (Any) -> Any - flow_style = False - tag = u"tag:yaml.org,2002:set" - # return self.represent_mapping(tag, value) - value = [] # type: List[Any] - flow_style = setting.fa.flow_style(flow_style) - try: - anchor = setting.yaml_anchor() - except AttributeError: - anchor = None - node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor) - if self.alias_key is not None: - self.represented_objects[self.alias_key] = node - best_style = True - # no sorting! !! - try: - comment = getattr(setting, comment_attrib) - node.comment = comment.comment - if node.comment and node.comment[1]: - for ct in node.comment[1]: - ct.reset() - item_comments = comment.items - for v in item_comments.values(): - if v and v[1]: - for ct in v[1]: - ct.reset() - try: - node.comment.append(comment.end) - except AttributeError: - pass - except AttributeError: - item_comments = {} - for item_key in setting.odict: - node_key = self.represent_key(item_key) - node_value = self.represent_data(None) - item_comment = item_comments.get(item_key) - if item_comment: - assert getattr(node_key, "comment", None) is None - node_key.comment = item_comment[:2] - node_key.style = node_value.style = "?" - if not (isinstance(node_key, ScalarNode) and not node_key.style): - best_style = False - if not (isinstance(node_value, ScalarNode) and not node_value.style): - best_style = False - value.append((node_key, node_value)) - best_style = best_style - return node - - def represent_dict(self, data): - # type: (Any) -> Any - """write out tag if saved on loading""" - try: - t = data.tag.value - except AttributeError: - t = None - if t: - if t.startswith("!!"): - tag = "tag:yaml.org,2002:" + t[2:] - else: - tag = t - else: - tag = u"tag:yaml.org,2002:map" - return self.represent_mapping(tag, data) - - def represent_list(self, data): - # type: (Any) -> Any - try: - t = data.tag.value - except AttributeError: - t = None - if t: - if t.startswith("!!"): - tag = "tag:yaml.org,2002:" + t[2:] - else: - tag = t - else: - tag = u"tag:yaml.org,2002:seq" - return self.represent_sequence(tag, data) - - def represent_datetime(self, data): - # type: (Any) -> Any - inter = "T" if data._yaml["t"] else " " - _yaml = data._yaml - if _yaml["delta"]: - data += _yaml["delta"] - value = data.isoformat(inter) - else: - value = data.isoformat(inter) - if _yaml["tz"]: - value += _yaml["tz"] - return self.represent_scalar(u"tag:yaml.org,2002:timestamp", to_unicode(value)) - - def represent_tagged_scalar(self, data): - # type: (Any) -> Any - try: - tag = data.tag.value - except AttributeError: - tag = None - try: - anchor = data.yaml_anchor() - except AttributeError: - anchor = None - return self.represent_scalar(tag, data.value, style=data.style, anchor=anchor) - - def represent_scalar_bool(self, data): - # type: (Any) -> Any - try: - anchor = data.yaml_anchor() - except AttributeError: - anchor = None - return SafeRepresenter.represent_bool(self, data, anchor=anchor) - - -RoundTripRepresenter.add_representer(type(None), RoundTripRepresenter.represent_none) - -RoundTripRepresenter.add_representer( - LiteralScalarString, RoundTripRepresenter.represent_literal_scalarstring -) - -RoundTripRepresenter.add_representer( - FoldedScalarString, RoundTripRepresenter.represent_folded_scalarstring -) - -RoundTripRepresenter.add_representer( - SingleQuotedScalarString, RoundTripRepresenter.represent_single_quoted_scalarstring -) - -RoundTripRepresenter.add_representer( - DoubleQuotedScalarString, RoundTripRepresenter.represent_double_quoted_scalarstring -) - -RoundTripRepresenter.add_representer( - PlainScalarString, RoundTripRepresenter.represent_plain_scalarstring -) - -# RoundTripRepresenter.add_representer(tuple, Representer.represent_tuple) - -RoundTripRepresenter.add_representer( - ScalarInt, RoundTripRepresenter.represent_scalar_int -) - -RoundTripRepresenter.add_representer( - BinaryInt, RoundTripRepresenter.represent_binary_int -) - -RoundTripRepresenter.add_representer(OctalInt, RoundTripRepresenter.represent_octal_int) - -RoundTripRepresenter.add_representer(HexInt, RoundTripRepresenter.represent_hex_int) - -RoundTripRepresenter.add_representer( - HexCapsInt, RoundTripRepresenter.represent_hex_caps_int -) - -RoundTripRepresenter.add_representer( - ScalarFloat, RoundTripRepresenter.represent_scalar_float -) - -RoundTripRepresenter.add_representer( - ScalarBoolean, RoundTripRepresenter.represent_scalar_bool -) - -RoundTripRepresenter.add_representer(CommentedSeq, RoundTripRepresenter.represent_list) - -RoundTripRepresenter.add_representer(CommentedMap, RoundTripRepresenter.represent_dict) - -RoundTripRepresenter.add_representer( - CommentedOrderedMap, RoundTripRepresenter.represent_ordereddict -) - -if sys.version_info >= (2, 7): - import collections - - RoundTripRepresenter.add_representer( - collections.OrderedDict, RoundTripRepresenter.represent_ordereddict - ) - -RoundTripRepresenter.add_representer(CommentedSet, RoundTripRepresenter.represent_set) - -RoundTripRepresenter.add_representer( - TaggedScalar, RoundTripRepresenter.represent_tagged_scalar -) - -RoundTripRepresenter.add_representer(TimeStamp, RoundTripRepresenter.represent_datetime) diff --git a/srsly/ruamel_yaml/resolver.py b/srsly/ruamel_yaml/resolver.py deleted file mode 100755 index edcee9b..0000000 --- a/srsly/ruamel_yaml/resolver.py +++ /dev/null @@ -1,410 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -import re - -if False: # MYPY - from typing import Any, Dict, List, Union, Text, Optional # NOQA - from .compat import VersionType # NOQA - -from .compat import string_types, _DEFAULT_YAML_VERSION # NOQA -from .error import * # NOQA -from .nodes import MappingNode, ScalarNode, SequenceNode # NOQA -from .util import RegExp # NOQA - -__all__ = ["BaseResolver", "Resolver", "VersionedResolver"] - - -# fmt: off -# resolvers consist of -# - a list of applicable version -# - a tag -# - a regexp -# - a list of first characters to match -implicit_resolvers = [ - ([(1, 2)], - u'tag:yaml.org,2002:bool', - RegExp(u'''^(?:true|True|TRUE|false|False|FALSE)$''', re.X), - list(u'tTfF')), - ([(1, 1)], - u'tag:yaml.org,2002:bool', - RegExp(u'''^(?:y|Y|yes|Yes|YES|n|N|no|No|NO - |true|True|TRUE|false|False|FALSE - |on|On|ON|off|Off|OFF)$''', re.X), - list(u'yYnNtTfFoO')), - ([(1, 2)], - u'tag:yaml.org,2002:float', - RegExp(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |[-+]?\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')), - ([(1, 1)], - u'tag:yaml.org,2002:float', - RegExp(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* # sexagesimal float - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')), - ([(1, 2)], - u'tag:yaml.org,2002:int', - RegExp(u'''^(?:[-+]?0b[0-1_]+ - |[-+]?0o?[0-7_]+ - |[-+]?[0-9_]+ - |[-+]?0x[0-9a-fA-F_]+)$''', re.X), - list(u'-+0123456789')), - ([(1, 1)], - u'tag:yaml.org,2002:int', - RegExp(u'''^(?:[-+]?0b[0-1_]+ - |[-+]?0?[0-7_]+ - |[-+]?(?:0|[1-9][0-9_]*) - |[-+]?0x[0-9a-fA-F_]+ - |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), # sexagesimal int - list(u'-+0123456789')), - ([(1, 2), (1, 1)], - u'tag:yaml.org,2002:merge', - RegExp(u'^(?:<<)$'), - [u'<']), - ([(1, 2), (1, 1)], - u'tag:yaml.org,2002:null', - RegExp(u'''^(?: ~ - |null|Null|NULL - | )$''', re.X), - [u'~', u'n', u'N', u'']), - ([(1, 2), (1, 1)], - u'tag:yaml.org,2002:timestamp', - RegExp(u'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] - |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]? - (?:[Tt]|[ \\t]+)[0-9][0-9]? - :[0-9][0-9] :[0-9][0-9] (?:\\.[0-9]*)? - (?:[ \\t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X), - list(u'0123456789')), - ([(1, 2), (1, 1)], - u'tag:yaml.org,2002:value', - RegExp(u'^(?:=)$'), - [u'=']), - # The following resolver is only for documentation purposes. It cannot work - # because plain scalars cannot start with '!', '&', or '*'. - ([(1, 2), (1, 1)], - u'tag:yaml.org,2002:yaml', - RegExp(u'^(?:!|&|\\*)$'), - list(u'!&*')), -] -# fmt: on - - -class ResolverError(YAMLError): - pass - - -class BaseResolver(object): - - DEFAULT_SCALAR_TAG = u"tag:yaml.org,2002:str" - DEFAULT_SEQUENCE_TAG = u"tag:yaml.org,2002:seq" - DEFAULT_MAPPING_TAG = u"tag:yaml.org,2002:map" - - yaml_implicit_resolvers = {} # type: Dict[Any, Any] - yaml_path_resolvers = {} # type: Dict[Any, Any] - - def __init__(self, loadumper=None): - # type: (Any, Any) -> None - self.loadumper = loadumper - if ( - self.loadumper is not None - and getattr(self.loadumper, "_resolver", None) is None - ): - self.loadumper._resolver = self.loadumper - self._loader_version = None # type: Any - self.resolver_exact_paths = [] # type: List[Any] - self.resolver_prefix_paths = [] # type: List[Any] - - @property - def parser(self): - # type: () -> Any - if self.loadumper is not None: - if hasattr(self.loadumper, "typ"): - return self.loadumper.parser - return self.loadumper._parser - return None - - @classmethod - def add_implicit_resolver_base(cls, tag, regexp, first): - # type: (Any, Any, Any) -> None - if "yaml_implicit_resolvers" not in cls.__dict__: - # deepcopy doesn't work here - cls.yaml_implicit_resolvers = dict( - (k, cls.yaml_implicit_resolvers[k][:]) - for k in cls.yaml_implicit_resolvers - ) - if first is None: - first = [None] - for ch in first: - cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp)) - - @classmethod - def add_implicit_resolver(cls, tag, regexp, first): - # type: (Any, Any, Any) -> None - if "yaml_implicit_resolvers" not in cls.__dict__: - # deepcopy doesn't work here - cls.yaml_implicit_resolvers = dict( - (k, cls.yaml_implicit_resolvers[k][:]) - for k in cls.yaml_implicit_resolvers - ) - if first is None: - first = [None] - for ch in first: - cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp)) - implicit_resolvers.append(([(1, 2), (1, 1)], tag, regexp, first)) - - # @classmethod - # def add_implicit_resolver(cls, tag, regexp, first): - - @classmethod - def add_path_resolver(cls, tag, path, kind=None): - # type: (Any, Any, Any) -> None - # Note: `add_path_resolver` is experimental. The API could be changed. - # `new_path` is a pattern that is matched against the path from the - # root to the node that is being considered. `node_path` elements are - # tuples `(node_check, index_check)`. `node_check` is a node class: - # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None` - # matches any kind of a node. `index_check` could be `None`, a boolean - # value, a string value, or a number. `None` and `False` match against - # any _value_ of sequence and mapping nodes. `True` matches against - # any _key_ of a mapping node. A string `index_check` matches against - # a mapping value that corresponds to a scalar key which content is - # equal to the `index_check` value. An integer `index_check` matches - # against a sequence value with the index equal to `index_check`. - if "yaml_path_resolvers" not in cls.__dict__: - cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy() - new_path = [] # type: List[Any] - for element in path: - if isinstance(element, (list, tuple)): - if len(element) == 2: - node_check, index_check = element - elif len(element) == 1: - node_check = element[0] - index_check = True - else: - raise ResolverError("Invalid path element: %s" % (element,)) - else: - node_check = None - index_check = element - if node_check is str: - node_check = ScalarNode - elif node_check is list: - node_check = SequenceNode - elif node_check is dict: - node_check = MappingNode - elif ( - node_check not in [ScalarNode, SequenceNode, MappingNode] - and not isinstance(node_check, string_types) - and node_check is not None - ): - raise ResolverError("Invalid node checker: %s" % (node_check,)) - if ( - not isinstance(index_check, (string_types, int)) - and index_check is not None - ): - raise ResolverError("Invalid index checker: %s" % (index_check,)) - new_path.append((node_check, index_check)) - if kind is str: - kind = ScalarNode - elif kind is list: - kind = SequenceNode - elif kind is dict: - kind = MappingNode - elif kind not in [ScalarNode, SequenceNode, MappingNode] and kind is not None: - raise ResolverError("Invalid node kind: %s" % (kind,)) - cls.yaml_path_resolvers[tuple(new_path), kind] = tag - - def descend_resolver(self, current_node, current_index): - # type: (Any, Any) -> None - if not self.yaml_path_resolvers: - return - exact_paths = {} - prefix_paths = [] - if current_node: - depth = len(self.resolver_prefix_paths) - for path, kind in self.resolver_prefix_paths[-1]: - if self.check_resolver_prefix( - depth, path, kind, current_node, current_index - ): - if len(path) > depth: - prefix_paths.append((path, kind)) - else: - exact_paths[kind] = self.yaml_path_resolvers[path, kind] - else: - for path, kind in self.yaml_path_resolvers: - if not path: - exact_paths[kind] = self.yaml_path_resolvers[path, kind] - else: - prefix_paths.append((path, kind)) - self.resolver_exact_paths.append(exact_paths) - self.resolver_prefix_paths.append(prefix_paths) - - def ascend_resolver(self): - # type: () -> None - if not self.yaml_path_resolvers: - return - self.resolver_exact_paths.pop() - self.resolver_prefix_paths.pop() - - def check_resolver_prefix(self, depth, path, kind, current_node, current_index): - # type: (int, Text, Any, Any, Any) -> bool - node_check, index_check = path[depth - 1] - if isinstance(node_check, string_types): - if current_node.tag != node_check: - return False - elif node_check is not None: - if not isinstance(current_node, node_check): - return False - if index_check is True and current_index is not None: - return False - if (index_check is False or index_check is None) and current_index is None: - return False - if isinstance(index_check, string_types): - if not ( - isinstance(current_index, ScalarNode) - and index_check == current_index.value - ): - return False - elif isinstance(index_check, int) and not isinstance(index_check, bool): - if index_check != current_index: - return False - return True - - def resolve(self, kind, value, implicit): - # type: (Any, Any, Any) -> Any - if kind is ScalarNode and implicit[0]: - if value == "": - resolvers = self.yaml_implicit_resolvers.get("", []) - else: - resolvers = self.yaml_implicit_resolvers.get(value[0], []) - resolvers += self.yaml_implicit_resolvers.get(None, []) - for tag, regexp in resolvers: - if regexp.match(value): - return tag - implicit = implicit[1] - if bool(self.yaml_path_resolvers): - exact_paths = self.resolver_exact_paths[-1] - if kind in exact_paths: - return exact_paths[kind] - if None in exact_paths: - return exact_paths[None] - if kind is ScalarNode: - return self.DEFAULT_SCALAR_TAG - elif kind is SequenceNode: - return self.DEFAULT_SEQUENCE_TAG - elif kind is MappingNode: - return self.DEFAULT_MAPPING_TAG - - @property - def processing_version(self): - # type: () -> Any - return None - - -class Resolver(BaseResolver): - pass - - -for ir in implicit_resolvers: - if (1, 2) in ir[0]: - Resolver.add_implicit_resolver_base(*ir[1:]) - - -class VersionedResolver(BaseResolver): - """ - contrary to the "normal" resolver, the smart resolver delays loading - the pattern matching rules. That way it can decide to load 1.1 rules - or the (default) 1.2 rules, that no longer support octal without 0o, sexagesimals - and Yes/No/On/Off booleans. - """ - - def __init__(self, version=None, loader=None, loadumper=None): - # type: (Optional[VersionType], Any, Any) -> None - if loader is None and loadumper is not None: - loader = loadumper - BaseResolver.__init__(self, loader) - self._loader_version = self.get_loader_version(version) - self._version_implicit_resolver = {} # type: Dict[Any, Any] - - def add_version_implicit_resolver(self, version, tag, regexp, first): - # type: (VersionType, Any, Any, Any) -> None - if first is None: - first = [None] - impl_resolver = self._version_implicit_resolver.setdefault(version, {}) - for ch in first: - impl_resolver.setdefault(ch, []).append((tag, regexp)) - - def get_loader_version(self, version): - # type: (Optional[VersionType]) -> Any - if version is None or isinstance(version, tuple): - return version - if isinstance(version, list): - return tuple(version) - # assume string - return tuple(map(int, version.split(u"."))) - - @property - def versioned_resolver(self): - # type: () -> Any - """ - select the resolver based on the version we are parsing - """ - version = self.processing_version - if version not in self._version_implicit_resolver: - for x in implicit_resolvers: - if version in x[0]: - self.add_version_implicit_resolver(version, x[1], x[2], x[3]) - return self._version_implicit_resolver[version] - - def resolve(self, kind, value, implicit): - # type: (Any, Any, Any) -> Any - if kind is ScalarNode and implicit[0]: - if value == "": - resolvers = self.versioned_resolver.get("", []) - else: - resolvers = self.versioned_resolver.get(value[0], []) - resolvers += self.versioned_resolver.get(None, []) - for tag, regexp in resolvers: - if regexp.match(value): - return tag - implicit = implicit[1] - if bool(self.yaml_path_resolvers): - exact_paths = self.resolver_exact_paths[-1] - if kind in exact_paths: - return exact_paths[kind] - if None in exact_paths: - return exact_paths[None] - if kind is ScalarNode: - return self.DEFAULT_SCALAR_TAG - elif kind is SequenceNode: - return self.DEFAULT_SEQUENCE_TAG - elif kind is MappingNode: - return self.DEFAULT_MAPPING_TAG - - @property - def processing_version(self): - # type: () -> Any - try: - version = self.loadumper._scanner.yaml_version - except AttributeError: - try: - if hasattr(self.loadumper, "typ"): - version = self.loadumper.version - else: - version = self.loadumper._serializer.use_version # dumping - except AttributeError: - version = None - if version is None: - version = self._loader_version - if version is None: - version = _DEFAULT_YAML_VERSION - return version diff --git a/srsly/ruamel_yaml/scalarbool.py b/srsly/ruamel_yaml/scalarbool.py deleted file mode 100755 index 277f4d1..0000000 --- a/srsly/ruamel_yaml/scalarbool.py +++ /dev/null @@ -1,51 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -""" -You cannot subclass bool, and this is necessary for round-tripping anchored -bool values (and also if you want to preserve the original way of writing) - -bool.__bases__ is type 'int', so that is what is used as the basis for ScalarBoolean as well. - -You can use these in an if statement, but not when testing equivalence -""" - -from .anchor import Anchor - -if False: # MYPY - from typing import Text, Any, Dict, List # NOQA - -__all__ = ["ScalarBoolean"] - -# no need for no_limit_int -> int - - -class ScalarBoolean(int): - def __new__(cls, *args, **kw): - # type: (Any, Any, Any) -> Any - anchor = kw.pop("anchor", None) # type: ignore - b = int.__new__(cls, *args, **kw) # type: ignore - if anchor is not None: - b.yaml_set_anchor(anchor, always_dump=True) - return b - - @property - def anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - setattr(self, Anchor.attrib, Anchor()) - return getattr(self, Anchor.attrib) - - def yaml_anchor(self, any=False): - # type: (bool) -> Any - if not hasattr(self, Anchor.attrib): - return None - if any or self.anchor.always_dump: - return self.anchor - return None - - def yaml_set_anchor(self, value, always_dump=False): - # type: (Any, bool) -> None - self.anchor.value = value - self.anchor.always_dump = always_dump diff --git a/srsly/ruamel_yaml/scalarfloat.py b/srsly/ruamel_yaml/scalarfloat.py deleted file mode 100755 index 57fd7f9..0000000 --- a/srsly/ruamel_yaml/scalarfloat.py +++ /dev/null @@ -1,137 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -import sys -from .compat import no_limit_int # NOQA -from .anchor import Anchor - -if False: # MYPY - from typing import Text, Any, Dict, List # NOQA - -__all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"] - - -class ScalarFloat(float): - def __new__(cls, *args, **kw): - # type: (Any, Any, Any) -> Any - width = kw.pop("width", None) # type: ignore - prec = kw.pop("prec", None) # type: ignore - m_sign = kw.pop("m_sign", None) # type: ignore - m_lead0 = kw.pop("m_lead0", 0) # type: ignore - exp = kw.pop("exp", None) # type: ignore - e_width = kw.pop("e_width", None) # type: ignore - e_sign = kw.pop("e_sign", None) # type: ignore - underscore = kw.pop("underscore", None) # type: ignore - anchor = kw.pop("anchor", None) # type: ignore - v = float.__new__(cls, *args, **kw) # type: ignore - v._width = width - v._prec = prec - v._m_sign = m_sign - v._m_lead0 = m_lead0 - v._exp = exp - v._e_width = e_width - v._e_sign = e_sign - v._underscore = underscore - if anchor is not None: - v.yaml_set_anchor(anchor, always_dump=True) - return v - - def __iadd__(self, a): # type: ignore - # type: (Any) -> Any - return float(self) + a - x = type(self)(self + a) - x._width = self._width - x._underscore = ( - self._underscore[:] if self._underscore is not None else None - ) # NOQA - return x - - def __ifloordiv__(self, a): # type: ignore - # type: (Any) -> Any - return float(self) // a - x = type(self)(self // a) - x._width = self._width - x._underscore = ( - self._underscore[:] if self._underscore is not None else None - ) # NOQA - return x - - def __imul__(self, a): # type: ignore - # type: (Any) -> Any - return float(self) * a - x = type(self)(self * a) - x._width = self._width - x._underscore = ( - self._underscore[:] if self._underscore is not None else None - ) # NOQA - x._prec = self._prec # check for others - return x - - def __ipow__(self, a): # type: ignore - # type: (Any) -> Any - return float(self) ** a - x = type(self)(self ** a) - x._width = self._width - x._underscore = ( - self._underscore[:] if self._underscore is not None else None - ) # NOQA - return x - - def __isub__(self, a): # type: ignore - # type: (Any) -> Any - return float(self) - a - x = type(self)(self - a) - x._width = self._width - x._underscore = ( - self._underscore[:] if self._underscore is not None else None - ) # NOQA - return x - - @property - def anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - setattr(self, Anchor.attrib, Anchor()) - return getattr(self, Anchor.attrib) - - def yaml_anchor(self, any=False): - # type: (bool) -> Any - if not hasattr(self, Anchor.attrib): - return None - if any or self.anchor.always_dump: - return self.anchor - return None - - def yaml_set_anchor(self, value, always_dump=False): - # type: (Any, bool) -> None - self.anchor.value = value - self.anchor.always_dump = always_dump - - def dump(self, out=sys.stdout): - # type: (Any) -> Any - out.write( - "ScalarFloat({}| w:{}, p:{}, s:{}, lz:{}, _:{}|{}, w:{}, s:{})\n".format( - self, - self._width, # type: ignore - self._prec, # type: ignore - self._m_sign, # type: ignore - self._m_lead0, # type: ignore - self._underscore, # type: ignore - self._exp, # type: ignore - self._e_width, # type: ignore - self._e_sign, # type: ignore - ) - ) - - -class ExponentialFloat(ScalarFloat): - def __new__(cls, value, width=None, underscore=None): - # type: (Any, Any, Any) -> Any - return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) - - -class ExponentialCapsFloat(ScalarFloat): - def __new__(cls, value, width=None, underscore=None): - # type: (Any, Any, Any) -> Any - return ScalarFloat.__new__(cls, value, width=width, underscore=underscore) diff --git a/srsly/ruamel_yaml/scalarint.py b/srsly/ruamel_yaml/scalarint.py deleted file mode 100755 index 383e25c..0000000 --- a/srsly/ruamel_yaml/scalarint.py +++ /dev/null @@ -1,150 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -from .compat import no_limit_int # NOQA -from .anchor import Anchor - -if False: # MYPY - from typing import Text, Any, Dict, List # NOQA - -__all__ = ["ScalarInt", "BinaryInt", "OctalInt", "HexInt", "HexCapsInt", "DecimalInt"] - - -class ScalarInt(no_limit_int): - def __new__(cls, *args, **kw): - # type: (Any, Any, Any) -> Any - width = kw.pop("width", None) # type: ignore - underscore = kw.pop("underscore", None) # type: ignore - anchor = kw.pop("anchor", None) # type: ignore - v = no_limit_int.__new__(cls, *args, **kw) # type: ignore - v._width = width - v._underscore = underscore - if anchor is not None: - v.yaml_set_anchor(anchor, always_dump=True) - return v - - def __iadd__(self, a): # type: ignore - # type: (Any) -> Any - x = type(self)(self + a) - x._width = self._width # type: ignore - x._underscore = ( # type: ignore - self._underscore[:] - if self._underscore is not None - else None # type: ignore - ) # NOQA - return x - - def __ifloordiv__(self, a): # type: ignore - # type: (Any) -> Any - x = type(self)(self // a) - x._width = self._width # type: ignore - x._underscore = ( # type: ignore - self._underscore[:] - if self._underscore is not None - else None # type: ignore - ) # NOQA - return x - - def __imul__(self, a): # type: ignore - # type: (Any) -> Any - x = type(self)(self * a) - x._width = self._width # type: ignore - x._underscore = ( # type: ignore - self._underscore[:] - if self._underscore is not None - else None # type: ignore - ) # NOQA - return x - - def __ipow__(self, a): # type: ignore - # type: (Any) -> Any - x = type(self)(self ** a) - x._width = self._width # type: ignore - x._underscore = ( # type: ignore - self._underscore[:] - if self._underscore is not None - else None # type: ignore - ) # NOQA - return x - - def __isub__(self, a): # type: ignore - # type: (Any) -> Any - x = type(self)(self - a) - x._width = self._width # type: ignore - x._underscore = ( # type: ignore - self._underscore[:] - if self._underscore is not None - else None # type: ignore - ) # NOQA - return x - - @property - def anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - setattr(self, Anchor.attrib, Anchor()) - return getattr(self, Anchor.attrib) - - def yaml_anchor(self, any=False): - # type: (bool) -> Any - if not hasattr(self, Anchor.attrib): - return None - if any or self.anchor.always_dump: - return self.anchor - return None - - def yaml_set_anchor(self, value, always_dump=False): - # type: (Any, bool) -> None - self.anchor.value = value - self.anchor.always_dump = always_dump - - -class BinaryInt(ScalarInt): - def __new__(cls, value, width=None, underscore=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - return ScalarInt.__new__( - cls, value, width=width, underscore=underscore, anchor=anchor - ) - - -class OctalInt(ScalarInt): - def __new__(cls, value, width=None, underscore=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - return ScalarInt.__new__( - cls, value, width=width, underscore=underscore, anchor=anchor - ) - - -# mixed casing of A-F is not supported, when loading the first non digit -# determines the case - - -class HexInt(ScalarInt): - """uses lower case (a-f)""" - - def __new__(cls, value, width=None, underscore=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - return ScalarInt.__new__( - cls, value, width=width, underscore=underscore, anchor=anchor - ) - - -class HexCapsInt(ScalarInt): - """uses upper case (A-F)""" - - def __new__(cls, value, width=None, underscore=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - return ScalarInt.__new__( - cls, value, width=width, underscore=underscore, anchor=anchor - ) - - -class DecimalInt(ScalarInt): - """needed if anchor""" - - def __new__(cls, value, width=None, underscore=None, anchor=None): - # type: (Any, Any, Any, Any) -> Any - return ScalarInt.__new__( - cls, value, width=width, underscore=underscore, anchor=anchor - ) diff --git a/srsly/ruamel_yaml/scalarstring.py b/srsly/ruamel_yaml/scalarstring.py deleted file mode 100755 index 0638e4b..0000000 --- a/srsly/ruamel_yaml/scalarstring.py +++ /dev/null @@ -1,156 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -from .compat import text_type -from .anchor import Anchor - -if False: # MYPY - from typing import Text, Any, Dict, List # NOQA - -__all__ = [ - "ScalarString", - "LiteralScalarString", - "FoldedScalarString", - "SingleQuotedScalarString", - "DoubleQuotedScalarString", - "PlainScalarString", - # PreservedScalarString is the old name, as it was the first to be preserved on rt, - # use LiteralScalarString instead - "PreservedScalarString", -] - - -class ScalarString(text_type): - __slots__ = Anchor.attrib - - def __new__(cls, *args, **kw): - # type: (Any, Any) -> Any - anchor = kw.pop("anchor", None) # type: ignore - ret_val = text_type.__new__(cls, *args, **kw) # type: ignore - if anchor is not None: - ret_val.yaml_set_anchor(anchor, always_dump=True) - return ret_val - - def replace(self, old, new, maxreplace=-1): - # type: (Any, Any, int) -> Any - return type(self)((text_type.replace(self, old, new, maxreplace))) - - @property - def anchor(self): - # type: () -> Any - if not hasattr(self, Anchor.attrib): - setattr(self, Anchor.attrib, Anchor()) - return getattr(self, Anchor.attrib) - - def yaml_anchor(self, any=False): - # type: (bool) -> Any - if not hasattr(self, Anchor.attrib): - return None - if any or self.anchor.always_dump: - return self.anchor - return None - - def yaml_set_anchor(self, value, always_dump=False): - # type: (Any, bool) -> None - self.anchor.value = value - self.anchor.always_dump = always_dump - - -class LiteralScalarString(ScalarString): - __slots__ = "comment" # the comment after the | on the first line - - style = "|" - - def __new__(cls, value, anchor=None): - # type: (Text, Any) -> Any - return ScalarString.__new__(cls, value, anchor=anchor) - - -PreservedScalarString = LiteralScalarString - - -class FoldedScalarString(ScalarString): - __slots__ = ("fold_pos", "comment") # the comment after the > on the first line - - style = ">" - - def __new__(cls, value, anchor=None): - # type: (Text, Any) -> Any - return ScalarString.__new__(cls, value, anchor=anchor) - - -class SingleQuotedScalarString(ScalarString): - __slots__ = () - - style = "'" - - def __new__(cls, value, anchor=None): - # type: (Text, Any) -> Any - return ScalarString.__new__(cls, value, anchor=anchor) - - -class DoubleQuotedScalarString(ScalarString): - __slots__ = () - - style = '"' - - def __new__(cls, value, anchor=None): - # type: (Text, Any) -> Any - return ScalarString.__new__(cls, value, anchor=anchor) - - -class PlainScalarString(ScalarString): - __slots__ = () - - style = "" - - def __new__(cls, value, anchor=None): - # type: (Text, Any) -> Any - return ScalarString.__new__(cls, value, anchor=anchor) - - -def preserve_literal(s): - # type: (Text) -> Text - return LiteralScalarString(s.replace("\r\n", "\n").replace("\r", "\n")) - - -def walk_tree(base, map=None): - # type: (Any, Any) -> None - """ - the routine here walks over a simple yaml tree (recursing in - dict values and list items) and converts strings that - have multiple lines to literal scalars - - You can also provide an explicit (ordered) mapping for multiple transforms - (first of which is executed): - map = .compat.ordereddict - map['\n'] = preserve_literal - map[':'] = SingleQuotedScalarString - walk_tree(data, map=map) - """ - from .compat import string_types - from .compat import MutableMapping, MutableSequence # type: ignore - - if map is None: - map = {"\n": preserve_literal} - - if isinstance(base, MutableMapping): - for k in base: - v = base[k] # type: Text - if isinstance(v, string_types): - for ch in map: - if ch in v: - base[k] = map[ch](v) - break - else: - walk_tree(v) - elif isinstance(base, MutableSequence): - for idx, elem in enumerate(base): - if isinstance(elem, string_types): - for ch in map: - if ch in elem: # type: ignore - base[idx] = map[ch](elem) - break - else: - walk_tree(elem) diff --git a/srsly/ruamel_yaml/scanner.py b/srsly/ruamel_yaml/scanner.py deleted file mode 100755 index a1d1ad0..0000000 --- a/srsly/ruamel_yaml/scanner.py +++ /dev/null @@ -1,2011 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -# Scanner produces tokens of the following types: -# STREAM-START -# STREAM-END -# DIRECTIVE(name, value) -# DOCUMENT-START -# DOCUMENT-END -# BLOCK-SEQUENCE-START -# BLOCK-MAPPING-START -# BLOCK-END -# FLOW-SEQUENCE-START -# FLOW-MAPPING-START -# FLOW-SEQUENCE-END -# FLOW-MAPPING-END -# BLOCK-ENTRY -# FLOW-ENTRY -# KEY -# VALUE -# ALIAS(value) -# ANCHOR(value) -# TAG(value) -# SCALAR(value, plain, style) -# -# RoundTripScanner -# COMMENT(value) -# -# Read comments in the Scanner code for more details. -# - -from .error import MarkedYAMLError -from .tokens import * # NOQA -from .compat import utf8, unichr, PY3, check_anchorname_char, nprint # NOQA - -if False: # MYPY - from typing import Any, Dict, Optional, List, Union, Text # NOQA - from .compat import VersionType # NOQA - -__all__ = ["Scanner", "RoundTripScanner", "ScannerError"] - - -_THE_END = "\n\0\r\x85\u2028\u2029" -_THE_END_SPACE_TAB = " \n\0\t\r\x85\u2028\u2029" -_SPACE_TAB = " \t" - - -class ScannerError(MarkedYAMLError): - pass - - -class SimpleKey(object): - # See below simple keys treatment. - - def __init__(self, token_number, required, index, line, column, mark): - # type: (Any, Any, int, int, int, Any) -> None - self.token_number = token_number - self.required = required - self.index = index - self.line = line - self.column = column - self.mark = mark - - -class Scanner(object): - def __init__(self, loader=None): - # type: (Any) -> None - """Initialize the scanner.""" - # It is assumed that Scanner and Reader will have a common descendant. - # Reader do the dirty work of checking for BOM and converting the - # input data to Unicode. It also adds NUL to the end. - # - # Reader supports the following methods - # self.peek(i=0) # peek the next i-th character - # self.prefix(l=1) # peek the next l characters - # self.forward(l=1) # read the next l characters and move the pointer - - self.loader = loader - if self.loader is not None and getattr(self.loader, "_scanner", None) is None: - self.loader._scanner = self - self.reset_scanner() - self.first_time = False - self.yaml_version = None # type: Any - - @property - def flow_level(self): - # type: () -> int - return len(self.flow_context) - - def reset_scanner(self): - # type: () -> None - # Had we reached the end of the stream? - self.done = False - - # flow_context is an expanding/shrinking list consisting of '{' and '[' - # for each unclosed flow context. If empty list that means block context - self.flow_context = [] # type: List[Text] - - # List of processed tokens that are not yet emitted. - self.tokens = [] # type: List[Any] - - # Add the STREAM-START token. - self.fetch_stream_start() - - # Number of tokens that were emitted through the `get_token` method. - self.tokens_taken = 0 - - # The current indentation level. - self.indent = -1 - - # Past indentation levels. - self.indents = [] # type: List[int] - - # Variables related to simple keys treatment. - - # A simple key is a key that is not denoted by the '?' indicator. - # Example of simple keys: - # --- - # block simple key: value - # ? not a simple key: - # : { flow simple key: value } - # We emit the KEY token before all keys, so when we find a potential - # simple key, we try to locate the corresponding ':' indicator. - # Simple keys should be limited to a single line and 1024 characters. - - # Can a simple key start at the current position? A simple key may - # start: - # - at the beginning of the line, not counting indentation spaces - # (in block context), - # - after '{', '[', ',' (in the flow context), - # - after '?', ':', '-' (in the block context). - # In the block context, this flag also signifies if a block collection - # may start at the current position. - self.allow_simple_key = True - - # Keep track of possible simple keys. This is a dictionary. The key - # is `flow_level`; there can be no more that one possible simple key - # for each level. The value is a SimpleKey record: - # (token_number, required, index, line, column, mark) - # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow), - # '[', or '{' tokens. - self.possible_simple_keys = {} # type: Dict[Any, Any] - - @property - def reader(self): - # type: () -> Any - try: - return self._scanner_reader # type: ignore - except AttributeError: - if hasattr(self.loader, "typ"): - self._scanner_reader = self.loader.reader - else: - self._scanner_reader = self.loader._reader - return self._scanner_reader - - @property - def scanner_processing_version(self): # prefix until un-composited - # type: () -> Any - if hasattr(self.loader, "typ"): - return self.loader.resolver.processing_version - return self.loader.processing_version - - # Public methods. - - def check_token(self, *choices): - # type: (Any) -> bool - # Check if the next token is one of the given types. - while self.need_more_tokens(): - self.fetch_more_tokens() - if bool(self.tokens): - if not choices: - return True - for choice in choices: - if isinstance(self.tokens[0], choice): - return True - return False - - def peek_token(self): - # type: () -> Any - # Return the next token, but do not delete if from the queue. - while self.need_more_tokens(): - self.fetch_more_tokens() - if bool(self.tokens): - return self.tokens[0] - - def get_token(self): - # type: () -> Any - # Return the next token. - while self.need_more_tokens(): - self.fetch_more_tokens() - if bool(self.tokens): - self.tokens_taken += 1 - return self.tokens.pop(0) - - # Private methods. - - def need_more_tokens(self): - # type: () -> bool - if self.done: - return False - if not self.tokens: - return True - # The current token may be a potential simple key, so we - # need to look further. - self.stale_possible_simple_keys() - if self.next_possible_simple_key() == self.tokens_taken: - return True - return False - - def fetch_comment(self, comment): - # type: (Any) -> None - raise NotImplementedError - - def fetch_more_tokens(self): - # type: () -> Any - # Eat whitespaces and comments until we reach the next token. - comment = self.scan_to_next_token() - if comment is not None: # never happens for base scanner - return self.fetch_comment(comment) - # Remove obsolete possible simple keys. - self.stale_possible_simple_keys() - - # Compare the current indentation and column. It may add some tokens - # and decrease the current indentation level. - self.unwind_indent(self.reader.column) - - # Peek the next character. - ch = self.reader.peek() - - # Is it the end of stream? - if ch == "\0": - return self.fetch_stream_end() - - # Is it a directive? - if ch == "%" and self.check_directive(): - return self.fetch_directive() - - # Is it the document start? - if ch == "-" and self.check_document_start(): - return self.fetch_document_start() - - # Is it the document end? - if ch == "." and self.check_document_end(): - return self.fetch_document_end() - - # TODO: support for BOM within a stream. - # if ch == u'\uFEFF': - # return self.fetch_bom() <-- issue BOMToken - - # Note: the order of the following checks is NOT significant. - - # Is it the flow sequence start indicator? - if ch == "[": - return self.fetch_flow_sequence_start() - - # Is it the flow mapping start indicator? - if ch == "{": - return self.fetch_flow_mapping_start() - - # Is it the flow sequence end indicator? - if ch == "]": - return self.fetch_flow_sequence_end() - - # Is it the flow mapping end indicator? - if ch == "}": - return self.fetch_flow_mapping_end() - - # Is it the flow entry indicator? - if ch == ",": - return self.fetch_flow_entry() - - # Is it the block entry indicator? - if ch == "-" and self.check_block_entry(): - return self.fetch_block_entry() - - # Is it the key indicator? - if ch == "?" and self.check_key(): - return self.fetch_key() - - # Is it the value indicator? - if ch == ":" and self.check_value(): - return self.fetch_value() - - # Is it an alias? - if ch == "*": - return self.fetch_alias() - - # Is it an anchor? - if ch == "&": - return self.fetch_anchor() - - # Is it a tag? - if ch == "!": - return self.fetch_tag() - - # Is it a literal scalar? - if ch == "|" and not self.flow_level: - return self.fetch_literal() - - # Is it a folded scalar? - if ch == ">" and not self.flow_level: - return self.fetch_folded() - - # Is it a single quoted scalar? - if ch == "'": - return self.fetch_single() - - # Is it a double quoted scalar? - if ch == '"': - return self.fetch_double() - - # It must be a plain scalar then. - if self.check_plain(): - return self.fetch_plain() - - # No? It's an error. Let's produce a nice error message. - raise ScannerError( - "while scanning for the next token", - None, - "found character %r that cannot start any token" % utf8(ch), - self.reader.get_mark(), - ) - - # Simple keys treatment. - - def next_possible_simple_key(self): - # type: () -> Any - # Return the number of the nearest possible simple key. Actually we - # don't need to loop through the whole dictionary. We may replace it - # with the following code: - # if not self.possible_simple_keys: - # return None - # return self.possible_simple_keys[ - # min(self.possible_simple_keys.keys())].token_number - min_token_number = None - for level in self.possible_simple_keys: - key = self.possible_simple_keys[level] - if min_token_number is None or key.token_number < min_token_number: - min_token_number = key.token_number - return min_token_number - - def stale_possible_simple_keys(self): - # type: () -> None - # Remove entries that are no longer possible simple keys. According to - # the YAML specification, simple keys - # - should be limited to a single line, - # - should be no longer than 1024 characters. - # Disabling this procedure will allow simple keys of any length and - # height (may cause problems if indentation is broken though). - for level in list(self.possible_simple_keys): - key = self.possible_simple_keys[level] - if key.line != self.reader.line or self.reader.index - key.index > 1024: - if key.required: - raise ScannerError( - "while scanning a simple key", - key.mark, - "could not find expected ':'", - self.reader.get_mark(), - ) - del self.possible_simple_keys[level] - - def save_possible_simple_key(self): - # type: () -> None - # The next token may start a simple key. We check if it's possible - # and save its position. This function is called for - # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'. - - # Check if a simple key is required at the current position. - required = not self.flow_level and self.indent == self.reader.column - - # The next token might be a simple key. Let's save it's number and - # position. - if self.allow_simple_key: - self.remove_possible_simple_key() - token_number = self.tokens_taken + len(self.tokens) - key = SimpleKey( - token_number, - required, - self.reader.index, - self.reader.line, - self.reader.column, - self.reader.get_mark(), - ) - self.possible_simple_keys[self.flow_level] = key - - def remove_possible_simple_key(self): - # type: () -> None - # Remove the saved possible key position at the current flow level. - if self.flow_level in self.possible_simple_keys: - key = self.possible_simple_keys[self.flow_level] - - if key.required: - raise ScannerError( - "while scanning a simple key", - key.mark, - "could not find expected ':'", - self.reader.get_mark(), - ) - - del self.possible_simple_keys[self.flow_level] - - # Indentation functions. - - def unwind_indent(self, column): - # type: (Any) -> None - # In flow context, tokens should respect indentation. - # Actually the condition should be `self.indent >= column` according to - # the spec. But this condition will prohibit intuitively correct - # constructions such as - # key : { - # } - # #### - # if self.flow_level and self.indent > column: - # raise ScannerError(None, None, - # "invalid intendation or unclosed '[' or '{'", - # self.reader.get_mark()) - - # In the flow context, indentation is ignored. We make the scanner less - # restrictive then specification requires. - if bool(self.flow_level): - return - - # In block context, we may need to issue the BLOCK-END tokens. - while self.indent > column: - mark = self.reader.get_mark() - self.indent = self.indents.pop() - self.tokens.append(BlockEndToken(mark, mark)) - - def add_indent(self, column): - # type: (int) -> bool - # Check if we need to increase indentation. - if self.indent < column: - self.indents.append(self.indent) - self.indent = column - return True - return False - - # Fetchers. - - def fetch_stream_start(self): - # type: () -> None - # We always add STREAM-START as the first token and STREAM-END as the - # last token. - # Read the token. - mark = self.reader.get_mark() - # Add STREAM-START. - self.tokens.append(StreamStartToken(mark, mark, encoding=self.reader.encoding)) - - def fetch_stream_end(self): - # type: () -> None - # Set the current intendation to -1. - self.unwind_indent(-1) - # Reset simple keys. - self.remove_possible_simple_key() - self.allow_simple_key = False - self.possible_simple_keys = {} - # Read the token. - mark = self.reader.get_mark() - # Add STREAM-END. - self.tokens.append(StreamEndToken(mark, mark)) - # The steam is finished. - self.done = True - - def fetch_directive(self): - # type: () -> None - # Set the current intendation to -1. - self.unwind_indent(-1) - - # Reset simple keys. - self.remove_possible_simple_key() - self.allow_simple_key = False - - # Scan and add DIRECTIVE. - self.tokens.append(self.scan_directive()) - - def fetch_document_start(self): - # type: () -> None - self.fetch_document_indicator(DocumentStartToken) - - def fetch_document_end(self): - # type: () -> None - self.fetch_document_indicator(DocumentEndToken) - - def fetch_document_indicator(self, TokenClass): - # type: (Any) -> None - # Set the current intendation to -1. - self.unwind_indent(-1) - - # Reset simple keys. Note that there could not be a block collection - # after '---'. - self.remove_possible_simple_key() - self.allow_simple_key = False - - # Add DOCUMENT-START or DOCUMENT-END. - start_mark = self.reader.get_mark() - self.reader.forward(3) - end_mark = self.reader.get_mark() - self.tokens.append(TokenClass(start_mark, end_mark)) - - def fetch_flow_sequence_start(self): - # type: () -> None - self.fetch_flow_collection_start(FlowSequenceStartToken, to_push="[") - - def fetch_flow_mapping_start(self): - # type: () -> None - self.fetch_flow_collection_start(FlowMappingStartToken, to_push="{") - - def fetch_flow_collection_start(self, TokenClass, to_push): - # type: (Any, Text) -> None - # '[' and '{' may start a simple key. - self.save_possible_simple_key() - # Increase the flow level. - self.flow_context.append(to_push) - # Simple keys are allowed after '[' and '{'. - self.allow_simple_key = True - # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(TokenClass(start_mark, end_mark)) - - def fetch_flow_sequence_end(self): - # type: () -> None - self.fetch_flow_collection_end(FlowSequenceEndToken) - - def fetch_flow_mapping_end(self): - # type: () -> None - self.fetch_flow_collection_end(FlowMappingEndToken) - - def fetch_flow_collection_end(self, TokenClass): - # type: (Any) -> None - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - # Decrease the flow level. - try: - popped = self.flow_context.pop() # NOQA - except IndexError: - # We must not be in a list or object. - # Defer error handling to the parser. - pass - # No simple keys after ']' or '}'. - self.allow_simple_key = False - # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(TokenClass(start_mark, end_mark)) - - def fetch_flow_entry(self): - # type: () -> None - # Simple keys are allowed after ','. - self.allow_simple_key = True - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - # Add FLOW-ENTRY. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(FlowEntryToken(start_mark, end_mark)) - - def fetch_block_entry(self): - # type: () -> None - # Block context needs additional checks. - if not self.flow_level: - # Are we allowed to start a new entry? - if not self.allow_simple_key: - raise ScannerError( - None, - None, - "sequence entries are not allowed here", - self.reader.get_mark(), - ) - # We may need to add BLOCK-SEQUENCE-START. - if self.add_indent(self.reader.column): - mark = self.reader.get_mark() - self.tokens.append(BlockSequenceStartToken(mark, mark)) - # It's an error for the block entry to occur in the flow context, - # but we let the parser detect this. - else: - pass - # Simple keys are allowed after '-'. - self.allow_simple_key = True - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - - # Add BLOCK-ENTRY. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(BlockEntryToken(start_mark, end_mark)) - - def fetch_key(self): - # type: () -> None - # Block context needs additional checks. - if not self.flow_level: - - # Are we allowed to start a key (not nessesary a simple)? - if not self.allow_simple_key: - raise ScannerError( - None, - None, - "mapping keys are not allowed here", - self.reader.get_mark(), - ) - - # We may need to add BLOCK-MAPPING-START. - if self.add_indent(self.reader.column): - mark = self.reader.get_mark() - self.tokens.append(BlockMappingStartToken(mark, mark)) - - # Simple keys are allowed after '?' in the block context. - self.allow_simple_key = not self.flow_level - - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - - # Add KEY. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(KeyToken(start_mark, end_mark)) - - def fetch_value(self): - # type: () -> None - # Do we determine a simple key? - if self.flow_level in self.possible_simple_keys: - # Add KEY. - key = self.possible_simple_keys[self.flow_level] - del self.possible_simple_keys[self.flow_level] - self.tokens.insert( - key.token_number - self.tokens_taken, KeyToken(key.mark, key.mark) - ) - - # If this key starts a new block mapping, we need to add - # BLOCK-MAPPING-START. - if not self.flow_level: - if self.add_indent(key.column): - self.tokens.insert( - key.token_number - self.tokens_taken, - BlockMappingStartToken(key.mark, key.mark), - ) - - # There cannot be two simple keys one after another. - self.allow_simple_key = False - - # It must be a part of a complex key. - else: - - # Block context needs additional checks. - # (Do we really need them? They will be caught by the parser - # anyway.) - if not self.flow_level: - - # We are allowed to start a complex value if and only if - # we can start a simple key. - if not self.allow_simple_key: - raise ScannerError( - None, - None, - "mapping values are not allowed here", - self.reader.get_mark(), - ) - - # If this value starts a new block mapping, we need to add - # BLOCK-MAPPING-START. It will be detected as an error later by - # the parser. - if not self.flow_level: - if self.add_indent(self.reader.column): - mark = self.reader.get_mark() - self.tokens.append(BlockMappingStartToken(mark, mark)) - - # Simple keys are allowed after ':' in the block context. - self.allow_simple_key = not self.flow_level - - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - - # Add VALUE. - start_mark = self.reader.get_mark() - self.reader.forward() - end_mark = self.reader.get_mark() - self.tokens.append(ValueToken(start_mark, end_mark)) - - def fetch_alias(self): - # type: () -> None - # ALIAS could be a simple key. - self.save_possible_simple_key() - # No simple keys after ALIAS. - self.allow_simple_key = False - # Scan and add ALIAS. - self.tokens.append(self.scan_anchor(AliasToken)) - - def fetch_anchor(self): - # type: () -> None - # ANCHOR could start a simple key. - self.save_possible_simple_key() - # No simple keys after ANCHOR. - self.allow_simple_key = False - # Scan and add ANCHOR. - self.tokens.append(self.scan_anchor(AnchorToken)) - - def fetch_tag(self): - # type: () -> None - # TAG could start a simple key. - self.save_possible_simple_key() - # No simple keys after TAG. - self.allow_simple_key = False - # Scan and add TAG. - self.tokens.append(self.scan_tag()) - - def fetch_literal(self): - # type: () -> None - self.fetch_block_scalar(style="|") - - def fetch_folded(self): - # type: () -> None - self.fetch_block_scalar(style=">") - - def fetch_block_scalar(self, style): - # type: (Any) -> None - # A simple key may follow a block scalar. - self.allow_simple_key = True - # Reset possible simple key on the current level. - self.remove_possible_simple_key() - # Scan and add SCALAR. - self.tokens.append(self.scan_block_scalar(style)) - - def fetch_single(self): - # type: () -> None - self.fetch_flow_scalar(style="'") - - def fetch_double(self): - # type: () -> None - self.fetch_flow_scalar(style='"') - - def fetch_flow_scalar(self, style): - # type: (Any) -> None - # A flow scalar could be a simple key. - self.save_possible_simple_key() - # No simple keys after flow scalars. - self.allow_simple_key = False - # Scan and add SCALAR. - self.tokens.append(self.scan_flow_scalar(style)) - - def fetch_plain(self): - # type: () -> None - # A plain scalar could be a simple key. - self.save_possible_simple_key() - # No simple keys after plain scalars. But note that `scan_plain` will - # change this flag if the scan is finished at the beginning of the - # line. - self.allow_simple_key = False - # Scan and add SCALAR. May change `allow_simple_key`. - self.tokens.append(self.scan_plain()) - - # Checkers. - - def check_directive(self): - # type: () -> Any - # DIRECTIVE: ^ '%' ... - # The '%' indicator is already checked. - if self.reader.column == 0: - return True - return None - - def check_document_start(self): - # type: () -> Any - # DOCUMENT-START: ^ '---' (' '|'\n') - if self.reader.column == 0: - if ( - self.reader.prefix(3) == "---" - and self.reader.peek(3) in _THE_END_SPACE_TAB - ): - return True - return None - - def check_document_end(self): - # type: () -> Any - # DOCUMENT-END: ^ '...' (' '|'\n') - if self.reader.column == 0: - if ( - self.reader.prefix(3) == "..." - and self.reader.peek(3) in _THE_END_SPACE_TAB - ): - return True - return None - - def check_block_entry(self): - # type: () -> Any - # BLOCK-ENTRY: '-' (' '|'\n') - return self.reader.peek(1) in _THE_END_SPACE_TAB - - def check_key(self): - # type: () -> Any - # KEY(flow context): '?' - if bool(self.flow_level): - return True - # KEY(block context): '?' (' '|'\n') - return self.reader.peek(1) in _THE_END_SPACE_TAB - - def check_value(self): - # type: () -> Any - # VALUE(flow context): ':' - if self.scanner_processing_version == (1, 1): - if bool(self.flow_level): - return True - else: - if bool(self.flow_level): - if self.flow_context[-1] == "[": - if self.reader.peek(1) not in _THE_END_SPACE_TAB: - return False - elif self.tokens and isinstance(self.tokens[-1], ValueToken): - # mapping flow context scanning a value token - if self.reader.peek(1) not in _THE_END_SPACE_TAB: - return False - return True - # VALUE(block context): ':' (' '|'\n') - return self.reader.peek(1) in _THE_END_SPACE_TAB - - def check_plain(self): - # type: () -> Any - # A plain scalar may start with any non-space character except: - # '-', '?', ':', ',', '[', ']', '{', '}', - # '#', '&', '*', '!', '|', '>', '\'', '\"', - # '%', '@', '`'. - # - # It may also start with - # '-', '?', ':' - # if it is followed by a non-space character. - # - # Note that we limit the last rule to the block context (except the - # '-' character) because we want the flow context to be space - # independent. - srp = self.reader.peek - ch = srp() - if self.scanner_processing_version == (1, 1): - return ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`" or ( - srp(1) not in _THE_END_SPACE_TAB - and (ch == "-" or (not self.flow_level and ch in "?:")) - ) - # YAML 1.2 - if ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`": - # ################### ^ ??? - return True - ch1 = srp(1) - if ch == "-" and ch1 not in _THE_END_SPACE_TAB: - return True - if ch == ":" and bool(self.flow_level) and ch1 not in _SPACE_TAB: - return True - - return srp(1) not in _THE_END_SPACE_TAB and ( - ch == "-" or (not self.flow_level and ch in "?:") - ) - - # Scanners. - - def scan_to_next_token(self): - # type: () -> Any - # We ignore spaces, line breaks and comments. - # If we find a line break in the block context, we set the flag - # `allow_simple_key` on. - # The byte order mark is stripped if it's the first character in the - # stream. We do not yet support BOM inside the stream as the - # specification requires. Any such mark will be considered as a part - # of the document. - # - # TODO: We need to make tab handling rules more sane. A good rule is - # Tabs cannot precede tokens - # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END, - # KEY(block), VALUE(block), BLOCK-ENTRY - # So the checking code is - # if : - # self.allow_simple_keys = False - # We also need to add the check for `allow_simple_keys == True` to - # `unwind_indent` before issuing BLOCK-END. - # Scanners for block, flow, and plain scalars need to be modified. - srp = self.reader.peek - srf = self.reader.forward - if self.reader.index == 0 and srp() == "\uFEFF": - srf() - found = False - _the_end = _THE_END - while not found: - while srp() == " ": - srf() - if srp() == "#": - while srp() not in _the_end: - srf() - if self.scan_line_break(): - if not self.flow_level: - self.allow_simple_key = True - else: - found = True - return None - - def scan_directive(self): - # type: () -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - start_mark = self.reader.get_mark() - srf() - name = self.scan_directive_name(start_mark) - value = None - if name == "YAML": - value = self.scan_yaml_directive_value(start_mark) - end_mark = self.reader.get_mark() - elif name == "TAG": - value = self.scan_tag_directive_value(start_mark) - end_mark = self.reader.get_mark() - else: - end_mark = self.reader.get_mark() - while srp() not in _THE_END: - srf() - self.scan_directive_ignored_line(start_mark) - return DirectiveToken(name, value, start_mark, end_mark) - - def scan_directive_name(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - length = 0 - srp = self.reader.peek - ch = srp(length) - while "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_:.": - length += 1 - ch = srp(length) - if not length: - raise ScannerError( - "while scanning a directive", - start_mark, - "expected alphabetic or numeric character, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - value = self.reader.prefix(length) - self.reader.forward(length) - ch = srp() - if ch not in "\0 \r\n\x85\u2028\u2029": - raise ScannerError( - "while scanning a directive", - start_mark, - "expected alphabetic or numeric character, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - return value - - def scan_yaml_directive_value(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - while srp() == " ": - srf() - major = self.scan_yaml_directive_number(start_mark) - if srp() != ".": - raise ScannerError( - "while scanning a directive", - start_mark, - "expected a digit or '.', but found %r" % utf8(srp()), - self.reader.get_mark(), - ) - srf() - minor = self.scan_yaml_directive_number(start_mark) - if srp() not in "\0 \r\n\x85\u2028\u2029": - raise ScannerError( - "while scanning a directive", - start_mark, - "expected a digit or ' ', but found %r" % utf8(srp()), - self.reader.get_mark(), - ) - self.yaml_version = (major, minor) - return self.yaml_version - - def scan_yaml_directive_number(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - ch = srp() - if not ("0" <= ch <= "9"): - raise ScannerError( - "while scanning a directive", - start_mark, - "expected a digit, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - length = 0 - while "0" <= srp(length) <= "9": - length += 1 - value = int(self.reader.prefix(length)) - srf(length) - return value - - def scan_tag_directive_value(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - while srp() == " ": - srf() - handle = self.scan_tag_directive_handle(start_mark) - while srp() == " ": - srf() - prefix = self.scan_tag_directive_prefix(start_mark) - return (handle, prefix) - - def scan_tag_directive_handle(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - value = self.scan_tag_handle("directive", start_mark) - ch = self.reader.peek() - if ch != " ": - raise ScannerError( - "while scanning a directive", - start_mark, - "expected ' ', but found %r" % utf8(ch), - self.reader.get_mark(), - ) - return value - - def scan_tag_directive_prefix(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - value = self.scan_tag_uri("directive", start_mark) - ch = self.reader.peek() - if ch not in "\0 \r\n\x85\u2028\u2029": - raise ScannerError( - "while scanning a directive", - start_mark, - "expected ' ', but found %r" % utf8(ch), - self.reader.get_mark(), - ) - return value - - def scan_directive_ignored_line(self, start_mark): - # type: (Any) -> None - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - while srp() == " ": - srf() - if srp() == "#": - while srp() not in _THE_END: - srf() - ch = srp() - if ch not in _THE_END: - raise ScannerError( - "while scanning a directive", - start_mark, - "expected a comment or a line break, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - self.scan_line_break() - - def scan_anchor(self, TokenClass): - # type: (Any) -> Any - # The specification does not restrict characters for anchors and - # aliases. This may lead to problems, for instance, the document: - # [ *alias, value ] - # can be interpteted in two ways, as - # [ "value" ] - # and - # [ *alias , "value" ] - # Therefore we restrict aliases to numbers and ASCII letters. - srp = self.reader.peek - start_mark = self.reader.get_mark() - indicator = srp() - if indicator == "*": - name = "alias" - else: - name = "anchor" - self.reader.forward() - length = 0 - ch = srp(length) - # while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \ - # or ch in u'-_': - while check_anchorname_char(ch): - length += 1 - ch = srp(length) - if not length: - raise ScannerError( - "while scanning an %s" % (name,), - start_mark, - "expected alphabetic or numeric character, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - value = self.reader.prefix(length) - self.reader.forward(length) - # ch1 = ch - # ch = srp() # no need to peek, ch is already set - # assert ch1 == ch - if ch not in "\0 \t\r\n\x85\u2028\u2029?:,[]{}%@`": - raise ScannerError( - "while scanning an %s" % (name,), - start_mark, - "expected alphabetic or numeric character, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - end_mark = self.reader.get_mark() - return TokenClass(value, start_mark, end_mark) - - def scan_tag(self): - # type: () -> Any - # See the specification for details. - srp = self.reader.peek - start_mark = self.reader.get_mark() - ch = srp(1) - if ch == "<": - handle = None - self.reader.forward(2) - suffix = self.scan_tag_uri("tag", start_mark) - if srp() != ">": - raise ScannerError( - "while parsing a tag", - start_mark, - "expected '>', but found %r" % utf8(srp()), - self.reader.get_mark(), - ) - self.reader.forward() - elif ch in _THE_END_SPACE_TAB: - handle = None - suffix = "!" - self.reader.forward() - else: - length = 1 - use_handle = False - while ch not in "\0 \r\n\x85\u2028\u2029": - if ch == "!": - use_handle = True - break - length += 1 - ch = srp(length) - handle = "!" - if use_handle: - handle = self.scan_tag_handle("tag", start_mark) - else: - handle = "!" - self.reader.forward() - suffix = self.scan_tag_uri("tag", start_mark) - ch = srp() - if ch not in "\0 \r\n\x85\u2028\u2029": - raise ScannerError( - "while scanning a tag", - start_mark, - "expected ' ', but found %r" % utf8(ch), - self.reader.get_mark(), - ) - value = (handle, suffix) - end_mark = self.reader.get_mark() - return TagToken(value, start_mark, end_mark) - - def scan_block_scalar(self, style, rt=False): - # type: (Any, Optional[bool]) -> Any - # See the specification for details. - srp = self.reader.peek - if style == ">": - folded = True - else: - folded = False - - chunks = [] # type: List[Any] - start_mark = self.reader.get_mark() - - # Scan the header. - self.reader.forward() - chomping, increment = self.scan_block_scalar_indicators(start_mark) - # block scalar comment e.g. : |+ # comment text - block_scalar_comment = self.scan_block_scalar_ignored_line(start_mark) - - # Determine the indentation level and go to the first non-empty line. - min_indent = self.indent + 1 - if increment is None: - # no increment and top level, min_indent could be 0 - if min_indent < 1 and ( - style not in "|>" - or (self.scanner_processing_version == (1, 1)) - and getattr( - self.loader, - "top_level_block_style_scalar_no_indent_error_1_1", - False, - ) - ): - min_indent = 1 - breaks, max_indent, end_mark = self.scan_block_scalar_indentation() - indent = max(min_indent, max_indent) - else: - if min_indent < 1: - min_indent = 1 - indent = min_indent + increment - 1 - breaks, end_mark = self.scan_block_scalar_breaks(indent) - line_break = "" - - # Scan the inner part of the block scalar. - while self.reader.column == indent and srp() != "\0": - chunks.extend(breaks) - leading_non_space = srp() not in " \t" - length = 0 - while srp(length) not in _THE_END: - length += 1 - chunks.append(self.reader.prefix(length)) - self.reader.forward(length) - line_break = self.scan_line_break() - breaks, end_mark = self.scan_block_scalar_breaks(indent) - if style in "|>" and min_indent == 0: - # at the beginning of a line, if in block style see if - # end of document/start_new_document - if self.check_document_start() or self.check_document_end(): - break - if self.reader.column == indent and srp() != "\0": - - # Unfortunately, folding rules are ambiguous. - # - # This is the folding according to the specification: - - if rt and folded and line_break == "\n": - chunks.append("\a") - if ( - folded - and line_break == "\n" - and leading_non_space - and srp() not in " \t" - ): - if not breaks: - chunks.append(" ") - else: - chunks.append(line_break) - - # This is Clark Evans's interpretation (also in the spec - # examples): - # - # if folded and line_break == u'\n': - # if not breaks: - # if srp() not in ' \t': - # chunks.append(u' ') - # else: - # chunks.append(line_break) - # else: - # chunks.append(line_break) - else: - break - - # Process trailing line breaks. The 'chomping' setting determines - # whether they are included in the value. - trailing = [] # type: List[Any] - if chomping in [None, True]: - chunks.append(line_break) - if chomping is True: - chunks.extend(breaks) - elif chomping in [None, False]: - trailing.extend(breaks) - - # We are done. - token = ScalarToken("".join(chunks), False, start_mark, end_mark, style) - if block_scalar_comment is not None: - token.add_pre_comments([block_scalar_comment]) - if len(trailing) > 0: - # nprint('trailing 1', trailing) # XXXXX - # Eat whitespaces and comments until we reach the next token. - comment = self.scan_to_next_token() - while comment: - trailing.append(" " * comment[1].column + comment[0]) - comment = self.scan_to_next_token() - - # Keep track of the trailing whitespace and following comments - # as a comment token, if isn't all included in the actual value. - comment_end_mark = self.reader.get_mark() - comment = CommentToken("".join(trailing), end_mark, comment_end_mark) - token.add_post_comment(comment) - return token - - def scan_block_scalar_indicators(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - srp = self.reader.peek - chomping = None - increment = None - ch = srp() - if ch in "+-": - if ch == "+": - chomping = True - else: - chomping = False - self.reader.forward() - ch = srp() - if ch in "0123456789": - increment = int(ch) - if increment == 0: - raise ScannerError( - "while scanning a block scalar", - start_mark, - "expected indentation indicator in the range 1-9, " - "but found 0", - self.reader.get_mark(), - ) - self.reader.forward() - elif ch in "0123456789": - increment = int(ch) - if increment == 0: - raise ScannerError( - "while scanning a block scalar", - start_mark, - "expected indentation indicator in the range 1-9, " "but found 0", - self.reader.get_mark(), - ) - self.reader.forward() - ch = srp() - if ch in "+-": - if ch == "+": - chomping = True - else: - chomping = False - self.reader.forward() - ch = srp() - if ch not in "\0 \r\n\x85\u2028\u2029": - raise ScannerError( - "while scanning a block scalar", - start_mark, - "expected chomping or indentation indicators, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - return chomping, increment - - def scan_block_scalar_ignored_line(self, start_mark): - # type: (Any) -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - prefix = "" - comment = None - while srp() == " ": - prefix += srp() - srf() - if srp() == "#": - comment = prefix - while srp() not in _THE_END: - comment += srp() - srf() - ch = srp() - if ch not in _THE_END: - raise ScannerError( - "while scanning a block scalar", - start_mark, - "expected a comment or a line break, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - self.scan_line_break() - return comment - - def scan_block_scalar_indentation(self): - # type: () -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - chunks = [] - max_indent = 0 - end_mark = self.reader.get_mark() - while srp() in " \r\n\x85\u2028\u2029": - if srp() != " ": - chunks.append(self.scan_line_break()) - end_mark = self.reader.get_mark() - else: - srf() - if self.reader.column > max_indent: - max_indent = self.reader.column - return chunks, max_indent, end_mark - - def scan_block_scalar_breaks(self, indent): - # type: (int) -> Any - # See the specification for details. - chunks = [] - srp = self.reader.peek - srf = self.reader.forward - end_mark = self.reader.get_mark() - while self.reader.column < indent and srp() == " ": - srf() - while srp() in "\r\n\x85\u2028\u2029": - chunks.append(self.scan_line_break()) - end_mark = self.reader.get_mark() - while self.reader.column < indent and srp() == " ": - srf() - return chunks, end_mark - - def scan_flow_scalar(self, style): - # type: (Any) -> Any - # See the specification for details. - # Note that we loose indentation rules for quoted scalars. Quoted - # scalars don't need to adhere indentation because " and ' clearly - # mark the beginning and the end of them. Therefore we are less - # restrictive then the specification requires. We only need to check - # that document separators are not included in scalars. - if style == '"': - double = True - else: - double = False - srp = self.reader.peek - chunks = [] # type: List[Any] - start_mark = self.reader.get_mark() - quote = srp() - self.reader.forward() - chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) - while srp() != quote: - chunks.extend(self.scan_flow_scalar_spaces(double, start_mark)) - chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark)) - self.reader.forward() - end_mark = self.reader.get_mark() - return ScalarToken("".join(chunks), False, start_mark, end_mark, style) - - ESCAPE_REPLACEMENTS = { - "0": "\0", - "a": "\x07", - "b": "\x08", - "t": "\x09", - "\t": "\x09", - "n": "\x0A", - "v": "\x0B", - "f": "\x0C", - "r": "\x0D", - "e": "\x1B", - " ": "\x20", - '"': '"', - "/": "/", # as per http://www.json.org/ - "\\": "\\", - "N": "\x85", - "_": "\xA0", - "L": "\u2028", - "P": "\u2029", - } - - ESCAPE_CODES = {"x": 2, "u": 4, "U": 8} - - def scan_flow_scalar_non_spaces(self, double, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - chunks = [] # type: List[Any] - srp = self.reader.peek - srf = self.reader.forward - while True: - length = 0 - while srp(length) not in " \n'\"\\\0\t\r\x85\u2028\u2029": - length += 1 - if length != 0: - chunks.append(self.reader.prefix(length)) - srf(length) - ch = srp() - if not double and ch == "'" and srp(1) == "'": - chunks.append("'") - srf(2) - elif (double and ch == "'") or (not double and ch in '"\\'): - chunks.append(ch) - srf() - elif double and ch == "\\": - srf() - ch = srp() - if ch in self.ESCAPE_REPLACEMENTS: - chunks.append(self.ESCAPE_REPLACEMENTS[ch]) - srf() - elif ch in self.ESCAPE_CODES: - length = self.ESCAPE_CODES[ch] - srf() - for k in range(length): - if srp(k) not in "0123456789ABCDEFabcdef": - raise ScannerError( - "while scanning a double-quoted scalar", - start_mark, - "expected escape sequence of %d hexdecimal " - "numbers, but found %r" % (length, utf8(srp(k))), - self.reader.get_mark(), - ) - code = int(self.reader.prefix(length), 16) - chunks.append(unichr(code)) - srf(length) - elif ch in "\n\r\x85\u2028\u2029": - self.scan_line_break() - chunks.extend(self.scan_flow_scalar_breaks(double, start_mark)) - else: - raise ScannerError( - "while scanning a double-quoted scalar", - start_mark, - "found unknown escape character %r" % utf8(ch), - self.reader.get_mark(), - ) - else: - return chunks - - def scan_flow_scalar_spaces(self, double, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - srp = self.reader.peek - chunks = [] - length = 0 - while srp(length) in " \t": - length += 1 - whitespaces = self.reader.prefix(length) - self.reader.forward(length) - ch = srp() - if ch == "\0": - raise ScannerError( - "while scanning a quoted scalar", - start_mark, - "found unexpected end of stream", - self.reader.get_mark(), - ) - elif ch in "\r\n\x85\u2028\u2029": - line_break = self.scan_line_break() - breaks = self.scan_flow_scalar_breaks(double, start_mark) - if line_break != "\n": - chunks.append(line_break) - elif not breaks: - chunks.append(" ") - chunks.extend(breaks) - else: - chunks.append(whitespaces) - return chunks - - def scan_flow_scalar_breaks(self, double, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - chunks = [] # type: List[Any] - srp = self.reader.peek - srf = self.reader.forward - while True: - # Instead of checking indentation, we check for document - # separators. - prefix = self.reader.prefix(3) - if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB: - raise ScannerError( - "while scanning a quoted scalar", - start_mark, - "found unexpected document separator", - self.reader.get_mark(), - ) - while srp() in " \t": - srf() - if srp() in "\r\n\x85\u2028\u2029": - chunks.append(self.scan_line_break()) - else: - return chunks - - def scan_plain(self): - # type: () -> Any - # See the specification for details. - # We add an additional restriction for the flow context: - # plain scalars in the flow context cannot contain ',', ': ' and '?'. - # We also keep track of the `allow_simple_key` flag here. - # Indentation rules are loosed for the flow context. - srp = self.reader.peek - srf = self.reader.forward - chunks = [] # type: List[Any] - start_mark = self.reader.get_mark() - end_mark = start_mark - indent = self.indent + 1 - # We allow zero indentation for scalars, but then we need to check for - # document separators at the beginning of the line. - # if indent == 0: - # indent = 1 - spaces = [] # type: List[Any] - while True: - length = 0 - if srp() == "#": - break - while True: - ch = srp(length) - if ch == ":" and srp(length + 1) not in _THE_END_SPACE_TAB: - pass - elif ch == "?" and self.scanner_processing_version != (1, 1): - pass - elif ( - ch in _THE_END_SPACE_TAB - or ( - not self.flow_level - and ch == ":" - and srp(length + 1) in _THE_END_SPACE_TAB - ) - or (self.flow_level and ch in ",:?[]{}") - ): - break - length += 1 - # It's not clear what we should do with ':' in the flow context. - if ( - self.flow_level - and ch == ":" - and srp(length + 1) not in "\0 \t\r\n\x85\u2028\u2029,[]{}" - ): - srf(length) - raise ScannerError( - "while scanning a plain scalar", - start_mark, - "found unexpected ':'", - self.reader.get_mark(), - "Please check " - "http://pyyaml.org/wiki/YAMLColonInFlowContext " - "for details.", - ) - if length == 0: - break - self.allow_simple_key = False - chunks.extend(spaces) - chunks.append(self.reader.prefix(length)) - srf(length) - end_mark = self.reader.get_mark() - spaces = self.scan_plain_spaces(indent, start_mark) - if ( - not spaces - or srp() == "#" - or (not self.flow_level and self.reader.column < indent) - ): - break - - token = ScalarToken("".join(chunks), True, start_mark, end_mark) - if spaces and spaces[0] == "\n": - # Create a comment token to preserve the trailing line breaks. - comment = CommentToken("".join(spaces) + "\n", start_mark, end_mark) - token.add_post_comment(comment) - return token - - def scan_plain_spaces(self, indent, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - # The specification is really confusing about tabs in plain scalars. - # We just forbid them completely. Do not use tabs in YAML! - srp = self.reader.peek - srf = self.reader.forward - chunks = [] - length = 0 - while srp(length) in " ": - length += 1 - whitespaces = self.reader.prefix(length) - self.reader.forward(length) - ch = srp() - if ch in "\r\n\x85\u2028\u2029": - line_break = self.scan_line_break() - self.allow_simple_key = True - prefix = self.reader.prefix(3) - if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB: - return - breaks = [] - while srp() in " \r\n\x85\u2028\u2029": - if srp() == " ": - srf() - else: - breaks.append(self.scan_line_break()) - prefix = self.reader.prefix(3) - if (prefix == "---" or prefix == "...") and srp( - 3 - ) in _THE_END_SPACE_TAB: - return - if line_break != "\n": - chunks.append(line_break) - elif not breaks: - chunks.append(" ") - chunks.extend(breaks) - elif whitespaces: - chunks.append(whitespaces) - return chunks - - def scan_tag_handle(self, name, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - # For some strange reasons, the specification does not allow '_' in - # tag handles. I have allowed it anyway. - srp = self.reader.peek - ch = srp() - if ch != "!": - raise ScannerError( - "while scanning a %s" % (name,), - start_mark, - "expected '!', but found %r" % utf8(ch), - self.reader.get_mark(), - ) - length = 1 - ch = srp(length) - if ch != " ": - while ( - "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_" - ): - length += 1 - ch = srp(length) - if ch != "!": - self.reader.forward(length) - raise ScannerError( - "while scanning a %s" % (name,), - start_mark, - "expected '!', but found %r" % utf8(ch), - self.reader.get_mark(), - ) - length += 1 - value = self.reader.prefix(length) - self.reader.forward(length) - return value - - def scan_tag_uri(self, name, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - # Note: we do not check if URI is well-formed. - srp = self.reader.peek - chunks = [] - length = 0 - ch = srp(length) - while ( - "0" <= ch <= "9" - or "A" <= ch <= "Z" - or "a" <= ch <= "z" - or ch in "-;/?:@&=+$,_.!~*'()[]%" - or ((self.scanner_processing_version > (1, 1)) and ch == "#") - ): - if ch == "%": - chunks.append(self.reader.prefix(length)) - self.reader.forward(length) - length = 0 - chunks.append(self.scan_uri_escapes(name, start_mark)) - else: - length += 1 - ch = srp(length) - if length != 0: - chunks.append(self.reader.prefix(length)) - self.reader.forward(length) - length = 0 - if not chunks: - raise ScannerError( - "while parsing a %s" % (name,), - start_mark, - "expected URI, but found %r" % utf8(ch), - self.reader.get_mark(), - ) - return "".join(chunks) - - def scan_uri_escapes(self, name, start_mark): - # type: (Any, Any) -> Any - # See the specification for details. - srp = self.reader.peek - srf = self.reader.forward - code_bytes = [] # type: List[Any] - mark = self.reader.get_mark() - while srp() == "%": - srf() - for k in range(2): - if srp(k) not in "0123456789ABCDEFabcdef": - raise ScannerError( - "while scanning a %s" % (name,), - start_mark, - "expected URI escape sequence of 2 hexdecimal numbers," - " but found %r" % utf8(srp(k)), - self.reader.get_mark(), - ) - if PY3: - code_bytes.append(int(self.reader.prefix(2), 16)) - else: - code_bytes.append(chr(int(self.reader.prefix(2), 16))) - srf(2) - try: - if PY3: - value = bytes(code_bytes).decode("utf-8") - else: - value = unicode(b"".join(code_bytes), "utf-8") - except UnicodeDecodeError as exc: - raise ScannerError( - "while scanning a %s" % (name,), start_mark, str(exc), mark - ) - return value - - def scan_line_break(self): - # type: () -> Any - # Transforms: - # '\r\n' : '\n' - # '\r' : '\n' - # '\n' : '\n' - # '\x85' : '\n' - # '\u2028' : '\u2028' - # '\u2029 : '\u2029' - # default : '' - ch = self.reader.peek() - if ch in "\r\n\x85": - if self.reader.prefix(2) == "\r\n": - self.reader.forward(2) - else: - self.reader.forward() - return "\n" - elif ch in "\u2028\u2029": - self.reader.forward() - return ch - return "" - - -class RoundTripScanner(Scanner): - def check_token(self, *choices): - # type: (Any) -> bool - # Check if the next token is one of the given types. - while self.need_more_tokens(): - self.fetch_more_tokens() - self._gather_comments() - if bool(self.tokens): - if not choices: - return True - for choice in choices: - if isinstance(self.tokens[0], choice): - return True - return False - - def peek_token(self): - # type: () -> Any - # Return the next token, but do not delete if from the queue. - while self.need_more_tokens(): - self.fetch_more_tokens() - self._gather_comments() - if bool(self.tokens): - return self.tokens[0] - return None - - def _gather_comments(self): - # type: () -> Any - """combine multiple comment lines""" - comments = [] # type: List[Any] - if not self.tokens: - return comments - if isinstance(self.tokens[0], CommentToken): - comment = self.tokens.pop(0) - self.tokens_taken += 1 - comments.append(comment) - while self.need_more_tokens(): - self.fetch_more_tokens() - if not self.tokens: - return comments - if isinstance(self.tokens[0], CommentToken): - self.tokens_taken += 1 - comment = self.tokens.pop(0) - # nprint('dropping2', comment) - comments.append(comment) - if len(comments) >= 1: - self.tokens[0].add_pre_comments(comments) - # pull in post comment on e.g. ':' - if not self.done and len(self.tokens) < 2: - self.fetch_more_tokens() - - def get_token(self): - # type: () -> Any - # Return the next token. - while self.need_more_tokens(): - self.fetch_more_tokens() - self._gather_comments() - if bool(self.tokens): - # nprint('tk', self.tokens) - # only add post comment to single line tokens: - # scalar, value token. FlowXEndToken, otherwise - # hidden streamtokens could get them (leave them and they will be - # pre comments for the next map/seq - if ( - len(self.tokens) > 1 - and isinstance( - self.tokens[0], - ( - ScalarToken, - ValueToken, - FlowSequenceEndToken, - FlowMappingEndToken, - ), - ) - and isinstance(self.tokens[1], CommentToken) - and self.tokens[0].end_mark.line == self.tokens[1].start_mark.line - ): - self.tokens_taken += 1 - c = self.tokens.pop(1) - self.fetch_more_tokens() - while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken): - self.tokens_taken += 1 - c1 = self.tokens.pop(1) - c.value = c.value + (" " * c1.start_mark.column) + c1.value - self.fetch_more_tokens() - self.tokens[0].add_post_comment(c) - elif ( - len(self.tokens) > 1 - and isinstance(self.tokens[0], ScalarToken) - and isinstance(self.tokens[1], CommentToken) - and self.tokens[0].end_mark.line != self.tokens[1].start_mark.line - ): - self.tokens_taken += 1 - c = self.tokens.pop(1) - c.value = ( - "\n" * (c.start_mark.line - self.tokens[0].end_mark.line) - + (" " * c.start_mark.column) - + c.value - ) - self.tokens[0].add_post_comment(c) - self.fetch_more_tokens() - while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken): - self.tokens_taken += 1 - c1 = self.tokens.pop(1) - c.value = c.value + (" " * c1.start_mark.column) + c1.value - self.fetch_more_tokens() - self.tokens_taken += 1 - return self.tokens.pop(0) - return None - - def fetch_comment(self, comment): - # type: (Any) -> None - value, start_mark, end_mark = comment - while value and value[-1] == " ": - # empty line within indented key context - # no need to update end-mark, that is not used - value = value[:-1] - self.tokens.append(CommentToken(value, start_mark, end_mark)) - - # scanner - - def scan_to_next_token(self): - # type: () -> Any - # We ignore spaces, line breaks and comments. - # If we find a line break in the block context, we set the flag - # `allow_simple_key` on. - # The byte order mark is stripped if it's the first character in the - # stream. We do not yet support BOM inside the stream as the - # specification requires. Any such mark will be considered as a part - # of the document. - # - # TODO: We need to make tab handling rules more sane. A good rule is - # Tabs cannot precede tokens - # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END, - # KEY(block), VALUE(block), BLOCK-ENTRY - # So the checking code is - # if : - # self.allow_simple_keys = False - # We also need to add the check for `allow_simple_keys == True` to - # `unwind_indent` before issuing BLOCK-END. - # Scanners for block, flow, and plain scalars need to be modified. - - srp = self.reader.peek - srf = self.reader.forward - if self.reader.index == 0 and srp() == "\uFEFF": - srf() - found = False - while not found: - while srp() == " ": - srf() - ch = srp() - if ch == "#": - start_mark = self.reader.get_mark() - comment = ch - srf() - while ch not in _THE_END: - ch = srp() - if ch == "\0": # don't gobble the end-of-stream character - # but add an explicit newline as "YAML processors should terminate - # the stream with an explicit line break - # https://yaml.org/spec/1.2/spec.html#id2780069 - comment += "\n" - break - comment += ch - srf() - # gather any blank lines following the comment too - ch = self.scan_line_break() - while len(ch) > 0: - comment += ch - ch = self.scan_line_break() - end_mark = self.reader.get_mark() - if not self.flow_level: - self.allow_simple_key = True - return comment, start_mark, end_mark - if bool(self.scan_line_break()): - start_mark = self.reader.get_mark() - if not self.flow_level: - self.allow_simple_key = True - ch = srp() - if ch == "\n": # empty toplevel lines - start_mark = self.reader.get_mark() - comment = "" - while ch: - ch = self.scan_line_break(empty_line=True) - comment += ch - if srp() == "#": - # empty line followed by indented real comment - comment = comment.rsplit("\n", 1)[0] + "\n" - end_mark = self.reader.get_mark() - return comment, start_mark, end_mark - else: - found = True - return None - - def scan_line_break(self, empty_line=False): - # type: (bool) -> Text - # Transforms: - # '\r\n' : '\n' - # '\r' : '\n' - # '\n' : '\n' - # '\x85' : '\n' - # '\u2028' : '\u2028' - # '\u2029 : '\u2029' - # default : '' - ch = self.reader.peek() # type: Text - if ch in "\r\n\x85": - if self.reader.prefix(2) == "\r\n": - self.reader.forward(2) - else: - self.reader.forward() - return "\n" - elif ch in "\u2028\u2029": - self.reader.forward() - return ch - elif empty_line and ch in "\t ": - self.reader.forward() - return ch - return "" - - def scan_block_scalar(self, style, rt=True): - # type: (Any, Optional[bool]) -> Any - return Scanner.scan_block_scalar(self, style, rt=rt) - - -# try: -# import psyco -# psyco.bind(Scanner) -# except ImportError: -# pass diff --git a/srsly/ruamel_yaml/serializer.py b/srsly/ruamel_yaml/serializer.py deleted file mode 100755 index 7888cdc..0000000 --- a/srsly/ruamel_yaml/serializer.py +++ /dev/null @@ -1,250 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import - -from .error import YAMLError -from .compat import nprint, DBG_NODE, dbg, string_types, nprintf # NOQA -from .util import RegExp - -from .events import ( - StreamStartEvent, - StreamEndEvent, - MappingStartEvent, - MappingEndEvent, - SequenceStartEvent, - SequenceEndEvent, - AliasEvent, - ScalarEvent, - DocumentStartEvent, - DocumentEndEvent, -) -from .nodes import MappingNode, ScalarNode, SequenceNode - -if False: # MYPY - from typing import Any, Dict, Union, Text, Optional # NOQA - from .compat import VersionType # NOQA - -__all__ = ["Serializer", "SerializerError"] - - -class SerializerError(YAMLError): - pass - - -class Serializer(object): - - # 'id' and 3+ numbers, but not 000 - ANCHOR_TEMPLATE = u"id%03d" - ANCHOR_RE = RegExp(u"id(?!000$)\\d{3,}") - - def __init__( - self, - encoding=None, - explicit_start=None, - explicit_end=None, - version=None, - tags=None, - dumper=None, - ): - # type: (Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> None # NOQA - self.dumper = dumper - if self.dumper is not None: - self.dumper._serializer = self - self.use_encoding = encoding - self.use_explicit_start = explicit_start - self.use_explicit_end = explicit_end - if isinstance(version, string_types): - self.use_version = tuple(map(int, version.split("."))) - else: - self.use_version = version # type: ignore - self.use_tags = tags - self.serialized_nodes = {} # type: Dict[Any, Any] - self.anchors = {} # type: Dict[Any, Any] - self.last_anchor_id = 0 - self.closed = None # type: Optional[bool] - self._templated_id = None - - @property - def emitter(self): - # type: () -> Any - if hasattr(self.dumper, "typ"): - return self.dumper.emitter - return self.dumper._emitter - - @property - def resolver(self): - # type: () -> Any - if hasattr(self.dumper, "typ"): - self.dumper.resolver - return self.dumper._resolver - - def open(self): - # type: () -> None - if self.closed is None: - self.emitter.emit(StreamStartEvent(encoding=self.use_encoding)) - self.closed = False - elif self.closed: - raise SerializerError("serializer is closed") - else: - raise SerializerError("serializer is already opened") - - def close(self): - # type: () -> None - if self.closed is None: - raise SerializerError("serializer is not opened") - elif not self.closed: - self.emitter.emit(StreamEndEvent()) - self.closed = True - - # def __del__(self): - # self.close() - - def serialize(self, node): - # type: (Any) -> None - if dbg(DBG_NODE): - nprint("Serializing nodes") - node.dump() - if self.closed is None: - raise SerializerError("serializer is not opened") - elif self.closed: - raise SerializerError("serializer is closed") - self.emitter.emit( - DocumentStartEvent( - explicit=self.use_explicit_start, - version=self.use_version, - tags=self.use_tags, - ) - ) - self.anchor_node(node) - self.serialize_node(node, None, None) - self.emitter.emit(DocumentEndEvent(explicit=self.use_explicit_end)) - self.serialized_nodes = {} - self.anchors = {} - self.last_anchor_id = 0 - - def anchor_node(self, node): - # type: (Any) -> None - if node in self.anchors: - if self.anchors[node] is None: - self.anchors[node] = self.generate_anchor(node) - else: - anchor = None - try: - if node.anchor.always_dump: - anchor = node.anchor.value - except: # NOQA - pass - self.anchors[node] = anchor - if isinstance(node, SequenceNode): - for item in node.value: - self.anchor_node(item) - elif isinstance(node, MappingNode): - for key, value in node.value: - self.anchor_node(key) - self.anchor_node(value) - - def generate_anchor(self, node): - # type: (Any) -> Any - try: - anchor = node.anchor.value - except: # NOQA - anchor = None - if anchor is None: - self.last_anchor_id += 1 - return self.ANCHOR_TEMPLATE % self.last_anchor_id - return anchor - - def serialize_node(self, node, parent, index): - # type: (Any, Any, Any) -> None - alias = self.anchors[node] - if node in self.serialized_nodes: - self.emitter.emit(AliasEvent(alias)) - else: - self.serialized_nodes[node] = True - self.resolver.descend_resolver(parent, index) - if isinstance(node, ScalarNode): - # here check if the node.tag equals the one that would result from parsing - # if not equal quoting is necessary for strings - detected_tag = self.resolver.resolve( - ScalarNode, node.value, (True, False) - ) - default_tag = self.resolver.resolve( - ScalarNode, node.value, (False, True) - ) - implicit = ( - (node.tag == detected_tag), - (node.tag == default_tag), - node.tag.startswith("tag:yaml.org,2002:"), - ) - self.emitter.emit( - ScalarEvent( - alias, - node.tag, - implicit, - node.value, - style=node.style, - comment=node.comment, - ) - ) - elif isinstance(node, SequenceNode): - implicit = node.tag == self.resolver.resolve( - SequenceNode, node.value, True - ) - comment = node.comment - end_comment = None - seq_comment = None - if node.flow_style is True: - if comment: # eol comment on flow style sequence - seq_comment = comment[0] - # comment[0] = None - if comment and len(comment) > 2: - end_comment = comment[2] - else: - end_comment = None - self.emitter.emit( - SequenceStartEvent( - alias, - node.tag, - implicit, - flow_style=node.flow_style, - comment=node.comment, - ) - ) - index = 0 - for item in node.value: - self.serialize_node(item, node, index) - index += 1 - self.emitter.emit(SequenceEndEvent(comment=[seq_comment, end_comment])) - elif isinstance(node, MappingNode): - implicit = node.tag == self.resolver.resolve( - MappingNode, node.value, True - ) - comment = node.comment - end_comment = None - map_comment = None - if node.flow_style is True: - if comment: # eol comment on flow style sequence - map_comment = comment[0] - # comment[0] = None - if comment and len(comment) > 2: - end_comment = comment[2] - self.emitter.emit( - MappingStartEvent( - alias, - node.tag, - implicit, - flow_style=node.flow_style, - comment=node.comment, - nr_items=len(node.value), - ) - ) - for key, value in node.value: - self.serialize_node(key, node, None) - self.serialize_node(value, node, key) - self.emitter.emit(MappingEndEvent(comment=[map_comment, end_comment])) - self.resolver.ascend_resolver() - - -def templated_id(s): - # type: (Text) -> Any - return Serializer.ANCHOR_RE.match(s) diff --git a/srsly/ruamel_yaml/timestamp.py b/srsly/ruamel_yaml/timestamp.py deleted file mode 100755 index 374e4c0..0000000 --- a/srsly/ruamel_yaml/timestamp.py +++ /dev/null @@ -1,28 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -import datetime -import copy - -# ToDo: at least on PY3 you could probably attach the tzinfo correctly to the object -# a more complete datetime might be used by safe loading as well - -if False: # MYPY - from typing import Any, Dict, Optional, List # NOQA - - -class TimeStamp(datetime.datetime): - def __init__(self, *args, **kw): - # type: (Any, Any) -> None - self._yaml = dict(t=False, tz=None, delta=0) # type: Dict[Any, Any] - - def __new__(cls, *args, **kw): # datetime is immutable - # type: (Any, Any) -> Any - return datetime.datetime.__new__(cls, *args, **kw) # type: ignore - - def __deepcopy__(self, memo): - # type: (Any) -> Any - ts = TimeStamp(self.year, self.month, self.day, self.hour, self.minute, self.second) - ts._yaml = copy.deepcopy(self._yaml) - return ts diff --git a/srsly/ruamel_yaml/tokens.py b/srsly/ruamel_yaml/tokens.py deleted file mode 100755 index 5f5a663..0000000 --- a/srsly/ruamel_yaml/tokens.py +++ /dev/null @@ -1,286 +0,0 @@ -# # header -# coding: utf-8 - -from __future__ import unicode_literals - -if False: # MYPY - from typing import Text, Any, Dict, Optional, List # NOQA - from .error import StreamMark # NOQA - -SHOWLINES = True - - -class Token(object): - __slots__ = 'start_mark', 'end_mark', '_comment' - - def __init__(self, start_mark, end_mark): - # type: (StreamMark, StreamMark) -> None - self.start_mark = start_mark - self.end_mark = end_mark - - def __repr__(self): - # type: () -> Any - # attributes = [key for key in self.__slots__ if not key.endswith('_mark') and - # hasattr('self', key)] - attributes = [key for key in self.__slots__ if not key.endswith('_mark')] - attributes.sort() - arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes]) - if SHOWLINES: - try: - arguments += ', line: ' + str(self.start_mark.line) - except: # NOQA - pass - try: - arguments += ', comment: ' + str(self._comment) - except: # NOQA - pass - return '{}({})'.format(self.__class__.__name__, arguments) - - def add_post_comment(self, comment): - # type: (Any) -> None - if not hasattr(self, '_comment'): - self._comment = [None, None] - self._comment[0] = comment - - def add_pre_comments(self, comments): - # type: (Any) -> None - if not hasattr(self, '_comment'): - self._comment = [None, None] - assert self._comment[1] is None - self._comment[1] = comments - - def get_comment(self): - # type: () -> Any - return getattr(self, '_comment', None) - - @property - def comment(self): - # type: () -> Any - return getattr(self, '_comment', None) - - def move_comment(self, target, empty=False): - # type: (Any, bool) -> Any - """move a comment from this token to target (normally next token) - used to combine e.g. comments before a BlockEntryToken to the - ScalarToken that follows it - empty is a special for empty values -> comment after key - """ - c = self.comment - if c is None: - return - # don't push beyond last element - if isinstance(target, (StreamEndToken, DocumentStartToken)): - return - delattr(self, '_comment') - tc = target.comment - if not tc: # target comment, just insert - # special for empty value in key: value issue 25 - if empty: - c = [c[0], c[1], None, None, c[0]] - target._comment = c - # nprint('mco2:', self, target, target.comment, empty) - return self - if c[0] and tc[0] or c[1] and tc[1]: - raise NotImplementedError('overlap in comment %r %r' % (c, tc)) - if c[0]: - tc[0] = c[0] - if c[1]: - tc[1] = c[1] - return self - - def split_comment(self): - # type: () -> Any - """ split the post part of a comment, and return it - as comment to be added. Delete second part if [None, None] - abc: # this goes to sequence - # this goes to first element - - first element - """ - comment = self.comment - if comment is None or comment[0] is None: - return None # nothing to do - ret_val = [comment[0], None] - if comment[1] is None: - delattr(self, '_comment') - return ret_val - - -# class BOMToken(Token): -# id = '' - - -class DirectiveToken(Token): - __slots__ = 'name', 'value' - id = '' - - def __init__(self, name, value, start_mark, end_mark): - # type: (Any, Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.name = name - self.value = value - - -class DocumentStartToken(Token): - __slots__ = () - id = '' - - -class DocumentEndToken(Token): - __slots__ = () - id = '' - - -class StreamStartToken(Token): - __slots__ = ('encoding',) - id = '' - - def __init__(self, start_mark=None, end_mark=None, encoding=None): - # type: (Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.encoding = encoding - - -class StreamEndToken(Token): - __slots__ = () - id = '' - - -class BlockSequenceStartToken(Token): - __slots__ = () - id = '' - - -class BlockMappingStartToken(Token): - __slots__ = () - id = '' - - -class BlockEndToken(Token): - __slots__ = () - id = '' - - -class FlowSequenceStartToken(Token): - __slots__ = () - id = '[' - - -class FlowMappingStartToken(Token): - __slots__ = () - id = '{' - - -class FlowSequenceEndToken(Token): - __slots__ = () - id = ']' - - -class FlowMappingEndToken(Token): - __slots__ = () - id = '}' - - -class KeyToken(Token): - __slots__ = () - id = '?' - - # def x__repr__(self): - # return 'KeyToken({})'.format( - # self.start_mark.buffer[self.start_mark.index:].split(None, 1)[0]) - - -class ValueToken(Token): - __slots__ = () - id = ':' - - -class BlockEntryToken(Token): - __slots__ = () - id = '-' - - -class FlowEntryToken(Token): - __slots__ = () - id = ',' - - -class AliasToken(Token): - __slots__ = ('value',) - id = '' - - def __init__(self, value, start_mark, end_mark): - # type: (Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.value = value - - -class AnchorToken(Token): - __slots__ = ('value',) - id = '' - - def __init__(self, value, start_mark, end_mark): - # type: (Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.value = value - - -class TagToken(Token): - __slots__ = ('value',) - id = '' - - def __init__(self, value, start_mark, end_mark): - # type: (Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.value = value - - -class ScalarToken(Token): - __slots__ = 'value', 'plain', 'style' - id = '' - - def __init__(self, value, plain, start_mark, end_mark, style=None): - # type: (Any, Any, Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.value = value - self.plain = plain - self.style = style - - -class CommentToken(Token): - __slots__ = 'value', 'pre_done' - id = '' - - def __init__(self, value, start_mark, end_mark): - # type: (Any, Any, Any) -> None - Token.__init__(self, start_mark, end_mark) - self.value = value - - def reset(self): - # type: () -> None - if hasattr(self, 'pre_done'): - delattr(self, 'pre_done') - - def __repr__(self): - # type: () -> Any - v = '{!r}'.format(self.value) - if SHOWLINES: - try: - v += ', line: ' + str(self.start_mark.line) - v += ', col: ' + str(self.start_mark.column) - except: # NOQA - pass - return 'CommentToken({})'.format(v) - - def __eq__(self, other): - # type: (Any) -> bool - if self.start_mark != other.start_mark: - return False - if self.end_mark != other.end_mark: - return False - if self.value != other.value: - return False - return True - - def __ne__(self, other): - # type: (Any) -> bool - return not self.__eq__(other) diff --git a/srsly/ruamel_yaml/util.py b/srsly/ruamel_yaml/util.py deleted file mode 100755 index 3eb7d76..0000000 --- a/srsly/ruamel_yaml/util.py +++ /dev/null @@ -1,190 +0,0 @@ -# coding: utf-8 - -""" -some helper functions that might be generally useful -""" - -from __future__ import absolute_import, print_function - -from functools import partial -import re - -from .compat import text_type, binary_type - -if False: # MYPY - from typing import Any, Dict, Optional, List, Text # NOQA - from .compat import StreamTextType # NOQA - - -class LazyEval(object): - """ - Lightweight wrapper around lazily evaluated func(*args, **kwargs). - - func is only evaluated when any attribute of its return value is accessed. - Every attribute access is passed through to the wrapped value. - (This only excludes special cases like method-wrappers, e.g., __hash__.) - The sole additional attribute is the lazy_self function which holds the - return value (or, prior to evaluation, func and arguments), in its closure. - """ - - def __init__(self, func, *args, **kwargs): - # type: (Any, Any, Any) -> None - def lazy_self(): - # type: () -> Any - return_value = func(*args, **kwargs) - object.__setattr__(self, 'lazy_self', lambda: return_value) - return return_value - - object.__setattr__(self, 'lazy_self', lazy_self) - - def __getattribute__(self, name): - # type: (Any) -> Any - lazy_self = object.__getattribute__(self, 'lazy_self') - if name == 'lazy_self': - return lazy_self - return getattr(lazy_self(), name) - - def __setattr__(self, name, value): - # type: (Any, Any) -> None - setattr(self.lazy_self(), name, value) - - -RegExp = partial(LazyEval, re.compile) - - -# originally as comment -# https://github.com/pre-commit/pre-commit/pull/211#issuecomment-186466605 -# if you use this in your code, I suggest adding a test in your test suite -# that check this routines output against a known piece of your YAML -# before upgrades to this code break your round-tripped YAML -def load_yaml_guess_indent(stream, **kw): - # type: (StreamTextType, Any) -> Any - """guess the indent and block sequence indent of yaml stream/string - - returns round_trip_loaded stream, indent level, block sequence indent - - block sequence indent is the number of spaces before a dash relative to previous indent - - if there are no block sequences, indent is taken from nested mappings, block sequence - indent is unset (None) in that case - """ - from .main import round_trip_load - - # load a yaml file guess the indentation, if you use TABs ... - def leading_spaces(l): - # type: (Any) -> int - idx = 0 - while idx < len(l) and l[idx] == ' ': - idx += 1 - return idx - - if isinstance(stream, text_type): - yaml_str = stream # type: Any - elif isinstance(stream, binary_type): - # most likely, but the Reader checks BOM for this - yaml_str = stream.decode('utf-8') - else: - yaml_str = stream.read() - map_indent = None - indent = None # default if not found for some reason - block_seq_indent = None - prev_line_key_only = None - key_indent = 0 - for line in yaml_str.splitlines(): - rline = line.rstrip() - lline = rline.lstrip() - if lline.startswith('- '): - l_s = leading_spaces(line) - block_seq_indent = l_s - key_indent - idx = l_s + 1 - while line[idx] == ' ': # this will end as we rstripped - idx += 1 - if line[idx] == '#': # comment after - - continue - indent = idx - key_indent - break - if map_indent is None and prev_line_key_only is not None and rline: - idx = 0 - while line[idx] in ' -': - idx += 1 - if idx > prev_line_key_only: - map_indent = idx - prev_line_key_only - if rline.endswith(':'): - key_indent = leading_spaces(line) - idx = 0 - while line[idx] == ' ': # this will end on ':' - idx += 1 - prev_line_key_only = idx - continue - prev_line_key_only = None - if indent is None and map_indent is not None: - indent = map_indent - return round_trip_load(yaml_str, **kw), indent, block_seq_indent - - -def configobj_walker(cfg): - # type: (Any) -> Any - """ - walks over a ConfigObj (INI file with comments) generating - corresponding YAML output (including comments - """ - from configobj import ConfigObj # type: ignore - - assert isinstance(cfg, ConfigObj) - for c in cfg.initial_comment: - if c.strip(): - yield c - for s in _walk_section(cfg): - if s.strip(): - yield s - for c in cfg.final_comment: - if c.strip(): - yield c - - -def _walk_section(s, level=0): - # type: (Any, int) -> Any - from configobj import Section - - assert isinstance(s, Section) - indent = u' ' * level - for name in s.scalars: - for c in s.comments[name]: - yield indent + c.strip() - x = s[name] - if u'\n' in x: - i = indent + u' ' - x = u'|\n' + i + x.strip().replace(u'\n', u'\n' + i) - elif ':' in x: - x = u"'" + x.replace(u"'", u"''") + u"'" - line = u'{0}{1}: {2}'.format(indent, name, x) - c = s.inline_comments[name] - if c: - line += u' ' + c - yield line - for name in s.sections: - for c in s.comments[name]: - yield indent + c.strip() - line = u'{0}{1}:'.format(indent, name) - c = s.inline_comments[name] - if c: - line += u' ' + c - yield line - for val in _walk_section(s[name], level=level + 1): - yield val - - -# def config_obj_2_rt_yaml(cfg): -# from .comments import CommentedMap, CommentedSeq -# from configobj import ConfigObj -# assert isinstance(cfg, ConfigObj) -# #for c in cfg.initial_comment: -# # if c.strip(): -# # pass -# cm = CommentedMap() -# for name in s.sections: -# cm[name] = d = CommentedMap() -# -# -# #for c in cfg.final_comment: -# # if c.strip(): -# # yield c -# return cm diff --git a/srsly/tests/cloudpickle/__init__.py b/srsly/tests/cloudpickle/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/srsly/tests/cloudpickle/cloudpickle_file_test.py b/srsly/tests/cloudpickle/cloudpickle_file_test.py deleted file mode 100644 index 218566f..0000000 --- a/srsly/tests/cloudpickle/cloudpickle_file_test.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import shutil -import sys -import tempfile -import unittest - -import pytest - -import srsly.cloudpickle as cloudpickle -from srsly.cloudpickle.compat import pickle - - -class CloudPickleFileTests(unittest.TestCase): - """In Cloudpickle, expected behaviour when pickling an opened file - is to send its contents over the wire and seek to the same position.""" - - def setUp(self): - self.tmpdir = tempfile.mkdtemp() - self.tmpfilepath = os.path.join(self.tmpdir, 'testfile') - self.teststring = 'Hello world!' - - def tearDown(self): - shutil.rmtree(self.tmpdir) - - def test_empty_file(self): - # Empty file - open(self.tmpfilepath, 'w').close() - with open(self.tmpfilepath, 'r') as f: - self.assertEqual('', pickle.loads(cloudpickle.dumps(f)).read()) - os.remove(self.tmpfilepath) - - def test_closed_file(self): - # Write & close - with open(self.tmpfilepath, 'w') as f: - f.write(self.teststring) - with pytest.raises(pickle.PicklingError) as excinfo: - cloudpickle.dumps(f) - assert "Cannot pickle closed files" in str(excinfo.value) - os.remove(self.tmpfilepath) - - def test_r_mode(self): - # Write & close - with open(self.tmpfilepath, 'w') as f: - f.write(self.teststring) - # Open for reading - with open(self.tmpfilepath, 'r') as f: - new_f = pickle.loads(cloudpickle.dumps(f)) - self.assertEqual(self.teststring, new_f.read()) - os.remove(self.tmpfilepath) - - def test_w_mode(self): - with open(self.tmpfilepath, 'w') as f: - f.write(self.teststring) - f.seek(0) - self.assertRaises(pickle.PicklingError, - lambda: cloudpickle.dumps(f)) - os.remove(self.tmpfilepath) - - def test_plus_mode(self): - # Write, then seek to 0 - with open(self.tmpfilepath, 'w+') as f: - f.write(self.teststring) - f.seek(0) - new_f = pickle.loads(cloudpickle.dumps(f)) - self.assertEqual(self.teststring, new_f.read()) - os.remove(self.tmpfilepath) - - def test_seek(self): - # Write, then seek to arbitrary position - with open(self.tmpfilepath, 'w+') as f: - f.write(self.teststring) - f.seek(4) - unpickled = pickle.loads(cloudpickle.dumps(f)) - # unpickled StringIO is at position 4 - self.assertEqual(4, unpickled.tell()) - self.assertEqual(self.teststring[4:], unpickled.read()) - # but unpickled StringIO also contained the start - unpickled.seek(0) - self.assertEqual(self.teststring, unpickled.read()) - os.remove(self.tmpfilepath) - - @pytest.mark.skip(reason="Requires pytest -s to pass") - def test_pickling_special_file_handles(self): - # Warning: if you want to run your tests with nose, add -s option - for out in sys.stdout, sys.stderr: # Regression test for SPARK-3415 - self.assertEqual(out, pickle.loads(cloudpickle.dumps(out))) - self.assertRaises(pickle.PicklingError, - lambda: cloudpickle.dumps(sys.stdin)) - - -if __name__ == '__main__': - unittest.main() diff --git a/srsly/tests/cloudpickle/cloudpickle_test.py b/srsly/tests/cloudpickle/cloudpickle_test.py deleted file mode 100644 index 25451df..0000000 --- a/srsly/tests/cloudpickle/cloudpickle_test.py +++ /dev/null @@ -1,2847 +0,0 @@ -import _collections_abc -import abc -import collections -import base64 -import functools -import io -import itertools -import logging -import math -import multiprocessing -from operator import itemgetter, attrgetter -import pickletools -import platform -import random -import re -import shutil -import subprocess -import sys -import tempfile -import textwrap -import types -import unittest -import weakref -import os -import enum -import typing -from functools import wraps - -import pytest - -try: - # try importing numpy and scipy. These are not hard dependencies and - # tests should be skipped if these modules are not available - import numpy as np - import scipy.special as spp -except (ImportError, RuntimeError): - np = None - spp = None - -try: - # Ditto for Tornado - import tornado -except ImportError: - tornado = None - -import srsly.cloudpickle as cloudpickle -from srsly.cloudpickle.compat import pickle -from srsly.cloudpickle import register_pickle_by_value -from srsly.cloudpickle import unregister_pickle_by_value -from srsly.cloudpickle import list_registry_pickle_by_value -from srsly.cloudpickle.cloudpickle import _should_pickle_by_reference -from srsly.cloudpickle.cloudpickle import _make_empty_cell, cell_set -from srsly.cloudpickle.cloudpickle import _extract_class_dict, _whichmodule -from srsly.cloudpickle.cloudpickle import _lookup_module_and_qualname - -from .testutils import subprocess_pickle_echo -from .testutils import subprocess_pickle_string -from .testutils import assert_run_python_script -from .testutils import subprocess_worker - - -_TEST_GLOBAL_VARIABLE = "default_value" -_TEST_GLOBAL_VARIABLE2 = "another_value" - - -class RaiserOnPickle: - - def __init__(self, exc): - self.exc = exc - - def __reduce__(self): - raise self.exc - - -def pickle_depickle(obj, protocol=cloudpickle.DEFAULT_PROTOCOL): - """Helper function to test whether object pickled with cloudpickle can be - depickled with pickle - """ - return pickle.loads(cloudpickle.dumps(obj, protocol=protocol)) - - -def _escape(raw_filepath): - # Ugly hack to embed filepaths in code templates for windows - return raw_filepath.replace("\\", r"\\\\") - - -def _maybe_remove(list_, item): - try: - list_.remove(item) - except ValueError: - pass - return list_ - - -def test_extract_class_dict(): - class A(int): - """A docstring""" - def method(self): - return "a" - - class B: - """B docstring""" - B_CONSTANT = 42 - - def method(self): - return "b" - - class C(A, B): - C_CONSTANT = 43 - - def method_c(self): - return "c" - - clsdict = _extract_class_dict(C) - if sys.version_info >= (3, 13): - expected_keys = ["C_CONSTANT", "__doc__", "__firstlineno__", "method_c"] - else: - expected_keys = ["C_CONSTANT", "__doc__", "method_c"] - assert sorted(clsdict) == expected_keys - assert clsdict["C_CONSTANT"] == 43 - assert clsdict["__doc__"] is None - assert clsdict["method_c"](C()) == C().method_c() - - -class CloudPickleTest(unittest.TestCase): - - protocol = cloudpickle.DEFAULT_PROTOCOL - - def setUp(self): - self.tmpdir = tempfile.mkdtemp(prefix="tmp_cloudpickle_test_") - - def tearDown(self): - shutil.rmtree(self.tmpdir) - - @pytest.mark.skipif( - platform.python_implementation() != "CPython" or - (sys.version_info >= (3, 8, 0) and sys.version_info < (3, 8, 2)), - reason="Underlying bug fixed upstream starting Python 3.8.2") - def test_reducer_override_reference_cycle(self): - # Early versions of Python 3.8 introduced a reference cycle between a - # Pickler and it's reducer_override method. Because a Pickler - # object references every object it has pickled through its memo, this - # cycle prevented the garbage-collection of those external pickled - # objects. See #327 as well as https://bugs.python.org/issue39492 - # This bug was fixed in Python 3.8.2, but is still present using - # cloudpickle and Python 3.8.0/1, hence the skipif directive. - class MyClass: - pass - - my_object = MyClass() - wr = weakref.ref(my_object) - - cloudpickle.dumps(my_object) - del my_object - assert wr() is None, "'del'-ed my_object has not been collected" - - def test_itemgetter(self): - d = range(10) - getter = itemgetter(1) - - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - - getter = itemgetter(0, 3) - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - - def test_attrgetter(self): - class C: - def __getattr__(self, item): - return item - d = C() - getter = attrgetter("a") - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("a", "b") - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - - d.e = C() - getter = attrgetter("e.a") - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - getter = attrgetter("e.a", "e.b") - getter2 = pickle_depickle(getter, protocol=self.protocol) - self.assertEqual(getter(d), getter2(d)) - - # Regression test for SPARK-3415 - @pytest.mark.skip(reason="Requires pytest -s to pass") - def test_pickling_file_handles(self): - out1 = sys.stderr - out2 = pickle.loads(cloudpickle.dumps(out1, protocol=self.protocol)) - self.assertEqual(out1, out2) - - def test_func_globals(self): - class Unpicklable: - def __reduce__(self): - raise Exception("not picklable") - - global exit - exit = Unpicklable() - - self.assertRaises(Exception, lambda: cloudpickle.dumps( - exit, protocol=self.protocol)) - - def foo(): - sys.exit(0) - - self.assertTrue("exit" in foo.__code__.co_names) - cloudpickle.dumps(foo) - - def test_buffer(self): - try: - buffer_obj = buffer("Hello") - buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol) - self.assertEqual(buffer_clone, str(buffer_obj)) - buffer_obj = buffer("Hello", 2, 3) - buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol) - self.assertEqual(buffer_clone, str(buffer_obj)) - except NameError: # Python 3 does no longer support buffers - pass - - def test_memoryview(self): - buffer_obj = memoryview(b"Hello") - self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol), - buffer_obj.tobytes()) - - def test_dict_keys(self): - keys = {"a": 1, "b": 2}.keys() - results = pickle_depickle(keys) - self.assertEqual(results, keys) - assert isinstance(results, _collections_abc.dict_keys) - - def test_dict_values(self): - values = {"a": 1, "b": 2}.values() - results = pickle_depickle(values) - self.assertEqual(sorted(results), sorted(values)) - assert isinstance(results, _collections_abc.dict_values) - - def test_dict_items(self): - items = {"a": 1, "b": 2}.items() - results = pickle_depickle(items) - self.assertEqual(results, items) - assert isinstance(results, _collections_abc.dict_items) - - def test_odict_keys(self): - keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys() - results = pickle_depickle(keys) - self.assertEqual(results, keys) - assert type(keys) == type(results) - - def test_odict_values(self): - values = collections.OrderedDict([("a", 1), ("b", 2)]).values() - results = pickle_depickle(values) - self.assertEqual(list(results), list(values)) - assert type(values) == type(results) - - def test_odict_items(self): - items = collections.OrderedDict([("a", 1), ("b", 2)]).items() - results = pickle_depickle(items) - self.assertEqual(results, items) - assert type(items) == type(results) - - def test_sliced_and_non_contiguous_memoryview(self): - buffer_obj = memoryview(b"Hello!" * 3)[2:15:2] - self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol), - buffer_obj.tobytes()) - - def test_large_memoryview(self): - buffer_obj = memoryview(b"Hello!" * int(1e7)) - self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol), - buffer_obj.tobytes()) - - def test_lambda(self): - self.assertEqual( - pickle_depickle(lambda: 1, protocol=self.protocol)(), 1) - - def test_nested_lambdas(self): - a, b = 1, 2 - f1 = lambda x: x + a - f2 = lambda x: f1(x) // b - self.assertEqual(pickle_depickle(f2, protocol=self.protocol)(1), 1) - - def test_recursive_closure(self): - def f1(): - def g(): - return g - return g - - def f2(base): - def g(n): - return base if n <= 1 else n * g(n - 1) - return g - - g1 = pickle_depickle(f1(), protocol=self.protocol) - self.assertEqual(g1(), g1) - - g2 = pickle_depickle(f2(2), protocol=self.protocol) - self.assertEqual(g2(5), 240) - - def test_closure_none_is_preserved(self): - def f(): - """a function with no closure cells - """ - - self.assertTrue( - f.__closure__ is None, - msg='f actually has closure cells!', - ) - - g = pickle_depickle(f, protocol=self.protocol) - - self.assertTrue( - g.__closure__ is None, - msg='g now has closure cells even though f does not', - ) - - def test_empty_cell_preserved(self): - def f(): - if False: # pragma: no cover - cell = None - - def g(): - cell # NameError, unbound free variable - - return g - - g1 = f() - with pytest.raises(NameError): - g1() - - g2 = pickle_depickle(g1, protocol=self.protocol) - with pytest.raises(NameError): - g2() - - def test_unhashable_closure(self): - def f(): - s = {1, 2} # mutable set is unhashable - - def g(): - return len(s) - - return g - - g = pickle_depickle(f(), protocol=self.protocol) - self.assertEqual(g(), 2) - - def test_dynamically_generated_class_that_uses_super(self): - - class Base: - def method(self): - return 1 - - class Derived(Base): - "Derived Docstring" - def method(self): - return super().method() + 1 - - self.assertEqual(Derived().method(), 2) - - # Pickle and unpickle the class. - UnpickledDerived = pickle_depickle(Derived, protocol=self.protocol) - self.assertEqual(UnpickledDerived().method(), 2) - - # We have special logic for handling __doc__ because it's a readonly - # attribute on PyPy. - self.assertEqual(UnpickledDerived.__doc__, "Derived Docstring") - - # Pickle and unpickle an instance. - orig_d = Derived() - d = pickle_depickle(orig_d, protocol=self.protocol) - self.assertEqual(d.method(), 2) - - def test_cycle_in_classdict_globals(self): - - class C: - - def it_works(self): - return "woohoo!" - - C.C_again = C - C.instance_of_C = C() - - depickled_C = pickle_depickle(C, protocol=self.protocol) - depickled_instance = pickle_depickle(C()) - - # Test instance of depickled class. - self.assertEqual(depickled_C().it_works(), "woohoo!") - self.assertEqual(depickled_C.C_again().it_works(), "woohoo!") - self.assertEqual(depickled_C.instance_of_C.it_works(), "woohoo!") - self.assertEqual(depickled_instance.it_works(), "woohoo!") - - def test_locally_defined_function_and_class(self): - LOCAL_CONSTANT = 42 - - def some_function(x, y): - # Make sure the __builtins__ are not broken (see #211) - sum(range(10)) - return (x + y) / LOCAL_CONSTANT - - # pickle the function definition - self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(41, 1), 1) - self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(81, 3), 2) - - hidden_constant = lambda: LOCAL_CONSTANT - - class SomeClass: - """Overly complicated class with nested references to symbols""" - def __init__(self, value): - self.value = value - - def one(self): - return LOCAL_CONSTANT / hidden_constant() - - def some_method(self, x): - return self.one() + some_function(x, 1) + self.value - - # pickle the class definition - clone_class = pickle_depickle(SomeClass, protocol=self.protocol) - self.assertEqual(clone_class(1).one(), 1) - self.assertEqual(clone_class(5).some_method(41), 7) - clone_class = subprocess_pickle_echo(SomeClass, protocol=self.protocol) - self.assertEqual(clone_class(5).some_method(41), 7) - - # pickle the class instances - self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1) - self.assertEqual(pickle_depickle(SomeClass(5)).some_method(41), 7) - new_instance = subprocess_pickle_echo(SomeClass(5), - protocol=self.protocol) - self.assertEqual(new_instance.some_method(41), 7) - - # pickle the method instances - self.assertEqual(pickle_depickle(SomeClass(1).one)(), 1) - self.assertEqual(pickle_depickle(SomeClass(5).some_method)(41), 7) - new_method = subprocess_pickle_echo(SomeClass(5).some_method, - protocol=self.protocol) - self.assertEqual(new_method(41), 7) - - def test_partial(self): - partial_obj = functools.partial(min, 1) - partial_clone = pickle_depickle(partial_obj, protocol=self.protocol) - self.assertEqual(partial_clone(4), 1) - - @pytest.mark.skipif(platform.python_implementation() == 'PyPy', - reason="Skip numpy and scipy tests on PyPy") - def test_ufunc(self): - # test a numpy ufunc (universal function), which is a C-based function - # that is applied on a numpy array - - if np: - # simple ufunc: np.add - self.assertEqual(pickle_depickle(np.add, protocol=self.protocol), - np.add) - else: # skip if numpy is not available - pass - - if spp: - # custom ufunc: scipy.special.iv - self.assertEqual(pickle_depickle(spp.iv, protocol=self.protocol), - spp.iv) - else: # skip if scipy is not available - pass - - def test_loads_namespace(self): - obj = 1, 2, 3, 4 - returned_obj = cloudpickle.loads(cloudpickle.dumps( - obj, protocol=self.protocol)) - self.assertEqual(obj, returned_obj) - - def test_load_namespace(self): - obj = 1, 2, 3, 4 - bio = io.BytesIO() - cloudpickle.dump(obj, bio) - bio.seek(0) - returned_obj = cloudpickle.load(bio) - self.assertEqual(obj, returned_obj) - - def test_generator(self): - - def some_generator(cnt): - for i in range(cnt): - yield i - - gen2 = pickle_depickle(some_generator, protocol=self.protocol) - - assert type(gen2(3)) == type(some_generator(3)) - assert list(gen2(3)) == list(range(3)) - - def test_classmethod(self): - class A: - @staticmethod - def test_sm(): - return "sm" - @classmethod - def test_cm(cls): - return "cm" - - sm = A.__dict__["test_sm"] - cm = A.__dict__["test_cm"] - - A.test_sm = pickle_depickle(sm, protocol=self.protocol) - A.test_cm = pickle_depickle(cm, protocol=self.protocol) - - self.assertEqual(A.test_sm(), "sm") - self.assertEqual(A.test_cm(), "cm") - - def test_bound_classmethod(self): - class A: - @classmethod - def test_cm(cls): - return "cm" - - A.test_cm = pickle_depickle(A.test_cm, protocol=self.protocol) - self.assertEqual(A.test_cm(), "cm") - - def test_method_descriptors(self): - f = pickle_depickle(str.upper) - self.assertEqual(f('abc'), 'ABC') - - def test_instancemethods_without_self(self): - class F: - def f(self, x): - return x + 1 - - g = pickle_depickle(F.f, protocol=self.protocol) - self.assertEqual(g.__name__, F.f.__name__) - # self.assertEqual(g(F(), 1), 2) # still fails - - def test_module(self): - pickle_clone = pickle_depickle(pickle, protocol=self.protocol) - self.assertEqual(pickle, pickle_clone) - - def test_dynamic_module(self): - mod = types.ModuleType('mod') - code = ''' - x = 1 - def f(y): - return x + y - - class Foo: - def method(self, x): - return f(x) - ''' - exec(textwrap.dedent(code), mod.__dict__) - mod2 = pickle_depickle(mod, protocol=self.protocol) - self.assertEqual(mod.x, mod2.x) - self.assertEqual(mod.f(5), mod2.f(5)) - self.assertEqual(mod.Foo().method(5), mod2.Foo().method(5)) - - if platform.python_implementation() != 'PyPy': - # XXX: this fails with excessive recursion on PyPy. - mod3 = subprocess_pickle_echo(mod, protocol=self.protocol) - self.assertEqual(mod.x, mod3.x) - self.assertEqual(mod.f(5), mod3.f(5)) - self.assertEqual(mod.Foo().method(5), mod3.Foo().method(5)) - - # Test dynamic modules when imported back are singletons - mod1, mod2 = pickle_depickle([mod, mod]) - self.assertEqual(id(mod1), id(mod2)) - - # Ensure proper pickling of mod's functions when module "looks" like a - # file-backed module even though it is not: - try: - sys.modules['mod'] = mod - depickled_f = pickle_depickle(mod.f, protocol=self.protocol) - self.assertEqual(mod.f(5), depickled_f(5)) - finally: - sys.modules.pop('mod', None) - - def test_module_locals_behavior(self): - # Makes sure that a local function defined in another module is - # correctly serialized. This notably checks that the globals are - # accessible and that there is no issue with the builtins (see #211) - - pickled_func_path = os.path.join(self.tmpdir, 'local_func_g.pkl') - - child_process_script = ''' - from srsly.cloudpickle.compat import pickle - import gc - with open("{pickled_func_path}", 'rb') as f: - func = pickle.load(f) - - assert func(range(10)) == 45 - ''' - - child_process_script = child_process_script.format( - pickled_func_path=_escape(pickled_func_path)) - - try: - - from srsly.tests.cloudpickle.testutils import make_local_function - - g = make_local_function() - with open(pickled_func_path, 'wb') as f: - cloudpickle.dump(g, f, protocol=self.protocol) - - assert_run_python_script(textwrap.dedent(child_process_script)) - - finally: - os.unlink(pickled_func_path) - - def test_dynamic_module_with_unpicklable_builtin(self): - # Reproducer of https://github.com/cloudpipe/cloudpickle/issues/316 - # Some modules such as scipy inject some unpicklable objects into the - # __builtins__ module, which appears in every module's __dict__ under - # the '__builtins__' key. In such cases, cloudpickle used to fail - # when pickling dynamic modules. - class UnpickleableObject: - def __reduce__(self): - raise ValueError('Unpicklable object') - - mod = types.ModuleType("mod") - - exec('f = lambda x: abs(x)', mod.__dict__) - assert mod.f(-1) == 1 - assert '__builtins__' in mod.__dict__ - - unpicklable_obj = UnpickleableObject() - with pytest.raises(ValueError): - cloudpickle.dumps(unpicklable_obj) - - # Emulate the behavior of scipy by injecting an unpickleable object - # into mod's builtins. - # The __builtins__ entry of mod's __dict__ can either be the - # __builtins__ module, or the __builtins__ module's __dict__. #316 - # happens only in the latter case. - if isinstance(mod.__dict__['__builtins__'], dict): - mod.__dict__['__builtins__']['unpickleable_obj'] = unpicklable_obj - elif isinstance(mod.__dict__['__builtins__'], types.ModuleType): - mod.__dict__['__builtins__'].unpickleable_obj = unpicklable_obj - - depickled_mod = pickle_depickle(mod, protocol=self.protocol) - assert '__builtins__' in depickled_mod.__dict__ - - if isinstance(depickled_mod.__dict__['__builtins__'], dict): - assert "abs" in depickled_mod.__builtins__ - elif isinstance( - depickled_mod.__dict__['__builtins__'], types.ModuleType): - assert hasattr(depickled_mod.__builtins__, "abs") - assert depickled_mod.f(-1) == 1 - - # Additional check testing that the issue #425 is fixed: without the - # fix for #425, `mod.f` would not have access to `__builtins__`, and - # thus calling `mod.f(-1)` (which relies on the `abs` builtin) would - # fail. - assert mod.f(-1) == 1 - - def test_load_dynamic_module_in_grandchild_process(self): - # Make sure that when loaded, a dynamic module preserves its dynamic - # property. Otherwise, this will lead to an ImportError if pickled in - # the child process and reloaded in another one. - - # We create a new dynamic module - mod = types.ModuleType('mod') - code = ''' - x = 1 - ''' - exec(textwrap.dedent(code), mod.__dict__) - - # This script will be ran in a separate child process. It will import - # the pickled dynamic module, and then re-pickle it under a new name. - # Finally, it will create a child process that will load the re-pickled - # dynamic module. - parent_process_module_file = os.path.join( - self.tmpdir, 'dynamic_module_from_parent_process.pkl') - child_process_module_file = os.path.join( - self.tmpdir, 'dynamic_module_from_child_process.pkl') - child_process_script = ''' - from srsly.cloudpickle.compat import pickle - import textwrap - - import srsly.cloudpickle as cloudpickle - from srsly.tests.cloudpickle.testutils import assert_run_python_script - - - child_of_child_process_script = {child_of_child_process_script} - - with open('{parent_process_module_file}', 'rb') as f: - mod = pickle.load(f) - - with open('{child_process_module_file}', 'wb') as f: - cloudpickle.dump(mod, f, protocol={protocol}) - - assert_run_python_script(textwrap.dedent(child_of_child_process_script)) - ''' - - # The script ran by the process created by the child process - child_of_child_process_script = """ ''' - from srsly.cloudpickle.compat import pickle - with open('{child_process_module_file}','rb') as fid: - mod = pickle.load(fid) - ''' """ - - # Filling the two scripts with the pickled modules filepaths and, - # for the first child process, the script to be executed by its - # own child process. - child_of_child_process_script = child_of_child_process_script.format( - child_process_module_file=child_process_module_file) - - child_process_script = child_process_script.format( - parent_process_module_file=_escape(parent_process_module_file), - child_process_module_file=_escape(child_process_module_file), - child_of_child_process_script=_escape(child_of_child_process_script), - protocol=self.protocol) - - try: - with open(parent_process_module_file, 'wb') as fid: - cloudpickle.dump(mod, fid, protocol=self.protocol) - - assert_run_python_script(textwrap.dedent(child_process_script)) - - finally: - # Remove temporary created files - if os.path.exists(parent_process_module_file): - os.unlink(parent_process_module_file) - if os.path.exists(child_process_module_file): - os.unlink(child_process_module_file) - - def test_correct_globals_import(self): - def nested_function(x): - return x + 1 - - def unwanted_function(x): - return math.exp(x) - - def my_small_function(x, y): - return nested_function(x) + y - - b = cloudpickle.dumps(my_small_function, protocol=self.protocol) - - # Make sure that the pickle byte string only includes the definition - # of my_small_function and its dependency nested_function while - # extra functions and modules such as unwanted_function and the math - # module are not included so as to keep the pickle payload as - # lightweight as possible. - - assert b'my_small_function' in b - assert b'nested_function' in b - - assert b'unwanted_function' not in b - assert b'math' not in b - - def test_module_importability(self): - pytest.importorskip("_cloudpickle_testpkg") - from srsly.cloudpickle.compat import pickle - import os.path - import collections - import collections.abc - - assert _should_pickle_by_reference(pickle) - assert _should_pickle_by_reference(os.path) # fake (aliased) module - assert _should_pickle_by_reference(collections) # package - assert _should_pickle_by_reference(collections.abc) # module in package - - dynamic_module = types.ModuleType('dynamic_module') - assert not _should_pickle_by_reference(dynamic_module) - - if platform.python_implementation() == 'PyPy': - import _codecs - assert _should_pickle_by_reference(_codecs) - - # #354: Check that modules created dynamically during the import of - # their parent modules are considered importable by cloudpickle. - # See the mod_with_dynamic_submodule documentation for more - # details of this use case. - import _cloudpickle_testpkg.mod.dynamic_submodule as m - assert _should_pickle_by_reference(m) - assert pickle_depickle(m, protocol=self.protocol) is m - - # Check for similar behavior for a module that cannot be imported by - # attribute lookup. - from _cloudpickle_testpkg.mod import dynamic_submodule_two as m2 - # Note: import _cloudpickle_testpkg.mod.dynamic_submodule_two as m2 - # works only for Python 3.7+ - assert _should_pickle_by_reference(m2) - assert pickle_depickle(m2, protocol=self.protocol) is m2 - - # Submodule_three is a dynamic module only importable via module lookup - with pytest.raises(ImportError): - import _cloudpickle_testpkg.mod.submodule_three # noqa - from _cloudpickle_testpkg.mod import submodule_three as m3 - assert not _should_pickle_by_reference(m3) - - # This module cannot be pickled using attribute lookup (as it does not - # have a `__module__` attribute like classes and functions. - assert not hasattr(m3, '__module__') - depickled_m3 = pickle_depickle(m3, protocol=self.protocol) - assert depickled_m3 is not m3 - assert m3.f(1) == depickled_m3.f(1) - - # Do the same for an importable dynamic submodule inside a dynamic - # module inside a file-backed module. - import _cloudpickle_testpkg.mod.dynamic_submodule.dynamic_subsubmodule as sm # noqa - assert _should_pickle_by_reference(sm) - assert pickle_depickle(sm, protocol=self.protocol) is sm - - expected = "cannot check importability of object instances" - with pytest.raises(TypeError, match=expected): - _should_pickle_by_reference(object()) - - def test_Ellipsis(self): - self.assertEqual(Ellipsis, - pickle_depickle(Ellipsis, protocol=self.protocol)) - - def test_NotImplemented(self): - ExcClone = pickle_depickle(NotImplemented, protocol=self.protocol) - self.assertEqual(NotImplemented, ExcClone) - - def test_NoneType(self): - res = pickle_depickle(type(None), protocol=self.protocol) - self.assertEqual(type(None), res) - - def test_EllipsisType(self): - res = pickle_depickle(type(Ellipsis), protocol=self.protocol) - self.assertEqual(type(Ellipsis), res) - - def test_NotImplementedType(self): - res = pickle_depickle(type(NotImplemented), protocol=self.protocol) - self.assertEqual(type(NotImplemented), res) - - def test_builtin_function(self): - # Note that builtin_function_or_method are special-cased by cloudpickle - # only in python2. - - # builtin function from the __builtin__ module - assert pickle_depickle(zip, protocol=self.protocol) is zip - - from os import mkdir - # builtin function from a "regular" module - assert pickle_depickle(mkdir, protocol=self.protocol) is mkdir - - def test_builtin_type_constructor(self): - # This test makes sure that cloudpickling builtin-type - # constructors works for all python versions/implementation. - - # pickle_depickle some builtin methods of the __builtin__ module - for t in list, tuple, set, frozenset, dict, object: - cloned_new = pickle_depickle(t.__new__, protocol=self.protocol) - assert isinstance(cloned_new(t), t) - - # The next 4 tests cover all cases into which builtin python methods can - # appear. - # There are 4 kinds of method: 'classic' methods, classmethods, - # staticmethods and slotmethods. They will appear under different types - # depending on whether they are called from the __dict__ of their - # class, their class itself, or an instance of their class. This makes - # 12 total combinations. - # This discussion and the following tests are relevant for the CPython - # implementation only. In PyPy, there is no builtin method or builtin - # function types/flavours. The only way into which a builtin method can be - # identified is with it's builtin-code __code__ attribute. - - def test_builtin_classicmethod(self): - obj = 1.5 # float object - - bound_classicmethod = obj.hex # builtin_function_or_method - unbound_classicmethod = type(obj).hex # method_descriptor - clsdict_classicmethod = type(obj).__dict__['hex'] # method_descriptor - - assert unbound_classicmethod is clsdict_classicmethod - - depickled_bound_meth = pickle_depickle( - bound_classicmethod, protocol=self.protocol) - depickled_unbound_meth = pickle_depickle( - unbound_classicmethod, protocol=self.protocol) - depickled_clsdict_meth = pickle_depickle( - clsdict_classicmethod, protocol=self.protocol) - - # No identity on the bound methods they are bound to different float - # instances - assert depickled_bound_meth() == bound_classicmethod() - assert depickled_unbound_meth is unbound_classicmethod - assert depickled_clsdict_meth is clsdict_classicmethod - - - @pytest.mark.skipif( - (platform.machine() == "aarch64" and sys.version_info[:2] >= (3, 10)) - or platform.python_implementation() == "PyPy" - or (sys.version_info[:2] == (3, 10) and sys.version_info >= (3, 10, 8)) - # Skipping tests on 3.11 due to https://github.com/cloudpipe/cloudpickle/pull/486. - or sys.version_info[:2] >= (3, 11), - reason="Fails on aarch64 + python 3.10+ in cibuildwheel, currently unable to replicate failure elsewhere; fails sometimes for pypy on conda-forge; fails for python 3.10.8+ and 3.11+") - def test_builtin_classmethod(self): - obj = 1.5 # float object - - bound_clsmethod = obj.fromhex # builtin_function_or_method - unbound_clsmethod = type(obj).fromhex # builtin_function_or_method - clsdict_clsmethod = type( - obj).__dict__['fromhex'] # classmethod_descriptor - - depickled_bound_meth = pickle_depickle( - bound_clsmethod, protocol=self.protocol) - depickled_unbound_meth = pickle_depickle( - unbound_clsmethod, protocol=self.protocol) - depickled_clsdict_meth = pickle_depickle( - clsdict_clsmethod, protocol=self.protocol) - - # float.fromhex takes a string as input. - arg = "0x1" - - # Identity on both the bound and the unbound methods cannot be - # tested: the bound methods are bound to different objects, and the - # unbound methods are actually recreated at each call. - assert depickled_bound_meth(arg) == bound_clsmethod(arg) - assert depickled_unbound_meth(arg) == unbound_clsmethod(arg) - - if platform.python_implementation() == 'CPython': - # Roundtripping a classmethod_descriptor results in a - # builtin_function_or_method (CPython upstream issue). - assert depickled_clsdict_meth(arg) == clsdict_clsmethod(float, arg) - if platform.python_implementation() == 'PyPy': - # builtin-classmethods are simple classmethod in PyPy (not - # callable). We test equality of types and the functionality of the - # __func__ attribute instead. We do not test the the identity of - # the functions as __func__ attributes of classmethods are not - # pickleable and must be reconstructed at depickling time. - assert type(depickled_clsdict_meth) == type(clsdict_clsmethod) - assert depickled_clsdict_meth.__func__( - float, arg) == clsdict_clsmethod.__func__(float, arg) - - def test_builtin_slotmethod(self): - obj = 1.5 # float object - - bound_slotmethod = obj.__repr__ # method-wrapper - unbound_slotmethod = type(obj).__repr__ # wrapper_descriptor - clsdict_slotmethod = type(obj).__dict__['__repr__'] # ditto - - depickled_bound_meth = pickle_depickle( - bound_slotmethod, protocol=self.protocol) - depickled_unbound_meth = pickle_depickle( - unbound_slotmethod, protocol=self.protocol) - depickled_clsdict_meth = pickle_depickle( - clsdict_slotmethod, protocol=self.protocol) - - # No identity tests on the bound slotmethod are they are bound to - # different float instances - assert depickled_bound_meth() == bound_slotmethod() - assert depickled_unbound_meth is unbound_slotmethod - assert depickled_clsdict_meth is clsdict_slotmethod - - @pytest.mark.skipif( - platform.python_implementation() == "PyPy", - reason="No known staticmethod example in the pypy stdlib") - def test_builtin_staticmethod(self): - obj = "foo" # str object - - bound_staticmethod = obj.maketrans # builtin_function_or_method - unbound_staticmethod = type(obj).maketrans # ditto - clsdict_staticmethod = type(obj).__dict__['maketrans'] # staticmethod - - assert bound_staticmethod is unbound_staticmethod - - depickled_bound_meth = pickle_depickle( - bound_staticmethod, protocol=self.protocol) - depickled_unbound_meth = pickle_depickle( - unbound_staticmethod, protocol=self.protocol) - depickled_clsdict_meth = pickle_depickle( - clsdict_staticmethod, protocol=self.protocol) - - assert depickled_bound_meth is bound_staticmethod - assert depickled_unbound_meth is unbound_staticmethod - - # staticmethod objects are recreated at depickling time, but the - # underlying __func__ object is pickled by attribute. - assert depickled_clsdict_meth.__func__ is clsdict_staticmethod.__func__ - type(depickled_clsdict_meth) is type(clsdict_staticmethod) - - @pytest.mark.skipif(tornado is None, - reason="test needs Tornado installed") - def test_tornado_coroutine(self): - # Pickling a locally defined coroutine function - from tornado import gen, ioloop - - @gen.coroutine - def f(x, y): - yield gen.sleep(x) - raise gen.Return(y + 1) - - @gen.coroutine - def g(y): - res = yield f(0.01, y) - raise gen.Return(res + 1) - - data = cloudpickle.dumps([g, g], protocol=self.protocol) - f = g = None - g2, g3 = pickle.loads(data) - self.assertTrue(g2 is g3) - loop = ioloop.IOLoop.current() - res = loop.run_sync(functools.partial(g2, 5)) - self.assertEqual(res, 7) - - @pytest.mark.skipif( - (3, 11, 0, 'beta') <= sys.version_info < (3, 11, 0, 'beta', 4), - reason="https://github.com/python/cpython/issues/92932" - ) - def test_extended_arg(self): - # Functions with more than 65535 global vars prefix some global - # variable references with the EXTENDED_ARG opcode. - nvars = 65537 + 258 - names = ['g%d' % i for i in range(1, nvars)] - r = random.Random(42) - d = {name: r.randrange(100) for name in names} - # def f(x): - # x = g1, g2, ... - # return zlib.crc32(bytes(bytearray(x))) - code = """ - import zlib - - def f(): - x = {tup} - return zlib.crc32(bytes(bytearray(x))) - """.format(tup=', '.join(names)) - exec(textwrap.dedent(code), d, d) - f = d['f'] - res = f() - data = cloudpickle.dumps([f, f], protocol=self.protocol) - d = f = None - f2, f3 = pickle.loads(data) - self.assertTrue(f2 is f3) - self.assertEqual(f2(), res) - - def test_submodule(self): - # Function that refers (by attribute) to a sub-module of a package. - - # Choose any module NOT imported by __init__ of its parent package - # examples in standard library include: - # - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree - - global xml # imitate performing this import at top of file - import xml.etree.ElementTree - def example(): - x = xml.etree.ElementTree.Comment # potential AttributeError - - s = cloudpickle.dumps(example, protocol=self.protocol) - - # refresh the environment, i.e., unimport the dependency - del xml - for item in list(sys.modules): - if item.split('.')[0] == 'xml': - del sys.modules[item] - - # deserialise - f = pickle.loads(s) - f() # perform test for error - - def test_submodule_closure(self): - # Same as test_submodule except the package is not a global - def scope(): - import xml.etree.ElementTree - def example(): - x = xml.etree.ElementTree.Comment # potential AttributeError - return example - example = scope() - - s = cloudpickle.dumps(example, protocol=self.protocol) - - # refresh the environment (unimport dependency) - for item in list(sys.modules): - if item.split('.')[0] == 'xml': - del sys.modules[item] - - f = cloudpickle.loads(s) - f() # test - - def test_multiprocess(self): - # running a function pickled by another process (a la dask.distributed) - def scope(): - def example(): - x = xml.etree.ElementTree.Comment - return example - global xml - import xml.etree.ElementTree - example = scope() - - s = cloudpickle.dumps(example, protocol=self.protocol) - - # choose "subprocess" rather than "multiprocessing" because the latter - # library uses fork to preserve the parent environment. - command = ("import base64; " - "from srsly.cloudpickle.compat import pickle; " - "pickle.loads(base64.b32decode('" + - base64.b32encode(s).decode('ascii') + - "'))()") - assert not subprocess.call([sys.executable, '-c', command]) - - def test_import(self): - # like test_multiprocess except subpackage modules referenced directly - # (unlike test_submodule) - global etree - def scope(): - import xml.etree as foobar - def example(): - x = etree.Comment - x = foobar.ElementTree - return example - example = scope() - import xml.etree.ElementTree as etree - - s = cloudpickle.dumps(example, protocol=self.protocol) - - command = ("import base64; " - "from srsly.cloudpickle.compat import pickle; " - "pickle.loads(base64.b32decode('" + - base64.b32encode(s).decode('ascii') + - "'))()") - assert not subprocess.call([sys.executable, '-c', command]) - - def test_multiprocessing_lock_raises(self): - lock = multiprocessing.Lock() - with pytest.raises(RuntimeError, match="only be shared between processes through inheritance"): - cloudpickle.dumps(lock) - - def test_cell_manipulation(self): - cell = _make_empty_cell() - - with pytest.raises(ValueError): - cell.cell_contents - - ob = object() - cell_set(cell, ob) - self.assertTrue( - cell.cell_contents is ob, - msg='cell contents not set correctly', - ) - - def check_logger(self, name): - logger = logging.getLogger(name) - pickled = pickle_depickle(logger, protocol=self.protocol) - self.assertTrue(pickled is logger, (pickled, logger)) - - dumped = cloudpickle.dumps(logger) - - code = """if 1: - import base64, srsly.cloudpickle as cloudpickle, logging - - logging.basicConfig(level=logging.INFO) - logger = cloudpickle.loads(base64.b32decode(b'{}')) - logger.info('hello') - """.format(base64.b32encode(dumped).decode('ascii')) - proc = subprocess.Popen([sys.executable, "-W ignore", "-c", code], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - out, _ = proc.communicate() - self.assertEqual(proc.wait(), 0) - self.assertEqual(out.strip().decode(), - f'INFO:{logger.name}:hello') - - def test_logger(self): - # logging.RootLogger object - self.check_logger(None) - # logging.Logger object - self.check_logger('cloudpickle.dummy_test_logger') - - def test_getset_descriptor(self): - assert isinstance(float.real, types.GetSetDescriptorType) - depickled_descriptor = pickle_depickle(float.real) - self.assertIs(depickled_descriptor, float.real) - - def test_abc_cache_not_pickled(self): - # cloudpickle issue #302: make sure that cloudpickle does not pickle - # the caches populated during instance/subclass checks of abc.ABCMeta - # instances. - MyClass = abc.ABCMeta('MyClass', (), {}) - - class MyUnrelatedClass: - pass - - class MyRelatedClass: - pass - - MyClass.register(MyRelatedClass) - - assert not issubclass(MyUnrelatedClass, MyClass) - assert issubclass(MyRelatedClass, MyClass) - - s = cloudpickle.dumps(MyClass) - - assert b"MyUnrelatedClass" not in s - assert b"MyRelatedClass" in s - - depickled_class = cloudpickle.loads(s) - assert not issubclass(MyUnrelatedClass, depickled_class) - assert issubclass(MyRelatedClass, depickled_class) - - def test_abc(self): - - class AbstractClass(abc.ABC): - @abc.abstractmethod - def some_method(self): - """A method""" - - @classmethod - @abc.abstractmethod - def some_classmethod(cls): - """A classmethod""" - - @staticmethod - @abc.abstractmethod - def some_staticmethod(): - """A staticmethod""" - - @property - @abc.abstractmethod - def some_property(): - """A property""" - - class ConcreteClass(AbstractClass): - def some_method(self): - return 'it works!' - - @classmethod - def some_classmethod(cls): - assert cls == ConcreteClass - return 'it works!' - - @staticmethod - def some_staticmethod(): - return 'it works!' - - @property - def some_property(self): - return 'it works!' - - # This abstract class is locally defined so we can safely register - # tuple in it to verify the unpickled class also register tuple. - AbstractClass.register(tuple) - - concrete_instance = ConcreteClass() - depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol) - depickled_class = pickle_depickle(ConcreteClass, - protocol=self.protocol) - depickled_instance = pickle_depickle(concrete_instance) - - assert issubclass(tuple, AbstractClass) - assert issubclass(tuple, depickled_base) - - self.assertEqual(depickled_class().some_method(), 'it works!') - self.assertEqual(depickled_instance.some_method(), 'it works!') - - self.assertEqual(depickled_class.some_classmethod(), 'it works!') - self.assertEqual(depickled_instance.some_classmethod(), 'it works!') - - self.assertEqual(depickled_class().some_staticmethod(), 'it works!') - self.assertEqual(depickled_instance.some_staticmethod(), 'it works!') - - self.assertEqual(depickled_class().some_property, 'it works!') - self.assertEqual(depickled_instance.some_property, 'it works!') - self.assertRaises(TypeError, depickled_base) - - class DepickledBaseSubclass(depickled_base): - def some_method(self): - return 'it works for realz!' - - @classmethod - def some_classmethod(cls): - assert cls == DepickledBaseSubclass - return 'it works for realz!' - - @staticmethod - def some_staticmethod(): - return 'it works for realz!' - - @property - def some_property(): - return 'it works for realz!' - - self.assertEqual(DepickledBaseSubclass().some_method(), - 'it works for realz!') - - class IncompleteBaseSubclass(depickled_base): - def some_method(self): - return 'this class lacks some concrete methods' - - self.assertRaises(TypeError, IncompleteBaseSubclass) - - def test_abstracts(self): - # Same as `test_abc` but using deprecated `abc.abstract*` methods. - # See https://github.com/cloudpipe/cloudpickle/issues/367 - - class AbstractClass(abc.ABC): - @abc.abstractmethod - def some_method(self): - """A method""" - - @abc.abstractclassmethod - def some_classmethod(cls): - """A classmethod""" - - @abc.abstractstaticmethod - def some_staticmethod(): - """A staticmethod""" - - @abc.abstractproperty - def some_property(self): - """A property""" - - class ConcreteClass(AbstractClass): - def some_method(self): - return 'it works!' - - @classmethod - def some_classmethod(cls): - assert cls == ConcreteClass - return 'it works!' - - @staticmethod - def some_staticmethod(): - return 'it works!' - - @property - def some_property(self): - return 'it works!' - - # This abstract class is locally defined so we can safely register - # tuple in it to verify the unpickled class also register tuple. - AbstractClass.register(tuple) - - concrete_instance = ConcreteClass() - depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol) - depickled_class = pickle_depickle(ConcreteClass, - protocol=self.protocol) - depickled_instance = pickle_depickle(concrete_instance) - - assert issubclass(tuple, AbstractClass) - assert issubclass(tuple, depickled_base) - - self.assertEqual(depickled_class().some_method(), 'it works!') - self.assertEqual(depickled_instance.some_method(), 'it works!') - - self.assertEqual(depickled_class.some_classmethod(), 'it works!') - self.assertEqual(depickled_instance.some_classmethod(), 'it works!') - - self.assertEqual(depickled_class().some_staticmethod(), 'it works!') - self.assertEqual(depickled_instance.some_staticmethod(), 'it works!') - - self.assertEqual(depickled_class().some_property, 'it works!') - self.assertEqual(depickled_instance.some_property, 'it works!') - self.assertRaises(TypeError, depickled_base) - - class DepickledBaseSubclass(depickled_base): - def some_method(self): - return 'it works for realz!' - - @classmethod - def some_classmethod(cls): - assert cls == DepickledBaseSubclass - return 'it works for realz!' - - @staticmethod - def some_staticmethod(): - return 'it works for realz!' - - @property - def some_property(self): - return 'it works for realz!' - - self.assertEqual(DepickledBaseSubclass().some_method(), - 'it works for realz!') - - class IncompleteBaseSubclass(depickled_base): - def some_method(self): - return 'this class lacks some concrete methods' - - self.assertRaises(TypeError, IncompleteBaseSubclass) - - def test_weakset_identity_preservation(self): - # Test that weaksets don't lose all their inhabitants if they're - # pickled in a larger data structure that includes other references to - # their inhabitants. - - class SomeClass: - def __init__(self, x): - self.x = x - - obj1, obj2, obj3 = SomeClass(1), SomeClass(2), SomeClass(3) - - things = [weakref.WeakSet([obj1, obj2]), obj1, obj2, obj3] - result = pickle_depickle(things, protocol=self.protocol) - - weakset, depickled1, depickled2, depickled3 = result - - self.assertEqual(depickled1.x, 1) - self.assertEqual(depickled2.x, 2) - self.assertEqual(depickled3.x, 3) - self.assertEqual(len(weakset), 2) - - self.assertEqual(set(weakset), {depickled1, depickled2}) - - def test_non_module_object_passing_whichmodule_test(self): - # https://github.com/cloudpipe/cloudpickle/pull/326: cloudpickle should - # not try to instrospect non-modules object when trying to discover the - # module of a function/class. This happenened because codecov injects - # tuples (and not modules) into sys.modules, but type-checks were not - # carried out on the entries of sys.modules, causing cloupdickle to - # then error in unexpected ways - def func(x): - return x ** 2 - - # Trigger a loop during the execution of whichmodule(func) by - # explicitly setting the function's module to None - func.__module__ = None - - class NonModuleObject: - def __ini__(self): - self.some_attr = None - - def __getattr__(self, name): - # We whitelist func so that a _whichmodule(func, None) call - # returns the NonModuleObject instance if a type check on the - # entries of sys.modules is not carried out, but manipulating - # this instance thinking it really is a module later on in the - # pickling process of func errors out - if name == 'func': - return func - else: - raise AttributeError - - non_module_object = NonModuleObject() - - assert func(2) == 4 - assert func is non_module_object.func - - # Any manipulation of non_module_object relying on attribute access - # will raise an Exception - with pytest.raises(AttributeError): - _ = non_module_object.some_attr - - try: - sys.modules['NonModuleObject'] = non_module_object - - func_module_name = _whichmodule(func, None) - assert func_module_name != 'NonModuleObject' - assert func_module_name is None - - depickled_func = pickle_depickle(func, protocol=self.protocol) - assert depickled_func(2) == 4 - - finally: - sys.modules.pop('NonModuleObject') - - def test_unrelated_faulty_module(self): - # Check that pickling a dynamically defined function or class does not - # fail when introspecting the currently loaded modules in sys.modules - # as long as those faulty modules are unrelated to the class or - # function we are currently pickling. - for base_class in (object, types.ModuleType): - for module_name in ['_missing_module', None]: - class FaultyModule(base_class): - def __getattr__(self, name): - # This throws an exception while looking up within - # pickle.whichmodule or getattr(module, name, None) - raise Exception() - - class Foo: - __module__ = module_name - - def foo(self): - return "it works!" - - def foo(): - return "it works!" - - foo.__module__ = module_name - - if base_class is types.ModuleType: # noqa - faulty_module = FaultyModule('_faulty_module') - else: - faulty_module = FaultyModule() - sys.modules["_faulty_module"] = faulty_module - - try: - # Test whichmodule in save_global. - self.assertEqual(pickle_depickle(Foo()).foo(), "it works!") - - # Test whichmodule in save_function. - cloned = pickle_depickle(foo, protocol=self.protocol) - self.assertEqual(cloned(), "it works!") - finally: - sys.modules.pop("_faulty_module", None) - - @pytest.mark.skip(reason="fails for pytest v7.2.0") - def test_dynamic_pytest_module(self): - # Test case for pull request https://github.com/cloudpipe/cloudpickle/pull/116 - import py - - def f(): - s = py.builtin.set([1]) - return s.pop() - - # some setup is required to allow pytest apimodules to be correctly - # serializable. - from srsly.cloudpickle import CloudPickler - from srsly.cloudpickle import cloudpickle_fast as cp_fast - CloudPickler.dispatch_table[type(py.builtin)] = cp_fast._module_reduce - - g = cloudpickle.loads(cloudpickle.dumps(f, protocol=self.protocol)) - - result = g() - self.assertEqual(1, result) - - def test_function_module_name(self): - func = lambda x: x - cloned = pickle_depickle(func, protocol=self.protocol) - self.assertEqual(cloned.__module__, func.__module__) - - def test_function_qualname(self): - def func(x): - return x - # Default __qualname__ attribute (Python 3 only) - if hasattr(func, '__qualname__'): - cloned = pickle_depickle(func, protocol=self.protocol) - self.assertEqual(cloned.__qualname__, func.__qualname__) - - # Mutated __qualname__ attribute - func.__qualname__ = '' - cloned = pickle_depickle(func, protocol=self.protocol) - self.assertEqual(cloned.__qualname__, func.__qualname__) - - def test_property(self): - # Note that the @property decorator only has an effect on new-style - # classes. - class MyObject: - _read_only_value = 1 - _read_write_value = 1 - - @property - def read_only_value(self): - "A read-only attribute" - return self._read_only_value - - @property - def read_write_value(self): - return self._read_write_value - - @read_write_value.setter - def read_write_value(self, value): - self._read_write_value = value - - - - my_object = MyObject() - - assert my_object.read_only_value == 1 - assert MyObject.read_only_value.__doc__ == "A read-only attribute" - - with pytest.raises(AttributeError): - my_object.read_only_value = 2 - my_object.read_write_value = 2 - - depickled_obj = pickle_depickle(my_object) - - assert depickled_obj.read_only_value == 1 - assert depickled_obj.read_write_value == 2 - - # make sure the depickled read_only_value attribute is still read-only - with pytest.raises(AttributeError): - my_object.read_only_value = 2 - - # make sure the depickled read_write_value attribute is writeable - depickled_obj.read_write_value = 3 - assert depickled_obj.read_write_value == 3 - type(depickled_obj).read_only_value.__doc__ == "A read-only attribute" - - - def test_namedtuple(self): - MyTuple = collections.namedtuple('MyTuple', ['a', 'b', 'c']) - t1 = MyTuple(1, 2, 3) - t2 = MyTuple(3, 2, 1) - - depickled_t1, depickled_MyTuple, depickled_t2 = pickle_depickle( - [t1, MyTuple, t2], protocol=self.protocol) - - assert isinstance(depickled_t1, MyTuple) - assert depickled_t1 == t1 - assert depickled_MyTuple is MyTuple - assert isinstance(depickled_t2, MyTuple) - assert depickled_t2 == t2 - - @pytest.mark.skipif(platform.python_implementation() == "PyPy", - reason="fails sometimes for pypy on conda-forge") - def test_interactively_defined_function(self): - # Check that callables defined in the __main__ module of a Python - # script (or jupyter kernel) can be pickled / unpickled / executed. - code = """\ - from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo - - CONSTANT = 42 - - class Foo(object): - - def method(self, x): - return x - - foo = Foo() - - def f0(x): - return x ** 2 - - def f1(): - return Foo - - def f2(x): - return Foo().method(x) - - def f3(): - return Foo().method(CONSTANT) - - def f4(x): - return foo.method(x) - - def f5(x): - # Recursive call to a dynamically defined function. - if x <= 0: - return f4(x) - return f5(x - 1) + 1 - - cloned = subprocess_pickle_echo(lambda x: x**2, protocol={protocol}) - assert cloned(3) == 9 - - cloned = subprocess_pickle_echo(f0, protocol={protocol}) - assert cloned(3) == 9 - - cloned = subprocess_pickle_echo(Foo, protocol={protocol}) - assert cloned().method(2) == Foo().method(2) - - cloned = subprocess_pickle_echo(Foo(), protocol={protocol}) - assert cloned.method(2) == Foo().method(2) - - cloned = subprocess_pickle_echo(f1, protocol={protocol}) - assert cloned()().method('a') == f1()().method('a') - - cloned = subprocess_pickle_echo(f2, protocol={protocol}) - assert cloned(2) == f2(2) - - cloned = subprocess_pickle_echo(f3, protocol={protocol}) - assert cloned() == f3() - - cloned = subprocess_pickle_echo(f4, protocol={protocol}) - assert cloned(2) == f4(2) - - cloned = subprocess_pickle_echo(f5, protocol={protocol}) - assert cloned(7) == f5(7) == 7 - """.format(protocol=self.protocol) - assert_run_python_script(textwrap.dedent(code)) - - def test_interactively_defined_global_variable(self): - # Check that callables defined in the __main__ module of a Python - # script (or jupyter kernel) correctly retrieve global variables. - code_template = """\ - from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo - from srsly.cloudpickle import dumps, loads - - def local_clone(obj, protocol=None): - return loads(dumps(obj, protocol=protocol)) - - VARIABLE = "default_value" - - def f0(): - global VARIABLE - VARIABLE = "changed_by_f0" - - def f1(): - return VARIABLE - - assert f0.__globals__ is f1.__globals__ - - # pickle f0 and f1 inside the same pickle_string - cloned_f0, cloned_f1 = {clone_func}([f0, f1], protocol={protocol}) - - # cloned_f0 and cloned_f1 now share a global namespace that is isolated - # from any previously existing namespace - assert cloned_f0.__globals__ is cloned_f1.__globals__ - assert cloned_f0.__globals__ is not f0.__globals__ - - # pickle f1 another time, but in a new pickle string - pickled_f1 = dumps(f1, protocol={protocol}) - - # Change the value of the global variable in f0's new global namespace - cloned_f0() - - # thanks to cloudpickle isolation, depickling and calling f0 and f1 - # should not affect the globals of already existing modules - assert VARIABLE == "default_value", VARIABLE - - # Ensure that cloned_f1 and cloned_f0 share the same globals, as f1 and - # f0 shared the same globals at pickling time, and cloned_f1 was - # depickled from the same pickle string as cloned_f0 - shared_global_var = cloned_f1() - assert shared_global_var == "changed_by_f0", shared_global_var - - # f1 is unpickled another time, but because it comes from another - # pickle string than pickled_f1 and pickled_f0, it will not share the - # same globals as the latter two. - new_cloned_f1 = loads(pickled_f1) - assert new_cloned_f1.__globals__ is not cloned_f1.__globals__ - assert new_cloned_f1.__globals__ is not f1.__globals__ - - # get the value of new_cloned_f1's VARIABLE - new_global_var = new_cloned_f1() - assert new_global_var == "default_value", new_global_var - """ - for clone_func in ['local_clone', 'subprocess_pickle_echo']: - code = code_template.format(protocol=self.protocol, - clone_func=clone_func) - assert_run_python_script(textwrap.dedent(code)) - - def test_closure_interacting_with_a_global_variable(self): - global _TEST_GLOBAL_VARIABLE - assert _TEST_GLOBAL_VARIABLE == "default_value" - orig_value = _TEST_GLOBAL_VARIABLE - try: - def f0(): - global _TEST_GLOBAL_VARIABLE - _TEST_GLOBAL_VARIABLE = "changed_by_f0" - - def f1(): - return _TEST_GLOBAL_VARIABLE - - # pickle f0 and f1 inside the same pickle_string - cloned_f0, cloned_f1 = pickle_depickle([f0, f1], - protocol=self.protocol) - - # cloned_f0 and cloned_f1 now share a global namespace that is - # isolated from any previously existing namespace - assert cloned_f0.__globals__ is cloned_f1.__globals__ - assert cloned_f0.__globals__ is not f0.__globals__ - - # pickle f1 another time, but in a new pickle string - pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol) - - # Change the global variable's value in f0's new global namespace - cloned_f0() - - # depickling f0 and f1 should not affect the globals of already - # existing modules - assert _TEST_GLOBAL_VARIABLE == "default_value" - - # Ensure that cloned_f1 and cloned_f0 share the same globals, as f1 - # and f0 shared the same globals at pickling time, and cloned_f1 - # was depickled from the same pickle string as cloned_f0 - shared_global_var = cloned_f1() - assert shared_global_var == "changed_by_f0", shared_global_var - - # f1 is unpickled another time, but because it comes from another - # pickle string than pickled_f1 and pickled_f0, it will not share - # the same globals as the latter two. - new_cloned_f1 = pickle.loads(pickled_f1) - assert new_cloned_f1.__globals__ is not cloned_f1.__globals__ - assert new_cloned_f1.__globals__ is not f1.__globals__ - - # get the value of new_cloned_f1's VARIABLE - new_global_var = new_cloned_f1() - assert new_global_var == "default_value", new_global_var - finally: - _TEST_GLOBAL_VARIABLE = orig_value - - def test_interactive_remote_function_calls(self): - code = """if __name__ == "__main__": - from srsly.tests.cloudpickle.testutils import subprocess_worker - - def interactive_function(x): - return x + 1 - - with subprocess_worker(protocol={protocol}) as w: - - assert w.run(interactive_function, 41) == 42 - - # Define a new function that will call an updated version of - # the previously called function: - - def wrapper_func(x): - return interactive_function(x) - - def interactive_function(x): - return x - 1 - - # The change in the definition of interactive_function in the main - # module of the main process should be reflected transparently - # in the worker process: the worker process does not recall the - # previous definition of `interactive_function`: - - assert w.run(wrapper_func, 41) == 40 - """.format(protocol=self.protocol) - assert_run_python_script(code) - - def test_interactive_remote_function_calls_no_side_effect(self): - code = """if __name__ == "__main__": - from srsly.tests.cloudpickle.testutils import subprocess_worker - import sys - - with subprocess_worker(protocol={protocol}) as w: - - GLOBAL_VARIABLE = 0 - - class CustomClass(object): - - def mutate_globals(self): - global GLOBAL_VARIABLE - GLOBAL_VARIABLE += 1 - return GLOBAL_VARIABLE - - custom_object = CustomClass() - assert w.run(custom_object.mutate_globals) == 1 - - # The caller global variable is unchanged in the main process. - - assert GLOBAL_VARIABLE == 0 - - # Calling the same function again starts again from zero. The - # worker process is stateless: it has no memory of the past call: - - assert w.run(custom_object.mutate_globals) == 1 - - # The symbols defined in the main process __main__ module are - # not set in the worker process main module to leave the worker - # as stateless as possible: - - def is_in_main(name): - return hasattr(sys.modules["__main__"], name) - - assert is_in_main("CustomClass") - assert not w.run(is_in_main, "CustomClass") - - assert is_in_main("GLOBAL_VARIABLE") - assert not w.run(is_in_main, "GLOBAL_VARIABLE") - - """.format(protocol=self.protocol) - assert_run_python_script(code) - - def test_interactive_dynamic_type_and_remote_instances(self): - code = """if __name__ == "__main__": - from srsly.tests.cloudpickle.testutils import subprocess_worker - - with subprocess_worker(protocol={protocol}) as w: - - class CustomCounter: - def __init__(self): - self.count = 0 - def increment(self): - self.count += 1 - return self - - counter = CustomCounter().increment() - assert counter.count == 1 - - returned_counter = w.run(counter.increment) - assert returned_counter.count == 2, returned_counter.count - - # Check that the class definition of the returned instance was - # matched back to the original class definition living in __main__. - - assert isinstance(returned_counter, CustomCounter) - - # Check that memoization does not break provenance tracking: - - def echo(*args): - return args - - C1, C2, c1, c2 = w.run(echo, CustomCounter, CustomCounter, - CustomCounter(), returned_counter) - assert C1 is CustomCounter - assert C2 is CustomCounter - assert isinstance(c1, CustomCounter) - assert isinstance(c2, CustomCounter) - - """.format(protocol=self.protocol) - assert_run_python_script(code) - - def test_interactive_dynamic_type_and_stored_remote_instances(self): - """Simulate objects stored on workers to check isinstance semantics - - Such instances stored in the memory of running worker processes are - similar to dask-distributed futures for instance. - """ - code = """if __name__ == "__main__": - import srsly.cloudpickle as cloudpickle, uuid - from srsly.tests.cloudpickle.testutils import subprocess_worker - - with subprocess_worker(protocol={protocol}) as w: - - class A: - '''Original class definition''' - pass - - def store(x): - storage = getattr(cloudpickle, "_test_storage", None) - if storage is None: - storage = cloudpickle._test_storage = dict() - obj_id = uuid.uuid4().hex - storage[obj_id] = x - return obj_id - - def lookup(obj_id): - return cloudpickle._test_storage[obj_id] - - id1 = w.run(store, A()) - - # The stored object on the worker is matched to a singleton class - # definition thanks to provenance tracking: - assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1) - - # Retrieving the object from the worker yields a local copy that - # is matched back the local class definition this instance - # originally stems from. - assert isinstance(w.run(lookup, id1), A) - - # Changing the local class definition should be taken into account - # in all subsequent calls. In particular the old instances on the - # worker do not map back to the new class definition, neither on - # the worker itself, nor locally on the main program when the old - # instance is retrieved: - - class A: - '''Updated class definition''' - pass - - assert not w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1) - retrieved1 = w.run(lookup, id1) - assert not isinstance(retrieved1, A) - assert retrieved1.__class__ is not A - assert retrieved1.__class__.__doc__ == "Original class definition" - - # New instances on the other hand are proper instances of the new - # class definition everywhere: - - a = A() - id2 = w.run(store, a) - assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id2) - assert isinstance(w.run(lookup, id2), A) - - # Monkeypatch the class defintion in the main process to a new - # class method: - A.echo = lambda cls, x: x - - # Calling this method on an instance will automatically update - # the remote class definition on the worker to propagate the monkey - # patch dynamically. - assert w.run(a.echo, 42) == 42 - - # The stored instance can therefore also access the new class - # method: - assert w.run(lambda obj_id: lookup(obj_id).echo(43), id2) == 43 - - """.format(protocol=self.protocol) - assert_run_python_script(code) - - @pytest.mark.skip(reason="Seems to have issues outside of linux and CPython") - def test_interactive_remote_function_calls_no_memory_leak(self): - code = """if __name__ == "__main__": - from srsly.tests.cloudpickle.testutils import subprocess_worker - import struct - - with subprocess_worker(protocol={protocol}) as w: - - reference_size = w.memsize() - assert reference_size > 0 - - - def make_big_closure(i): - # Generate a byte string of size 1MB - itemsize = len(struct.pack("l", 1)) - data = struct.pack("l", i) * (int(1e6) // itemsize) - def process_data(): - return len(data) - return process_data - - for i in range(100): - func = make_big_closure(i) - result = w.run(func) - assert result == int(1e6), result - - import gc - w.run(gc.collect) - - # By this time the worker process has processed 100MB worth of data - # passed in the closures. The worker memory size should not have - # grown by more than a few MB as closures are garbage collected at - # the end of each remote function call. - growth = w.memsize() - reference_size - - # For some reason, the memory growth after processing 100MB of - # data is ~10MB on MacOS, and ~1MB on Linux, so the upper bound on - # memory growth we use is only tight for MacOS. However, - # - 10MB is still 10x lower than the expected memory growth in case - # of a leak (which would be the total size of the processed data, - # 100MB) - # - the memory usage growth does not increase if using 10000 - # iterations instead of 100 as used now (100x more data) - assert growth < 1.5e7, growth - - """.format(protocol=self.protocol) - assert_run_python_script(code) - - def test_pickle_reraise(self): - for exc_type in [Exception, ValueError, TypeError, RuntimeError]: - obj = RaiserOnPickle(exc_type("foo")) - with pytest.raises((exc_type, pickle.PicklingError)): - cloudpickle.dumps(obj, protocol=self.protocol) - - def test_unhashable_function(self): - d = {'a': 1} - depickled_method = pickle_depickle(d.get, protocol=self.protocol) - self.assertEqual(depickled_method('a'), 1) - self.assertEqual(depickled_method('b'), None) - - @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Deprecation warning in python 3.12 about future deprecation in python 3.14") - def test_itertools_count(self): - counter = itertools.count(1, step=2) - - # advance the counter a bit - next(counter) - next(counter) - - new_counter = pickle_depickle(counter, protocol=self.protocol) - - self.assertTrue(counter is not new_counter) - - for _ in range(10): - self.assertEqual(next(counter), next(new_counter)) - - def test_wraps_preserves_function_name(self): - from functools import wraps - - def f(): - pass - - @wraps(f) - def g(): - f() - - f2 = pickle_depickle(g, protocol=self.protocol) - - self.assertEqual(f2.__name__, f.__name__) - - def test_wraps_preserves_function_doc(self): - from functools import wraps - - def f(): - """42""" - pass - - @wraps(f) - def g(): - f() - - f2 = pickle_depickle(g, protocol=self.protocol) - - self.assertEqual(f2.__doc__, f.__doc__) - - def test_wraps_preserves_function_annotations(self): - def f(x): - pass - - f.__annotations__ = {'x': 1, 'return': float} - - @wraps(f) - def g(x): - f(x) - - f2 = pickle_depickle(g, protocol=self.protocol) - - self.assertEqual(f2.__annotations__, f.__annotations__) - - def test_type_hint(self): - t = typing.Union[list, int] - assert pickle_depickle(t) == t - - def test_instance_with_slots(self): - for slots in [["registered_attribute"], "registered_attribute"]: - class ClassWithSlots: - __slots__ = slots - - def __init__(self): - self.registered_attribute = 42 - - initial_obj = ClassWithSlots() - depickled_obj = pickle_depickle( - initial_obj, protocol=self.protocol) - - for obj in [initial_obj, depickled_obj]: - self.assertEqual(obj.registered_attribute, 42) - with pytest.raises(AttributeError): - obj.non_registered_attribute = 1 - - class SubclassWithSlots(ClassWithSlots): - def __init__(self): - self.unregistered_attribute = 1 - - obj = SubclassWithSlots() - s = cloudpickle.dumps(obj, protocol=self.protocol) - del SubclassWithSlots - depickled_obj = cloudpickle.loads(s) - assert depickled_obj.unregistered_attribute == 1 - - - @unittest.skipIf(not hasattr(types, "MappingProxyType"), - "Old versions of Python do not have this type.") - def test_mappingproxy(self): - mp = types.MappingProxyType({"some_key": "some value"}) - assert mp == pickle_depickle(mp, protocol=self.protocol) - - def test_dataclass(self): - dataclasses = pytest.importorskip("dataclasses") - - DataClass = dataclasses.make_dataclass('DataClass', [('x', int)]) - data = DataClass(x=42) - - pickle_depickle(DataClass, protocol=self.protocol) - assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42 - - def test_locally_defined_enum(self): - class StringEnum(str, enum.Enum): - """Enum when all members are also (and must be) strings""" - - class Color(StringEnum): - """3-element color space""" - RED = "1" - GREEN = "2" - BLUE = "3" - - def is_green(self): - return self is Color.GREEN - - green1, green2, ClonedColor = pickle_depickle( - [Color.GREEN, Color.GREEN, Color], protocol=self.protocol) - assert green1 is green2 - assert green1 is ClonedColor.GREEN - assert green1 is not ClonedColor.BLUE - assert isinstance(green1, str) - assert green1.is_green() - - # cloudpickle systematically tracks provenance of class definitions - # and ensure reconciliation in case of round trips: - assert green1 is Color.GREEN - assert ClonedColor is Color - - green3 = pickle_depickle(Color.GREEN, protocol=self.protocol) - assert green3 is Color.GREEN - - def test_locally_defined_intenum(self): - # Try again with a IntEnum defined with the functional API - DynamicColor = enum.IntEnum("Color", {"RED": 1, "GREEN": 2, "BLUE": 3}) - - green1, green2, ClonedDynamicColor = pickle_depickle( - [DynamicColor.GREEN, DynamicColor.GREEN, DynamicColor], - protocol=self.protocol) - - assert green1 is green2 - assert green1 is ClonedDynamicColor.GREEN - assert green1 is not ClonedDynamicColor.BLUE - assert ClonedDynamicColor is DynamicColor - - def test_interactively_defined_enum(self): - code = """if __name__ == "__main__": - from enum import Enum - from srsly.tests.cloudpickle.testutils import subprocess_worker - - with subprocess_worker(protocol={protocol}) as w: - - class Color(Enum): - RED = 1 - GREEN = 2 - - def check_positive(x): - return Color.GREEN if x >= 0 else Color.RED - - result = w.run(check_positive, 1) - - # Check that the returned enum instance is reconciled with the - # locally defined Color enum type definition: - - assert result is Color.GREEN - - # Check that changing the definition of the Enum class is taken - # into account on the worker for subsequent calls: - - class Color(Enum): - RED = 1 - BLUE = 2 - - def check_positive(x): - return Color.BLUE if x >= 0 else Color.RED - - result = w.run(check_positive, 1) - assert result is Color.BLUE - """.format(protocol=self.protocol) - assert_run_python_script(code) - - def test_relative_import_inside_function(self): - pytest.importorskip("_cloudpickle_testpkg") - # Make sure relative imports inside round-tripped functions is not - # broken. This was a bug in cloudpickle versions <= 0.5.3 and was - # re-introduced in 0.8.0. - from _cloudpickle_testpkg import relative_imports_factory - f, g = relative_imports_factory() - for func, source in zip([f, g], ["module", "package"]): - # Make sure relative imports are initially working - assert func() == f"hello from a {source}!" - - # Make sure relative imports still work after round-tripping - cloned_func = pickle_depickle(func, protocol=self.protocol) - assert cloned_func() == f"hello from a {source}!" - - def test_interactively_defined_func_with_keyword_only_argument(self): - # fixes https://github.com/cloudpipe/cloudpickle/issues/263 - def f(a, *, b=1): - return a + b - - depickled_f = pickle_depickle(f, protocol=self.protocol) - - for func in (f, depickled_f): - assert func(2) == 3 - assert func.__kwdefaults__ == {'b': 1} - - @pytest.mark.skipif(not hasattr(types.CodeType, "co_posonlyargcount"), - reason="Requires positional-only argument syntax") - def test_interactively_defined_func_with_positional_only_argument(self): - # Fixes https://github.com/cloudpipe/cloudpickle/issues/266 - # The source code of this test is bundled in a string and is ran from - # the __main__ module of a subprocess in order to avoid a SyntaxError - # in versions of python that do not support positional-only argument - # syntax. - code = """ - import pytest - from srsly.cloudpickle import loads, dumps - - def f(a, /, b=1): - return a + b - - depickled_f = loads(dumps(f, protocol={protocol})) - - for func in (f, depickled_f): - assert func(2) == 3 - assert func.__code__.co_posonlyargcount == 1 - with pytest.raises(TypeError): - func(a=2) - - """.format(protocol=self.protocol) - assert_run_python_script(textwrap.dedent(code)) - - def test___reduce___returns_string(self): - # Non regression test for objects with a __reduce__ method returning a - # string, meaning "save by attribute using save_global" - pytest.importorskip("_cloudpickle_testpkg") - from _cloudpickle_testpkg import some_singleton - assert some_singleton.__reduce__() == "some_singleton" - depickled_singleton = pickle_depickle( - some_singleton, protocol=self.protocol) - assert depickled_singleton is some_singleton - - def test_cloudpickle_extract_nested_globals(self): - def function_factory(): - def inner_function(): - global _TEST_GLOBAL_VARIABLE - return _TEST_GLOBAL_VARIABLE - return inner_function - - globals_ = set(cloudpickle.cloudpickle._extract_code_globals( - function_factory.__code__).keys()) - assert globals_ == {'_TEST_GLOBAL_VARIABLE'} - - depickled_factory = pickle_depickle(function_factory, - protocol=self.protocol) - inner_func = depickled_factory() - assert inner_func() == _TEST_GLOBAL_VARIABLE - - def test_recursion_during_pickling(self): - class A: - def __getattribute__(self, name): - return getattr(self, name) - - a = A() - with pytest.raises(pickle.PicklingError, match='recursion'): - cloudpickle.dumps(a) - - def test_out_of_band_buffers(self): - if self.protocol < 5: - pytest.skip("Need Pickle Protocol 5 or later") - np = pytest.importorskip("numpy") - - class LocallyDefinedClass: - data = np.zeros(10) - - data_instance = LocallyDefinedClass() - buffers = [] - pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol, - buffer_callback=buffers.append) - assert len(buffers) == 1 - reconstructed = pickle.loads(pickle_bytes, buffers=buffers) - np.testing.assert_allclose(reconstructed.data, data_instance.data) - - def test_pickle_dynamic_typevar(self): - T = typing.TypeVar('T') - depickled_T = pickle_depickle(T, protocol=self.protocol) - attr_list = [ - "__name__", "__bound__", "__constraints__", "__covariant__", - "__contravariant__" - ] - for attr in attr_list: - assert getattr(T, attr) == getattr(depickled_T, attr) - - def test_pickle_dynamic_typevar_tracking(self): - T = typing.TypeVar("T") - T2 = subprocess_pickle_echo(T, protocol=self.protocol) - assert T is T2 - - def test_pickle_dynamic_typevar_memoization(self): - T = typing.TypeVar('T') - depickled_T1, depickled_T2 = pickle_depickle((T, T), - protocol=self.protocol) - assert depickled_T1 is depickled_T2 - - def test_pickle_importable_typevar(self): - pytest.importorskip("_cloudpickle_testpkg") - from _cloudpickle_testpkg import T - T1 = pickle_depickle(T, protocol=self.protocol) - assert T1 is T - - # Standard Library TypeVar - from typing import AnyStr - assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol) - - def test_generic_type(self): - T = typing.TypeVar('T') - - class C(typing.Generic[T]): - pass - - assert pickle_depickle(C, protocol=self.protocol) is C - - # Identity is not part of the typing contract: only test for - # equality instead. - assert pickle_depickle(C[int], protocol=self.protocol) == C[int] - - with subprocess_worker(protocol=self.protocol) as worker: - - def check_generic(generic, origin, type_value, use_args): - assert generic.__origin__ is origin - - assert len(origin.__orig_bases__) == 1 - ob = origin.__orig_bases__[0] - assert ob.__origin__ is typing.Generic - - if use_args: - assert len(generic.__args__) == 1 - assert generic.__args__[0] is type_value - else: - assert len(generic.__parameters__) == 1 - assert generic.__parameters__[0] is type_value - assert len(ob.__parameters__) == 1 - - return "ok" - - # backward-compat for old Python 3.5 versions that sometimes relies - # on __parameters__ - use_args = getattr(C[int], '__args__', ()) != () - assert check_generic(C[int], C, int, use_args) == "ok" - assert worker.run(check_generic, C[int], C, int, use_args) == "ok" - - def test_generic_subclass(self): - T = typing.TypeVar('T') - - class Base(typing.Generic[T]): - pass - - class DerivedAny(Base): - pass - - class LeafAny(DerivedAny): - pass - - class DerivedInt(Base[int]): - pass - - class LeafInt(DerivedInt): - pass - - class DerivedT(Base[T]): - pass - - class LeafT(DerivedT[T]): - pass - - klasses = [ - Base, DerivedAny, LeafAny, DerivedInt, LeafInt, DerivedT, LeafT - ] - for klass in klasses: - assert pickle_depickle(klass, protocol=self.protocol) is klass - - with subprocess_worker(protocol=self.protocol) as worker: - - def check_mro(klass, expected_mro): - assert klass.mro() == expected_mro - return "ok" - - for klass in klasses: - mro = klass.mro() - assert check_mro(klass, mro) - assert worker.run(check_mro, klass, mro) == "ok" - - def test_locally_defined_class_with_type_hints(self): - with subprocess_worker(protocol=self.protocol) as worker: - for type_ in _all_types_to_test(): - class MyClass: - def method(self, arg: type_) -> type_: - return arg - MyClass.__annotations__ = {'attribute': type_} - - def check_annotations(obj, expected_type, expected_type_str): - assert obj.__annotations__["attribute"] == expected_type - assert ( - obj.method.__annotations__["arg"] == expected_type - ) - assert ( - obj.method.__annotations__["return"] - == expected_type - ) - return "ok" - - obj = MyClass() - assert check_annotations(obj, type_, "type_") == "ok" - assert ( - worker.run(check_annotations, obj, type_, "type_") == "ok" - ) - - def test_generic_extensions_literal(self): - typing_extensions = pytest.importorskip('typing_extensions') - for obj in [typing_extensions.Literal, typing_extensions.Literal['a']]: - depickled_obj = pickle_depickle(obj, protocol=self.protocol) - assert depickled_obj == obj - - def test_generic_extensions_final(self): - typing_extensions = pytest.importorskip('typing_extensions') - for obj in [typing_extensions.Final, typing_extensions.Final[int]]: - depickled_obj = pickle_depickle(obj, protocol=self.protocol) - assert depickled_obj == obj - - def test_class_annotations(self): - class C: - pass - C.__annotations__ = {'a': int} - - C1 = pickle_depickle(C, protocol=self.protocol) - assert C1.__annotations__ == C.__annotations__ - - def test_function_annotations(self): - def f(a: int) -> str: - pass - - f1 = pickle_depickle(f, protocol=self.protocol) - assert f1.__annotations__ == f.__annotations__ - - def test_always_use_up_to_date_copyreg(self): - # test that updates of copyreg.dispatch_table are taken in account by - # cloudpickle - import copyreg - try: - class MyClass: - pass - - def reduce_myclass(x): - return MyClass, (), {'custom_reduce': True} - - copyreg.dispatch_table[MyClass] = reduce_myclass - my_obj = MyClass() - depickled_myobj = pickle_depickle(my_obj, protocol=self.protocol) - assert hasattr(depickled_myobj, 'custom_reduce') - finally: - copyreg.dispatch_table.pop(MyClass) - - def test_literal_misdetection(self): - # see https://github.com/cloudpipe/cloudpickle/issues/403 - class MyClass: - @property - def __values__(self): - return () - - o = MyClass() - pickle_depickle(o, protocol=self.protocol) - - def test_final_or_classvar_misdetection(self): - # see https://github.com/cloudpipe/cloudpickle/issues/403 - class MyClass: - @property - def __type__(self): - return int - - o = MyClass() - pickle_depickle(o, protocol=self.protocol) - - @pytest.mark.skip(reason="Requires pytest -s to pass") - def test_pickle_constructs_from_module_registered_for_pickling_by_value(self): # noqa - _prev_sys_path = sys.path.copy() - try: - # We simulate an interactive session that: - # - we start from the /path/to/cloudpickle/tests directory, where a - # local .py file (mock_local_file) is located. - # - uses constructs from mock_local_file in remote workers that do - # not have access to this file. This situation is - # the justification behind the - # (un)register_pickle_by_value(module) api that cloudpickle - # exposes. - _mock_interactive_session_cwd = os.path.dirname(__file__) - - # First, remove sys.path entries that could point to - # /path/to/cloudpickle/tests and be in inherited by the worker - _maybe_remove(sys.path, '') - _maybe_remove(sys.path, _mock_interactive_session_cwd) - - # Add the desired session working directory - sys.path.insert(0, _mock_interactive_session_cwd) - - with subprocess_worker(protocol=self.protocol) as w: - # Make the module unavailable in the remote worker - w.run( - lambda p: sys.path.remove(p), _mock_interactive_session_cwd - ) - # Import the actual file after starting the module since the - # worker is started using fork on Linux, which will inherits - # the parent sys.modules. On Python>3.6, the worker can be - # started using spawn using mp_context in ProcessPoolExectutor. - # TODO Once Python 3.6 reaches end of life, rely on mp_context - # instead. - import mock_local_folder.mod as mod - # The constructs whose pickling mechanism is changed using - # register_pickle_by_value are functions, classes, TypeVar and - # modules. - from mock_local_folder.mod import ( - local_function, LocalT, LocalClass - ) - - # Make sure the module/constructs are unimportable in the - # worker. - with pytest.raises(ImportError): - w.run(lambda: __import__("mock_local_folder.mod")) - with pytest.raises(ImportError): - w.run( - lambda: __import__("mock_local_folder.subfolder.mod") - ) - - for o in [mod, local_function, LocalT, LocalClass]: - with pytest.raises(ImportError): - w.run(lambda: o) - - register_pickle_by_value(mod) - # function - assert w.run(lambda: local_function()) == local_function() - # typevar - assert w.run(lambda: LocalT.__name__) == LocalT.__name__ - # classes - assert ( - w.run(lambda: LocalClass().method()) - == LocalClass().method() - ) - # modules - assert ( - w.run(lambda: mod.local_function()) == local_function() - ) - - # Constructs from modules inside subfolders should be pickled - # by value if a namespace module pointing to some parent folder - # was registered for pickling by value. A "mock_local_folder" - # namespace module falls into that category, but a - # "mock_local_folder.mod" one does not. - from mock_local_folder.subfolder.submod import ( - LocalSubmodClass, LocalSubmodT, local_submod_function - ) - # Shorter aliases to comply with line-length limits - _t, _func, _class = ( - LocalSubmodT, local_submod_function, LocalSubmodClass - ) - with pytest.raises(ImportError): - w.run( - lambda: __import__("mock_local_folder.subfolder.mod") - ) - with pytest.raises(ImportError): - w.run(lambda: local_submod_function) - - unregister_pickle_by_value(mod) - - with pytest.raises(ImportError): - w.run(lambda: local_function) - - with pytest.raises(ImportError): - w.run(lambda: __import__("mock_local_folder.mod")) - - # Test the namespace folder case - import mock_local_folder - register_pickle_by_value(mock_local_folder) - assert w.run(lambda: local_function()) == local_function() - assert w.run(lambda: _func()) == _func() - unregister_pickle_by_value(mock_local_folder) - - with pytest.raises(ImportError): - w.run(lambda: local_function) - with pytest.raises(ImportError): - w.run(lambda: local_submod_function) - - # Test the case of registering a single module inside a - # subfolder. - import mock_local_folder.subfolder.submod - register_pickle_by_value(mock_local_folder.subfolder.submod) - assert w.run(lambda: _func()) == _func() - assert w.run(lambda: _t.__name__) == _t.__name__ - assert w.run(lambda: _class().method()) == _class().method() - - # Registering a module from a subfolder for pickling by value - # should not make constructs from modules from the parent - # folder pickleable - with pytest.raises(ImportError): - w.run(lambda: local_function) - with pytest.raises(ImportError): - w.run(lambda: __import__("mock_local_folder.mod")) - - unregister_pickle_by_value( - mock_local_folder.subfolder.submod - ) - with pytest.raises(ImportError): - w.run(lambda: local_submod_function) - - # Test the subfolder namespace module case - import mock_local_folder.subfolder - register_pickle_by_value(mock_local_folder.subfolder) - assert w.run(lambda: _func()) == _func() - assert w.run(lambda: _t.__name__) == _t.__name__ - assert w.run(lambda: _class().method()) == _class().method() - - unregister_pickle_by_value(mock_local_folder.subfolder) - finally: - _fname = "mock_local_folder" - sys.path = _prev_sys_path - for m in [_fname, f"{_fname}.mod", f"{_fname}.subfolder", - f"{_fname}.subfolder.submod"]: - mod = sys.modules.pop(m, None) - if mod and mod.__name__ in list_registry_pickle_by_value(): - unregister_pickle_by_value(mod) - - def test_pickle_constructs_from_installed_packages_registered_for_pickling_by_value( # noqa - self - ): - pytest.importorskip("_cloudpickle_testpkg") - for package_or_module in ["package", "module"]: - if package_or_module == "package": - import _cloudpickle_testpkg as m - f = m.package_function_with_global - _original_global = m.global_variable - elif package_or_module == "module": - import _cloudpickle_testpkg.mod as m - f = m.module_function_with_global - _original_global = m.global_variable - try: - with subprocess_worker(protocol=self.protocol) as w: - assert w.run(lambda: f()) == _original_global - - # Test that f is pickled by value by modifying a global - # variable that f uses, and making sure that this - # modification shows up when calling the function remotely - register_pickle_by_value(m) - assert w.run(lambda: f()) == _original_global - m.global_variable = "modified global" - assert m.global_variable != _original_global - assert w.run(lambda: f()) == "modified global" - unregister_pickle_by_value(m) - finally: - m.global_variable = _original_global - if m.__name__ in list_registry_pickle_by_value(): - unregister_pickle_by_value(m) - - def test_pickle_various_versions_of_the_same_function_with_different_pickling_method( # noqa - self - ): - pytest.importorskip("_cloudpickle_testpkg") - # Make sure that different versions of the same function (possibly - # pickled in a different way - by value and/or by reference) can - # peacefully co-exist (e.g. without globals interaction) in a remote - # worker. - import _cloudpickle_testpkg - from _cloudpickle_testpkg import package_function_with_global as f - _original_global = _cloudpickle_testpkg.global_variable - - def _create_registry(): - _main = __import__("sys").modules["__main__"] - _main._cloudpickle_registry = {} - # global _cloudpickle_registry - - def _add_to_registry(v, k): - _main = __import__("sys").modules["__main__"] - _main._cloudpickle_registry[k] = v - - def _call_from_registry(k): - _main = __import__("sys").modules["__main__"] - return _main._cloudpickle_registry[k]() - - try: - with subprocess_worker(protocol=self.protocol) as w: - w.run(_create_registry) - w.run(_add_to_registry, f, "f_by_ref") - - register_pickle_by_value(_cloudpickle_testpkg) - _cloudpickle_testpkg.global_variable = "modified global" - w.run(_add_to_registry, f, "f_by_val") - assert ( - w.run(_call_from_registry, "f_by_ref") == _original_global - ) - assert ( - w.run(_call_from_registry, "f_by_val") == "modified global" - ) - - finally: - _cloudpickle_testpkg.global_variable = _original_global - - if "_cloudpickle_testpkg" in list_registry_pickle_by_value(): - unregister_pickle_by_value(_cloudpickle_testpkg) - - @pytest.mark.skipif( - sys.version_info < (3, 7), - reason="Determinism can only be guaranteed for Python 3.7+" - ) - def test_deterministic_pickle_bytes_for_function(self): - # Ensure that functions with references to several global names are - # pickled to fixed bytes that do not depend on the PYTHONHASHSEED of - # the Python process. - vals = set() - - def func_with_globals(): - return _TEST_GLOBAL_VARIABLE + _TEST_GLOBAL_VARIABLE2 - - for i in range(5): - vals.add( - subprocess_pickle_string(func_with_globals, - protocol=self.protocol, - add_env={"PYTHONHASHSEED": str(i)})) - if len(vals) > 1: - # Print additional debug info on stdout with dis: - for val in vals: - pickletools.dis(val) - pytest.fail( - "Expected a single deterministic payload, got %d/5" % len(vals) - ) - - -class Protocol2CloudPickleTest(CloudPickleTest): - - protocol = 2 - - -def test_lookup_module_and_qualname_dynamic_typevar(): - T = typing.TypeVar('T') - module_and_name = _lookup_module_and_qualname(T, name=T.__name__) - assert module_and_name is None - - -def test_lookup_module_and_qualname_importable_typevar(): - pytest.importorskip("_cloudpickle_testpkg") - import _cloudpickle_testpkg - T = _cloudpickle_testpkg.T - module_and_name = _lookup_module_and_qualname(T, name=T.__name__) - assert module_and_name is not None - module, name = module_and_name - assert module is _cloudpickle_testpkg - assert name == 'T' - - -def test_lookup_module_and_qualname_stdlib_typevar(): - module_and_name = _lookup_module_and_qualname(typing.AnyStr, - name=typing.AnyStr.__name__) - assert module_and_name is not None - module, name = module_and_name - assert module is typing - assert name == 'AnyStr' - - -def test_register_pickle_by_value(): - pytest.importorskip("_cloudpickle_testpkg") - import _cloudpickle_testpkg as pkg - import _cloudpickle_testpkg.mod as mod - - assert list_registry_pickle_by_value() == set() - - register_pickle_by_value(pkg) - assert list_registry_pickle_by_value() == {pkg.__name__} - - register_pickle_by_value(mod) - assert list_registry_pickle_by_value() == {pkg.__name__, mod.__name__} - - unregister_pickle_by_value(mod) - assert list_registry_pickle_by_value() == {pkg.__name__} - - msg = f"Input should be a module object, got {pkg.__name__} instead" - with pytest.raises(ValueError, match=msg): - unregister_pickle_by_value(pkg.__name__) - - unregister_pickle_by_value(pkg) - assert list_registry_pickle_by_value() == set() - - msg = f"{pkg} is not registered for pickle by value" - with pytest.raises(ValueError, match=re.escape(msg)): - unregister_pickle_by_value(pkg) - - msg = f"Input should be a module object, got {pkg.__name__} instead" - with pytest.raises(ValueError, match=msg): - register_pickle_by_value(pkg.__name__) - - dynamic_mod = types.ModuleType('dynamic_mod') - msg = ( - f"{dynamic_mod} was not imported correctly, have you used an " - f"`import` statement to access it?" - ) - with pytest.raises(ValueError, match=re.escape(msg)): - register_pickle_by_value(dynamic_mod) - - -def _all_types_to_test(): - T = typing.TypeVar('T') - - class C(typing.Generic[T]): - pass - - types_to_test = [ - C, C[int], - T, typing.Any, typing.Optional, - typing.Generic, typing.Union, - typing.Optional[int], - typing.Generic[T], - typing.Callable[[int], typing.Any], - typing.Callable[..., typing.Any], - typing.Callable[[], typing.Any], - typing.Tuple[int, ...], - typing.Tuple[int, C[int]], - typing.List[int], - typing.Dict[int, str], - typing.ClassVar, - typing.ClassVar[C[int]], - typing.NoReturn, - ] - return types_to_test - - -def test_module_level_pickler(): - # #366: cloudpickle should expose its pickle.Pickler subclass as - # cloudpickle.Pickler - assert hasattr(cloudpickle, "Pickler") - assert cloudpickle.Pickler is cloudpickle.CloudPickler - - -if __name__ == '__main__': - unittest.main() diff --git a/srsly/tests/cloudpickle/mock_local_folder/mod.py b/srsly/tests/cloudpickle/mock_local_folder/mod.py deleted file mode 100644 index 1a1c1da..0000000 --- a/srsly/tests/cloudpickle/mock_local_folder/mod.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -In the distributed computing setting, this file plays the role of a "local -development" file, e.g. a file that is importable locally, but unimportable in -remote workers. Constructs defined in this file and usually pickled by -reference should instead flagged to cloudpickle for pickling by value: this is -done using the register_pickle_by_value api exposed by cloudpickle. -""" -import typing - - -def local_function(): - return "hello from a function importable locally!" - - -class LocalClass: - def method(self): - return "hello from a class importable locally" - - -LocalT = typing.TypeVar("LocalT") diff --git a/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py b/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py deleted file mode 100644 index deebc14..0000000 --- a/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py +++ /dev/null @@ -1,13 +0,0 @@ -import typing - - -def local_submod_function(): - return "hello from a file located in a locally-importable subfolder!" - - -class LocalSubmodClass: - def method(self): - return "hello from a class located in a locally-importable subfolder!" - - -LocalSubmodT = typing.TypeVar("LocalSubmodT") diff --git a/srsly/tests/cloudpickle/testutils.py b/srsly/tests/cloudpickle/testutils.py deleted file mode 100644 index e0890b4..0000000 --- a/srsly/tests/cloudpickle/testutils.py +++ /dev/null @@ -1,217 +0,0 @@ -import sys -import os -import os.path as op -import tempfile -from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError -from srsly.cloudpickle.compat import pickle -from contextlib import contextmanager -from concurrent.futures import ProcessPoolExecutor - -import psutil -from srsly.cloudpickle import dumps -from subprocess import TimeoutExpired - -loads = pickle.loads -TIMEOUT = 60 -TEST_GLOBALS = "a test value" - - -def make_local_function(): - def g(x): - # this function checks that the globals are correctly handled and that - # the builtins are available - assert TEST_GLOBALS == "a test value" - return sum(range(10)) - - return g - - -def _make_cwd_env(): - """Helper to prepare environment for the child processes""" - cloudpickle_repo_folder = op.normpath( - op.join(op.dirname(__file__), '..')) - env = os.environ.copy() - pythonpath = "{src}{sep}tests{pathsep}{src}".format( - src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep) - env['PYTHONPATH'] = pythonpath - return cloudpickle_repo_folder, env - - -def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT, - add_env=None): - """Retrieve pickle string of an object generated by a child Python process - - Pickle the input data into a buffer, send it to a subprocess via - stdin, expect the subprocess to unpickle, re-pickle that data back - and send it back to the parent process via stdout for final unpickling. - - >>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2) - b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.' - - """ - # run then pickle_echo(protocol=protocol) in __main__: - - # Protect stderr from any warning, as we will assume an error will happen - # if it is not empty. A concrete example is pytest using the imp module, - # which is deprecated in python 3.8 - cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)] - cwd, env = _make_cwd_env() - if add_env: - env.update(add_env) - proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env, - bufsize=4096) - pickle_string = dumps(input_data, protocol=protocol) - try: - comm_kwargs = {} - comm_kwargs['timeout'] = timeout - out, err = proc.communicate(pickle_string, **comm_kwargs) - if proc.returncode != 0 or len(err): - message = "Subprocess returned %d: " % proc.returncode - message += err.decode('utf-8') - raise RuntimeError(message) - return out - except TimeoutExpired as e: - proc.kill() - out, err = proc.communicate() - message = "\n".join([out.decode('utf-8'), err.decode('utf-8')]) - raise RuntimeError(message) from e - - -def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT, - add_env=None): - """Echo function with a child Python process - Pickle the input data into a buffer, send it to a subprocess via - stdin, expect the subprocess to unpickle, re-pickle that data back - and send it back to the parent process via stdout for final unpickling. - >>> subprocess_pickle_echo([1, 'a', None]) - [1, 'a', None] - """ - out = subprocess_pickle_string(input_data, - protocol=protocol, - timeout=timeout, - add_env=add_env) - return loads(out) - - -def _read_all_bytes(stream_in, chunk_size=4096): - all_data = b"" - while True: - data = stream_in.read(chunk_size) - all_data += data - if len(data) < chunk_size: - break - return all_data - - -def pickle_echo(stream_in=None, stream_out=None, protocol=None): - """Read a pickle from stdin and pickle it back to stdout""" - if stream_in is None: - stream_in = sys.stdin - if stream_out is None: - stream_out = sys.stdout - - # Force the use of bytes streams under Python 3 - if hasattr(stream_in, 'buffer'): - stream_in = stream_in.buffer - if hasattr(stream_out, 'buffer'): - stream_out = stream_out.buffer - - input_bytes = _read_all_bytes(stream_in) - stream_in.close() - obj = loads(input_bytes) - repickled_bytes = dumps(obj, protocol=protocol) - stream_out.write(repickled_bytes) - stream_out.close() - - -def call_func(payload, protocol): - """Remote function call that uses cloudpickle to transport everthing""" - func, args, kwargs = loads(payload) - try: - result = func(*args, **kwargs) - except BaseException as e: - result = e - return dumps(result, protocol=protocol) - - -class _Worker: - def __init__(self, protocol=None): - self.protocol = protocol - self.pool = ProcessPoolExecutor(max_workers=1) - self.pool.submit(id, 42).result() # start the worker process - - def run(self, func, *args, **kwargs): - """Synchronous remote function call""" - - input_payload = dumps((func, args, kwargs), protocol=self.protocol) - result_payload = self.pool.submit( - call_func, input_payload, self.protocol).result() - result = loads(result_payload) - - if isinstance(result, BaseException): - raise result - return result - - def memsize(self): - workers_pids = [p.pid if hasattr(p, "pid") else p - for p in list(self.pool._processes)] - num_workers = len(workers_pids) - if num_workers == 0: - return 0 - elif num_workers > 1: - raise RuntimeError("Unexpected number of workers: %d" - % num_workers) - return psutil.Process(workers_pids[0]).memory_info().rss - - def close(self): - self.pool.shutdown(wait=True) - - -@contextmanager -def subprocess_worker(protocol=None): - worker = _Worker(protocol=protocol) - yield worker - worker.close() - - -def assert_run_python_script(source_code, timeout=TIMEOUT): - """Utility to help check pickleability of objects defined in __main__ - - The script provided in the source code should return 0 and not print - anything on stderr or stdout. - """ - fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py') - os.close(fd) - try: - with open(source_file, 'wb') as f: - f.write(source_code.encode('utf-8')) - cmd = [sys.executable, '-W ignore', source_file] - cwd, env = _make_cwd_env() - kwargs = { - 'cwd': cwd, - 'stderr': STDOUT, - 'env': env, - } - # If coverage is running, pass the config file to the subprocess - coverage_rc = os.environ.get("COVERAGE_PROCESS_START") - if coverage_rc: - kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc - kwargs['timeout'] = timeout - try: - try: - out = check_output(cmd, **kwargs) - except CalledProcessError as e: - raise RuntimeError("script errored with output:\n%s" - % e.output.decode('utf-8')) from e - if out != b"": - raise AssertionError(out.decode('utf-8')) - except TimeoutExpired as e: - raise RuntimeError("script timeout, output so far:\n%s" - % e.output.decode('utf-8')) from e - finally: - os.unlink(source_file) - - -if __name__ == '__main__': - protocol = int(sys.argv[sys.argv.index('--protocol') + 1]) - pickle_echo(protocol=protocol) diff --git a/srsly/tests/msgpack/__init__.py b/srsly/tests/msgpack/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/srsly/tests/msgpack/test_buffer.py b/srsly/tests/msgpack/test_buffer.py deleted file mode 100644 index db7b213..0000000 --- a/srsly/tests/msgpack/test_buffer.py +++ /dev/null @@ -1,27 +0,0 @@ -from srsly.msgpack import packb, unpackb - - -def test_unpack_buffer(): - from array import array - - buf = array("b") - buf.frombytes(packb((b"foo", b"bar"))) - obj = unpackb(buf, use_list=1) - assert [b"foo", b"bar"] == obj - - -def test_unpack_bytearray(): - buf = bytearray(packb(("foo", "bar"))) - obj = unpackb(buf, use_list=1) - assert [b"foo", b"bar"] == obj - expected_type = bytes - assert all(type(s) == expected_type for s in obj) - - -def test_unpack_memoryview(): - buf = bytearray(packb(("foo", "bar"))) - view = memoryview(buf) - obj = unpackb(view, use_list=1) - assert [b"foo", b"bar"] == obj - expected_type = bytes - assert all(type(s) == expected_type for s in obj) diff --git a/srsly/tests/msgpack/test_case.py b/srsly/tests/msgpack/test_case.py deleted file mode 100644 index d07fe22..0000000 --- a/srsly/tests/msgpack/test_case.py +++ /dev/null @@ -1,135 +0,0 @@ -from srsly.msgpack import packb, unpackb - - -def check(length, obj): - v = packb(obj) - assert len(v) == length, "%r length should be %r but get %r" % (obj, length, len(v)) - assert unpackb(v, use_list=0) == obj - - -def test_1(): - for o in [ - None, - True, - False, - 0, - 1, - (1 << 6), - (1 << 7) - 1, - -1, - -((1 << 5) - 1), - -(1 << 5), - ]: - check(1, o) - - -def test_2(): - for o in [1 << 7, (1 << 8) - 1, -((1 << 5) + 1), -(1 << 7)]: - check(2, o) - - -def test_3(): - for o in [1 << 8, (1 << 16) - 1, -((1 << 7) + 1), -(1 << 15)]: - check(3, o) - - -def test_5(): - for o in [1 << 16, (1 << 32) - 1, -((1 << 15) + 1), -(1 << 31)]: - check(5, o) - - -def test_9(): - for o in [ - 1 << 32, - (1 << 64) - 1, - -((1 << 31) + 1), - -(1 << 63), - 1.0, - 0.1, - -0.1, - -1.0, - ]: - check(9, o) - - -def check_raw(overhead, num): - check(num + overhead, b" " * num) - - -def test_fixraw(): - check_raw(1, 0) - check_raw(1, (1 << 5) - 1) - - -def test_raw16(): - check_raw(3, 1 << 5) - check_raw(3, (1 << 16) - 1) - - -def test_raw32(): - check_raw(5, 1 << 16) - - -def check_array(overhead, num): - check(num + overhead, (None,) * num) - - -def test_fixarray(): - check_array(1, 0) - check_array(1, (1 << 4) - 1) - - -def test_array16(): - check_array(3, 1 << 4) - check_array(3, (1 << 16) - 1) - - -def test_array32(): - check_array(5, (1 << 16)) - - -def match(obj, buf): - assert packb(obj) == buf - assert unpackb(buf, use_list=0) == obj - - -def test_match(): - cases = [ - (None, b"\xc0"), - (False, b"\xc2"), - (True, b"\xc3"), - (0, b"\x00"), - (127, b"\x7f"), - (128, b"\xcc\x80"), - (256, b"\xcd\x01\x00"), - (-1, b"\xff"), - (-33, b"\xd0\xdf"), - (-129, b"\xd1\xff\x7f"), - ({1: 1}, b"\x81\x01\x01"), - (1.0, b"\xcb\x3f\xf0\x00\x00\x00\x00\x00\x00"), - ((), b"\x90"), - ( - tuple(range(15)), - b"\x9f\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e", - ), - ( - tuple(range(16)), - b"\xdc\x00\x10\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", - ), - ({}, b"\x80"), - ( - dict([(x, x) for x in range(15)]), - b"\x8f\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e", - ), - ( - dict([(x, x) for x in range(16)]), - b"\xde\x00\x10\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e\x0f\x0f", - ), - ] - - for v, p in cases: - match(v, p) - - -def test_unicode(): - assert unpackb(packb("foobar"), use_list=1) == b"foobar" diff --git a/srsly/tests/msgpack/test_except.py b/srsly/tests/msgpack/test_except.py deleted file mode 100644 index 2322e08..0000000 --- a/srsly/tests/msgpack/test_except.py +++ /dev/null @@ -1,59 +0,0 @@ -from pytest import raises -import datetime -from srsly.msgpack import packb, unpackb, Unpacker, FormatError, StackError, OutOfData - - -class DummyException(Exception): - pass - - -def test_raise_on_find_unsupported_value(): - with raises(TypeError): - packb(datetime.datetime.now()) - - -def test_raise_from_object_hook(): - def hook(obj): - raise DummyException - - raises(DummyException, unpackb, packb({}), object_hook=hook) - raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_hook=hook) - raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_pairs_hook=hook) - raises(DummyException, unpackb, packb({"fizz": {"buzz": "spam"}}), object_hook=hook) - raises( - DummyException, - unpackb, - packb({"fizz": {"buzz": "spam"}}), - object_pairs_hook=hook, - ) - - -def test_invalidvalue(): - incomplete = b"\xd9\x97#DL_" # raw8 - length=0x97 - with raises(ValueError): - unpackb(incomplete) - - with raises(OutOfData): - unpacker = Unpacker() - unpacker.feed(incomplete) - unpacker.unpack() - - with raises(FormatError): - unpackb(b"\xc1") # (undefined tag) - - with raises(FormatError): - unpackb(b"\x91\xc1") # fixarray(len=1) [ (undefined tag) ] - - with raises(StackError): - unpackb(b"\x91" * 3000) # nested fixarray(len=1) - - -def test_strict_map_key(): - valid = {u"unicode": 1, b"bytes": 2} - packed = packb(valid, use_bin_type=True) - assert valid == unpackb(packed, raw=False, strict_map_key=True) - - invalid = {42: 1} - packed = packb(invalid, use_bin_type=True) - with raises(ValueError): - unpackb(packed, raw=False, strict_map_key=True) diff --git a/srsly/tests/msgpack/test_extension.py b/srsly/tests/msgpack/test_extension.py deleted file mode 100644 index 5cbaafd..0000000 --- a/srsly/tests/msgpack/test_extension.py +++ /dev/null @@ -1,76 +0,0 @@ -import array -from srsly import msgpack -from srsly.msgpack.ext import ExtType - - -def test_pack_ext_type(): - def p(s): - packer = msgpack.Packer() - packer.pack_ext_type(0x42, s) - return packer.bytes() - - assert p(b"A") == b"\xd4\x42A" # fixext 1 - assert p(b"AB") == b"\xd5\x42AB" # fixext 2 - assert p(b"ABCD") == b"\xd6\x42ABCD" # fixext 4 - assert p(b"ABCDEFGH") == b"\xd7\x42ABCDEFGH" # fixext 8 - assert p(b"A" * 16) == b"\xd8\x42" + b"A" * 16 # fixext 16 - assert p(b"ABC") == b"\xc7\x03\x42ABC" # ext 8 - assert p(b"A" * 0x0123) == b"\xc8\x01\x23\x42" + b"A" * 0x0123 # ext 16 - assert ( - p(b"A" * 0x00012345) == b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345 - ) # ext 32 - - -def test_unpack_ext_type(): - def check(b, expected): - assert msgpack.unpackb(b) == expected - - check(b"\xd4\x42A", ExtType(0x42, b"A")) # fixext 1 - check(b"\xd5\x42AB", ExtType(0x42, b"AB")) # fixext 2 - check(b"\xd6\x42ABCD", ExtType(0x42, b"ABCD")) # fixext 4 - check(b"\xd7\x42ABCDEFGH", ExtType(0x42, b"ABCDEFGH")) # fixext 8 - check(b"\xd8\x42" + b"A" * 16, ExtType(0x42, b"A" * 16)) # fixext 16 - check(b"\xc7\x03\x42ABC", ExtType(0x42, b"ABC")) # ext 8 - check(b"\xc8\x01\x23\x42" + b"A" * 0x0123, ExtType(0x42, b"A" * 0x0123)) # ext 16 - check( - b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345, - ExtType(0x42, b"A" * 0x00012345), - ) # ext 32 - - -def test_extension_type(): - def default(obj): - print("default called", obj) - if isinstance(obj, array.array): - typecode = 123 # application specific typecode - data = obj.tobytes() - return ExtType(typecode, data) - raise TypeError("Unknown type object %r" % (obj,)) - - def ext_hook(code, data): - print("ext_hook called", code, data) - assert code == 123 - obj = array.array("d") - obj.frombytes(data) - return obj - - obj = [42, b"hello", array.array("d", [1.1, 2.2, 3.3])] - s = msgpack.packb(obj, default=default) - obj2 = msgpack.unpackb(s, ext_hook=ext_hook) - assert obj == obj2 - - -def test_overriding_hooks(): - def default(obj): - if isinstance(obj, int): - return {"__type__": "long", "__data__": str(obj)} - else: - return obj - - obj = {"testval": int(1823746192837461928374619)} - refobj = {"testval": default(obj["testval"])} - refout = msgpack.packb(refobj) - assert isinstance(refout, (str, bytes)) - testout = msgpack.packb(obj, default=default) - - assert refout == testout diff --git a/srsly/tests/msgpack/test_format.py b/srsly/tests/msgpack/test_format.py deleted file mode 100644 index c6f0871..0000000 --- a/srsly/tests/msgpack/test_format.py +++ /dev/null @@ -1,82 +0,0 @@ -from srsly.msgpack import unpackb - - -def check(src, should, use_list=0): - assert unpackb(src, use_list=use_list) == should - - -def testSimpleValue(): - check(b"\x93\xc0\xc2\xc3", (None, False, True)) - - -def testFixnum(): - check(b"\x92\x93\x00\x40\x7f\x93\xe0\xf0\xff", ((0, 64, 127), (-32, -16, -1))) - - -def testFixArray(): - check(b"\x92\x90\x91\x91\xc0", ((), ((None,),))) - - -def testFixRaw(): - check(b"\x94\xa0\xa1a\xa2bc\xa3def", (b"", b"a", b"bc", b"def")) - - -def testFixMap(): - check( - b"\x82\xc2\x81\xc0\xc0\xc3\x81\xc0\x80", {False: {None: None}, True: {None: {}}} - ) - - -def testUnsignedInt(): - check( - b"\x99\xcc\x00\xcc\x80\xcc\xff\xcd\x00\x00\xcd\x80\x00" - b"\xcd\xff\xff\xce\x00\x00\x00\x00\xce\x80\x00\x00\x00" - b"\xce\xff\xff\xff\xff", - (0, 128, 255, 0, 32768, 65535, 0, 2147483648, 4294967295), - ) - - -def testSignedInt(): - check( - b"\x99\xd0\x00\xd0\x80\xd0\xff\xd1\x00\x00\xd1\x80\x00" - b"\xd1\xff\xff\xd2\x00\x00\x00\x00\xd2\x80\x00\x00\x00" - b"\xd2\xff\xff\xff\xff", - (0, -128, -1, 0, -32768, -1, 0, -2147483648, -1), - ) - - -def testRaw(): - check( - b"\x96\xda\x00\x00\xda\x00\x01a\xda\x00\x02ab\xdb\x00\x00" - b"\x00\x00\xdb\x00\x00\x00\x01a\xdb\x00\x00\x00\x02ab", - (b"", b"a", b"ab", b"", b"a", b"ab"), - ) - - -def testArray(): - check( - b"\x96\xdc\x00\x00\xdc\x00\x01\xc0\xdc\x00\x02\xc2\xc3\xdd\x00" - b"\x00\x00\x00\xdd\x00\x00\x00\x01\xc0\xdd\x00\x00\x00\x02" - b"\xc2\xc3", - ((), (None,), (False, True), (), (None,), (False, True)), - ) - - -def testMap(): - check( - b"\x96" - b"\xde\x00\x00" - b"\xde\x00\x01\xc0\xc2" - b"\xde\x00\x02\xc0\xc2\xc3\xc2" - b"\xdf\x00\x00\x00\x00" - b"\xdf\x00\x00\x00\x01\xc0\xc2" - b"\xdf\x00\x00\x00\x02\xc0\xc2\xc3\xc2", - ( - {}, - {None: False}, - {True: False, None: False}, - {}, - {None: False}, - {True: False, None: False}, - ), - ) diff --git a/srsly/tests/msgpack/test_limits.py b/srsly/tests/msgpack/test_limits.py deleted file mode 100644 index bafa4c1..0000000 --- a/srsly/tests/msgpack/test_limits.py +++ /dev/null @@ -1,129 +0,0 @@ -import pytest -from srsly.msgpack import packb, unpackb, Packer, Unpacker, ExtType -from srsly.msgpack import PackOverflowError, PackValueError, UnpackValueError - - -def test_integer(): - x = -(2 ** 63) - assert unpackb(packb(x)) == x - with pytest.raises(PackOverflowError): - packb(x - 1) - - x = 2 ** 64 - 1 - assert unpackb(packb(x)) == x - with pytest.raises(PackOverflowError): - packb(x + 1) - - -def test_array_header(): - packer = Packer() - packer.pack_array_header(2 ** 32 - 1) - with pytest.raises(PackValueError): - packer.pack_array_header(2 ** 32) - - -def test_map_header(): - packer = Packer() - packer.pack_map_header(2 ** 32 - 1) - with pytest.raises(PackValueError): - packer.pack_array_header(2 ** 32) - - -def test_max_str_len(): - d = "x" * 3 - packed = packb(d) - - unpacker = Unpacker(max_str_len=3, raw=False) - unpacker.feed(packed) - assert unpacker.unpack() == d - - unpacker = Unpacker(max_str_len=2, raw=False) - with pytest.raises(UnpackValueError): - unpacker.feed(packed) - unpacker.unpack() - - -def test_max_bin_len(): - d = b"x" * 3 - packed = packb(d, use_bin_type=True) - - unpacker = Unpacker(max_bin_len=3) - unpacker.feed(packed) - assert unpacker.unpack() == d - - unpacker = Unpacker(max_bin_len=2) - with pytest.raises(UnpackValueError): - unpacker.feed(packed) - unpacker.unpack() - - -def test_max_array_len(): - d = [1, 2, 3] - packed = packb(d) - - unpacker = Unpacker(max_array_len=3) - unpacker.feed(packed) - assert unpacker.unpack() == d - - unpacker = Unpacker(max_array_len=2) - with pytest.raises(UnpackValueError): - unpacker.feed(packed) - unpacker.unpack() - - -def test_max_map_len(): - d = {1: 2, 3: 4, 5: 6} - packed = packb(d) - - unpacker = Unpacker(max_map_len=3) - unpacker.feed(packed) - assert unpacker.unpack() == d - - unpacker = Unpacker(max_map_len=2) - with pytest.raises(UnpackValueError): - unpacker.feed(packed) - unpacker.unpack() - - -def test_max_ext_len(): - d = ExtType(42, b"abc") - packed = packb(d) - - unpacker = Unpacker(max_ext_len=3) - unpacker.feed(packed) - assert unpacker.unpack() == d - - unpacker = Unpacker(max_ext_len=2) - with pytest.raises(UnpackValueError): - unpacker.feed(packed) - unpacker.unpack() - - -# PyPy fails following tests because of constant folding? -# https://bugs.pypy.org/issue1721 -# @pytest.mark.skipif(True, reason="Requires very large memory.") -# def test_binary(): -# x = b'x' * (2**32 - 1) -# assert unpackb(packb(x)) == x -# del x -# x = b'x' * (2**32) -# with pytest.raises(ValueError): -# packb(x) -# -# -# @pytest.mark.skipif(True, reason="Requires very large memory.") -# def test_string(): -# x = 'x' * (2**32 - 1) -# assert unpackb(packb(x)) == x -# x += 'y' -# with pytest.raises(ValueError): -# packb(x) -# -# -# @pytest.mark.skipif(True, reason="Requires very large memory.") -# def test_array(): -# x = [0] * (2**32 - 1) -# assert unpackb(packb(x)) == x -# x.append(0) -# with pytest.raises(ValueError): -# packb(x) diff --git a/srsly/tests/msgpack/test_memoryview.py b/srsly/tests/msgpack/test_memoryview.py deleted file mode 100644 index b182c7f..0000000 --- a/srsly/tests/msgpack/test_memoryview.py +++ /dev/null @@ -1,95 +0,0 @@ -from array import array -from srsly.msgpack import packb, unpackb - - -make_memoryview = memoryview - - -def make_array(f, data): - a = array(f) - a.frombytes(data) - return a - - -def get_data(a): - return a.tobytes() - - -def _runtest(format, nbytes, expected_header, expected_prefix, use_bin_type): - # create a new array - original_array = array(format) - original_array.fromlist([255] * (nbytes // original_array.itemsize)) - original_data = get_data(original_array) - view = make_memoryview(original_array) - - # pack, unpack, and reconstruct array - packed = packb(view, use_bin_type=use_bin_type) - unpacked = unpackb(packed) - reconstructed_array = make_array(format, unpacked) - - # check that we got the right amount of data - assert len(original_data) == nbytes - # check packed header - assert packed[:1] == expected_header - # check packed length prefix, if any - assert packed[1 : 1 + len(expected_prefix)] == expected_prefix - # check packed data - assert packed[1 + len(expected_prefix) :] == original_data - # check array unpacked correctly - assert original_array == reconstructed_array - - -def test_fixstr_from_byte(): - _runtest("B", 1, b"\xa1", b"", False) - _runtest("B", 31, b"\xbf", b"", False) - - -def test_fixstr_from_float(): - _runtest("f", 4, b"\xa4", b"", False) - _runtest("f", 28, b"\xbc", b"", False) - - -def test_str16_from_byte(): - _runtest("B", 2 ** 8, b"\xda", b"\x01\x00", False) - _runtest("B", 2 ** 16 - 1, b"\xda", b"\xff\xff", False) - - -def test_str16_from_float(): - _runtest("f", 2 ** 8, b"\xda", b"\x01\x00", False) - _runtest("f", 2 ** 16 - 4, b"\xda", b"\xff\xfc", False) - - -def test_str32_from_byte(): - _runtest("B", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False) - - -def test_str32_from_float(): - _runtest("f", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False) - - -def test_bin8_from_byte(): - _runtest("B", 1, b"\xc4", b"\x01", True) - _runtest("B", 2 ** 8 - 1, b"\xc4", b"\xff", True) - - -def test_bin8_from_float(): - _runtest("f", 4, b"\xc4", b"\x04", True) - _runtest("f", 2 ** 8 - 4, b"\xc4", b"\xfc", True) - - -def test_bin16_from_byte(): - _runtest("B", 2 ** 8, b"\xc5", b"\x01\x00", True) - _runtest("B", 2 ** 16 - 1, b"\xc5", b"\xff\xff", True) - - -def test_bin16_from_float(): - _runtest("f", 2 ** 8, b"\xc5", b"\x01\x00", True) - _runtest("f", 2 ** 16 - 4, b"\xc5", b"\xff\xfc", True) - - -def test_bin32_from_byte(): - _runtest("B", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True) - - -def test_bin32_from_float(): - _runtest("f", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True) diff --git a/srsly/tests/msgpack/test_newspec.py b/srsly/tests/msgpack/test_newspec.py deleted file mode 100644 index 316280b..0000000 --- a/srsly/tests/msgpack/test_newspec.py +++ /dev/null @@ -1,88 +0,0 @@ -from srsly.msgpack import packb, unpackb, ExtType - - -def test_str8(): - header = b"\xd9" - data = b"x" * 32 - b = packb(data.decode(), use_bin_type=True) - assert len(b) == len(data) + 2 - assert b[0:2] == header + b"\x20" - assert b[2:] == data - assert unpackb(b) == data - - data = b"x" * 255 - b = packb(data.decode(), use_bin_type=True) - assert len(b) == len(data) + 2 - assert b[0:2] == header + b"\xff" - assert b[2:] == data - assert unpackb(b) == data - - -def test_bin8(): - header = b"\xc4" - data = b"" - b = packb(data, use_bin_type=True) - assert len(b) == len(data) + 2 - assert b[0:2] == header + b"\x00" - assert b[2:] == data - assert unpackb(b) == data - - data = b"x" * 255 - b = packb(data, use_bin_type=True) - assert len(b) == len(data) + 2 - assert b[0:2] == header + b"\xff" - assert b[2:] == data - assert unpackb(b) == data - - -def test_bin16(): - header = b"\xc5" - data = b"x" * 256 - b = packb(data, use_bin_type=True) - assert len(b) == len(data) + 3 - assert b[0:1] == header - assert b[1:3] == b"\x01\x00" - assert b[3:] == data - assert unpackb(b) == data - - data = b"x" * 65535 - b = packb(data, use_bin_type=True) - assert len(b) == len(data) + 3 - assert b[0:1] == header - assert b[1:3] == b"\xff\xff" - assert b[3:] == data - assert unpackb(b) == data - - -def test_bin32(): - header = b"\xc6" - data = b"x" * 65536 - b = packb(data, use_bin_type=True) - assert len(b) == len(data) + 5 - assert b[0:1] == header - assert b[1:5] == b"\x00\x01\x00\x00" - assert b[5:] == data - assert unpackb(b) == data - - -def test_ext(): - def check(ext, packed): - assert packb(ext) == packed - assert unpackb(packed) == ext - - check(ExtType(0x42, b"Z"), b"\xd4\x42Z") # fixext 1 - check(ExtType(0x42, b"ZZ"), b"\xd5\x42ZZ") # fixext 2 - check(ExtType(0x42, b"Z" * 4), b"\xd6\x42" + b"Z" * 4) # fixext 4 - check(ExtType(0x42, b"Z" * 8), b"\xd7\x42" + b"Z" * 8) # fixext 8 - check(ExtType(0x42, b"Z" * 16), b"\xd8\x42" + b"Z" * 16) # fixext 16 - # ext 8 - check(ExtType(0x42, b""), b"\xc7\x00\x42") - check(ExtType(0x42, b"Z" * 255), b"\xc7\xff\x42" + b"Z" * 255) - # ext 16 - check(ExtType(0x42, b"Z" * 256), b"\xc8\x01\x00\x42" + b"Z" * 256) - check(ExtType(0x42, b"Z" * 0xFFFF), b"\xc8\xff\xff\x42" + b"Z" * 0xFFFF) - # ext 32 - check(ExtType(0x42, b"Z" * 0x10000), b"\xc9\x00\x01\x00\x00\x42" + b"Z" * 0x10000) - # needs large memory - # check(ExtType(0x42, b'Z'*0xffffffff), - # b'\xc9\xff\xff\xff\xff\x42' + b'Z'*0xffffffff) diff --git a/srsly/tests/msgpack/test_pack.py b/srsly/tests/msgpack/test_pack.py deleted file mode 100644 index 22f088d..0000000 --- a/srsly/tests/msgpack/test_pack.py +++ /dev/null @@ -1,201 +0,0 @@ -import struct -import pytest -from collections import OrderedDict -from io import BytesIO -from srsly.msgpack import packb, unpackb, Unpacker, Packer - - -def check(data, use_list=False): - re = unpackb(packb(data), use_list=use_list) - assert re == data - - -def testPack(): - test_data = [ - 0, - 1, - 127, - 128, - 255, - 256, - 65535, - 65536, - 4294967295, - 4294967296, - -1, - -32, - -33, - -128, - -129, - -32768, - -32769, - -4294967296, - -4294967297, - 1.0, - b"", - b"a", - b"a" * 31, - b"a" * 32, - None, - True, - False, - (), - ((),), - ((), None), - {None: 0}, - (1 << 23), - ] - for td in test_data: - check(td) - - -def testPackUnicode(): - test_data = ["", "abcd", ["defgh"], "Русский текст"] - for td in test_data: - re = unpackb(packb(td), use_list=1, raw=False) - assert re == td - packer = Packer() - data = packer.pack(td) - re = Unpacker(BytesIO(data), raw=False, use_list=1).unpack() - assert re == td - - -def testPackUTF32(): # deprecated - re = unpackb(packb("", encoding="utf-32"), use_list=1, encoding="utf-32") - assert re == "" - re = unpackb(packb("abcd", encoding="utf-32"), use_list=1, encoding="utf-32") - assert re == "abcd" - re = unpackb(packb(["defgh"], encoding="utf-32"), use_list=1, encoding="utf-32") - assert re == ["defgh"] - try: - packb("Русский текст", encoding="utf-32") - except LookupError as e: - pytest.xfail(str(e)) - # try: - # test_data = ["", "abcd", ["defgh"], "Русский текст"] - # for td in test_data: - # except LookupError as e: - # pytest.xfail(e) - - -def testPackBytes(): - test_data = [b"", b"abcd", (b"defgh",)] - for td in test_data: - check(td) - - -def testPackByteArrays(): - test_data = [bytearray(b""), bytearray(b"abcd"), (bytearray(b"defgh"),)] - for td in test_data: - check(td) - - -def testIgnoreUnicodeErrors(): # deprecated - re = unpackb( - packb(b"abc\xeddef"), encoding="utf-8", unicode_errors="ignore", use_list=1 - ) - assert re == "abcdef" - - -def testStrictUnicodeUnpack(): - with pytest.raises(UnicodeDecodeError): - unpackb(packb(b"abc\xeddef"), raw=False, use_list=1) - - -def testStrictUnicodePack(): # deprecated - with pytest.raises(UnicodeEncodeError): - packb("abc\xeddef", encoding="ascii", unicode_errors="strict") - - -def testIgnoreErrorsPack(): # deprecated - re = unpackb( - packb("abcФФФdef", encoding="ascii", unicode_errors="ignore"), - raw=False, - use_list=1, - ) - assert re == "abcdef" - - -def testDecodeBinary(): - re = unpackb(packb(b"abc"), encoding=None, use_list=1) - assert re == b"abc" - - -def testPackFloat(): - assert packb(1.0, use_single_float=True) == b"\xca" + struct.pack(str(">f"), 1.0) - assert packb(1.0, use_single_float=False) == b"\xcb" + struct.pack(str(">d"), 1.0) - - -def testArraySize(sizes=[0, 5, 50, 1000]): - bio = BytesIO() - packer = Packer() - for size in sizes: - bio.write(packer.pack_array_header(size)) - for i in range(size): - bio.write(packer.pack(i)) - - bio.seek(0) - unpacker = Unpacker(bio, use_list=1) - for size in sizes: - assert unpacker.unpack() == list(range(size)) - - -def test_manualreset(sizes=[0, 5, 50, 1000]): - packer = Packer(autoreset=False) - for size in sizes: - packer.pack_array_header(size) - for i in range(size): - packer.pack(i) - - bio = BytesIO(packer.bytes()) - unpacker = Unpacker(bio, use_list=1) - for size in sizes: - assert unpacker.unpack() == list(range(size)) - - packer.reset() - assert packer.bytes() == b"" - - -def testMapSize(sizes=[0, 5, 50, 1000]): - bio = BytesIO() - packer = Packer() - for size in sizes: - bio.write(packer.pack_map_header(size)) - for i in range(size): - bio.write(packer.pack(i)) # key - bio.write(packer.pack(i * 2)) # value - - bio.seek(0) - unpacker = Unpacker(bio) - for size in sizes: - assert unpacker.unpack() == dict((i, i * 2) for i in range(size)) - - -def test_odict(): - seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)] - od = OrderedDict(seq) - assert unpackb(packb(od), use_list=1) == dict(seq) - - def pair_hook(seq): - return list(seq) - - assert unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1) == seq - - -def test_pairlist(): - pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")] - packer = Packer() - packed = packer.pack_map_pairs(pairlist) - unpacked = unpackb(packed, object_pairs_hook=list) - assert pairlist == unpacked - - -def test_get_buffer(): - packer = Packer(autoreset=0, use_bin_type=True) - packer.pack([1, 2]) - strm = BytesIO() - strm.write(packer.getbuffer()) - written = strm.getvalue() - - expected = packb([1, 2], use_bin_type=True) - assert written == expected diff --git a/srsly/tests/msgpack/test_read_size.py b/srsly/tests/msgpack/test_read_size.py deleted file mode 100644 index 63e8b47..0000000 --- a/srsly/tests/msgpack/test_read_size.py +++ /dev/null @@ -1,71 +0,0 @@ -from srsly.msgpack import packb, Unpacker, OutOfData - - -UnexpectedTypeException = ValueError - - -def test_read_array_header(): - unpacker = Unpacker() - unpacker.feed(packb(["a", "b", "c"])) - assert unpacker.read_array_header() == 3 - assert unpacker.unpack() == b"a" - assert unpacker.unpack() == b"b" - assert unpacker.unpack() == b"c" - try: - unpacker.unpack() - assert 0, "should raise exception" - except OutOfData: - assert 1, "okay" - - -def test_read_map_header(): - unpacker = Unpacker() - unpacker.feed(packb({"a": "A"})) - assert unpacker.read_map_header() == 1 - assert unpacker.unpack() == b"a" - assert unpacker.unpack() == b"A" - try: - unpacker.unpack() - assert 0, "should raise exception" - except OutOfData: - assert 1, "okay" - - -def test_incorrect_type_array(): - unpacker = Unpacker() - unpacker.feed(packb(1)) - try: - unpacker.read_array_header() - assert 0, "should raise exception" - except UnexpectedTypeException: - assert 1, "okay" - - -def test_incorrect_type_map(): - unpacker = Unpacker() - unpacker.feed(packb(1)) - try: - unpacker.read_map_header() - assert 0, "should raise exception" - except UnexpectedTypeException: - assert 1, "okay" - - -def test_correct_type_nested_array(): - unpacker = Unpacker() - unpacker.feed(packb({"a": ["b", "c", "d"]})) - try: - unpacker.read_array_header() - assert 0, "should raise exception" - except UnexpectedTypeException: - assert 1, "okay" - - -def test_incorrect_type_nested_map(): - unpacker = Unpacker() - unpacker.feed(packb([{"a": "b"}])) - try: - unpacker.read_map_header() - assert 0, "should raise exception" - except UnexpectedTypeException: - assert 1, "okay" diff --git a/srsly/tests/msgpack/test_seq.py b/srsly/tests/msgpack/test_seq.py deleted file mode 100644 index c50e1cf..0000000 --- a/srsly/tests/msgpack/test_seq.py +++ /dev/null @@ -1,39 +0,0 @@ -import io -from srsly import msgpack - - -binarydata = bytes(bytearray(range(256))) - - -def gen_binary_data(idx): - return binarydata[: idx % 300] - - -def test_exceeding_unpacker_read_size(): - dumpf = io.BytesIO() - - packer = msgpack.Packer() - - NUMBER_OF_STRINGS = 6 - read_size = 16 - # 5 ok for read_size=16, while 6 glibc detected *** python: double free or corruption (fasttop): - # 20 ok for read_size=256, while 25 segfaults / glibc detected *** python: double free or corruption (!prev) - # 40 ok for read_size=1024, while 50 introduces errors - # 7000 ok for read_size=1024*1024, while 8000 leads to glibc detected *** python: double free or corruption (!prev): - - for idx in range(NUMBER_OF_STRINGS): - data = gen_binary_data(idx) - dumpf.write(packer.pack(data)) - - f = io.BytesIO(dumpf.getvalue()) - dumpf.close() - - unpacker = msgpack.Unpacker(f, read_size=read_size, use_list=1) - - read_count = 0 - for idx, o in enumerate(unpacker): - assert type(o) == bytes - assert o == gen_binary_data(idx) - read_count += 1 - - assert read_count == NUMBER_OF_STRINGS diff --git a/srsly/tests/msgpack/test_sequnpack.py b/srsly/tests/msgpack/test_sequnpack.py deleted file mode 100644 index bfd7afa..0000000 --- a/srsly/tests/msgpack/test_sequnpack.py +++ /dev/null @@ -1,128 +0,0 @@ -import io -import pytest -from srsly.msgpack import Unpacker, BufferFull -from srsly.msgpack import pack -from srsly.msgpack.exceptions import OutOfData - - -def test_partialdata(): - unpacker = Unpacker() - unpacker.feed(b"\xa5") - with pytest.raises(StopIteration): - next(iter(unpacker)) - unpacker.feed(b"h") - with pytest.raises(StopIteration): - next(iter(unpacker)) - unpacker.feed(b"a") - with pytest.raises(StopIteration): - next(iter(unpacker)) - unpacker.feed(b"l") - with pytest.raises(StopIteration): - next(iter(unpacker)) - unpacker.feed(b"l") - with pytest.raises(StopIteration): - next(iter(unpacker)) - unpacker.feed(b"o") - assert next(iter(unpacker)) == b"hallo" - - -def test_foobar(): - unpacker = Unpacker(read_size=3, use_list=1) - unpacker.feed(b"foobar") - assert unpacker.unpack() == ord(b"f") - assert unpacker.unpack() == ord(b"o") - assert unpacker.unpack() == ord(b"o") - assert unpacker.unpack() == ord(b"b") - assert unpacker.unpack() == ord(b"a") - assert unpacker.unpack() == ord(b"r") - with pytest.raises(OutOfData): - unpacker.unpack() - - unpacker.feed(b"foo") - unpacker.feed(b"bar") - - k = 0 - for o, e in zip(unpacker, "foobarbaz"): - assert o == ord(e) - k += 1 - assert k == len(b"foobar") - - -def test_foobar_skip(): - unpacker = Unpacker(read_size=3, use_list=1) - unpacker.feed(b"foobar") - assert unpacker.unpack() == ord(b"f") - unpacker.skip() - assert unpacker.unpack() == ord(b"o") - unpacker.skip() - assert unpacker.unpack() == ord(b"a") - unpacker.skip() - with pytest.raises(OutOfData): - unpacker.unpack() - - -def test_maxbuffersize(): - with pytest.raises(ValueError): - Unpacker(read_size=5, max_buffer_size=3) - unpacker = Unpacker(read_size=3, max_buffer_size=3, use_list=1) - unpacker.feed(b"fo") - with pytest.raises(BufferFull): - unpacker.feed(b"ob") - unpacker.feed(b"o") - assert ord("f") == next(unpacker) - unpacker.feed(b"b") - assert ord("o") == next(unpacker) - assert ord("o") == next(unpacker) - assert ord("b") == next(unpacker) - - -def test_readbytes(): - unpacker = Unpacker(read_size=3) - unpacker.feed(b"foobar") - assert unpacker.unpack() == ord(b"f") - assert unpacker.read_bytes(3) == b"oob" - assert unpacker.unpack() == ord(b"a") - assert unpacker.unpack() == ord(b"r") - - # Test buffer refill - unpacker = Unpacker(io.BytesIO(b"foobar"), read_size=3) - assert unpacker.unpack() == ord(b"f") - assert unpacker.read_bytes(3) == b"oob" - assert unpacker.unpack() == ord(b"a") - assert unpacker.unpack() == ord(b"r") - - -def test_issue124(): - unpacker = Unpacker() - unpacker.feed(b"\xa1?\xa1!") - assert tuple(unpacker) == (b"?", b"!") - assert tuple(unpacker) == () - unpacker.feed(b"\xa1?\xa1") - assert tuple(unpacker) == (b"?",) - assert tuple(unpacker) == () - unpacker.feed(b"!") - assert tuple(unpacker) == (b"!",) - assert tuple(unpacker) == () - - -def test_unpack_tell(): - stream = io.BytesIO() - messages = [2 ** i - 1 for i in range(65)] - messages += [-(2 ** i) for i in range(1, 64)] - messages += [ - b"hello", - b"hello" * 1000, - list(range(20)), - {i: bytes(i) * i for i in range(10)}, - {i: bytes(i) * i for i in range(32)}, - ] - offsets = [] - for m in messages: - pack(m, stream) - offsets.append(stream.tell()) - stream.seek(0) - unpacker = Unpacker(stream) - for m, o in zip(messages, offsets): - m2 = next(unpacker) - assert m == m2 - assert o == unpacker.tell() diff --git a/srsly/tests/msgpack/test_stricttype.py b/srsly/tests/msgpack/test_stricttype.py deleted file mode 100644 index fd99185..0000000 --- a/srsly/tests/msgpack/test_stricttype.py +++ /dev/null @@ -1,60 +0,0 @@ -from collections import namedtuple -from srsly.msgpack import packb, unpackb, ExtType - - -def test_namedtuple(): - T = namedtuple("T", "foo bar") - - def default(o): - if isinstance(o, T): - return dict(o._asdict()) - raise TypeError("Unsupported type %s" % (type(o),)) - - packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default) - unpacked = unpackb(packed, raw=False) - assert unpacked == {"foo": 1, "bar": 42} - - -def test_tuple(): - t = ("one", 2, b"three", (4,)) - - def default(o): - if isinstance(o, tuple): - return {"__type__": "tuple", "value": list(o)} - raise TypeError("Unsupported type %s" % (type(o),)) - - def convert(o): - if o.get("__type__") == "tuple": - return tuple(o["value"]) - return o - - data = packb(t, strict_types=True, use_bin_type=True, default=default) - expected = unpackb(data, raw=False, object_hook=convert) - - assert expected == t - - -def test_tuple_ext(): - t = ("one", 2, b"three", (4,)) - - MSGPACK_EXT_TYPE_TUPLE = 0 - - def default(o): - if isinstance(o, tuple): - # Convert to list and pack - payload = packb( - list(o), strict_types=True, use_bin_type=True, default=default - ) - return ExtType(MSGPACK_EXT_TYPE_TUPLE, payload) - raise TypeError(repr(o)) - - def convert(code, payload): - if code == MSGPACK_EXT_TYPE_TUPLE: - # Unpack and convert to tuple - return tuple(unpackb(payload, raw=False, ext_hook=convert)) - raise ValueError("Unknown Ext code {}".format(code)) - - data = packb(t, strict_types=True, use_bin_type=True, default=default) - expected = unpackb(data, raw=False, ext_hook=convert) - - assert expected == t diff --git a/srsly/tests/msgpack/test_subtype.py b/srsly/tests/msgpack/test_subtype.py deleted file mode 100644 index 7dfe08c..0000000 --- a/srsly/tests/msgpack/test_subtype.py +++ /dev/null @@ -1,23 +0,0 @@ -from collections import namedtuple -from srsly.msgpack import packb - - -class MyList(list): - pass - - -class MyDict(dict): - pass - - -class MyTuple(tuple): - pass - - -MyNamedTuple = namedtuple("MyNamedTuple", "x y") - - -def test_types(): - assert packb(MyDict()) == packb(dict()) - assert packb(MyList()) == packb(list()) - assert packb(MyNamedTuple(1, 2)) == packb((1, 2)) diff --git a/srsly/tests/msgpack/test_unpack.py b/srsly/tests/msgpack/test_unpack.py deleted file mode 100644 index 6045ea0..0000000 --- a/srsly/tests/msgpack/test_unpack.py +++ /dev/null @@ -1,70 +0,0 @@ -from io import BytesIO -import sys -import pytest -from srsly.msgpack import Unpacker, packb, OutOfData, ExtType - - -def test_unpack_array_header_from_file(): - f = BytesIO(packb([1, 2, 3, 4])) - unpacker = Unpacker(f) - assert unpacker.read_array_header() == 4 - assert unpacker.unpack() == 1 - assert unpacker.unpack() == 2 - assert unpacker.unpack() == 3 - assert unpacker.unpack() == 4 - with pytest.raises(OutOfData): - unpacker.unpack() - - -@pytest.mark.skipif( - "not hasattr(sys, 'getrefcount') == True", - reason="sys.getrefcount() is needed to pass this test", -) -def test_unpacker_hook_refcnt(): - result = [] - - def hook(x): - result.append(x) - return x - - basecnt = sys.getrefcount(hook) - - up = Unpacker(object_hook=hook, list_hook=hook) - - assert sys.getrefcount(hook) >= basecnt + 2 - - up.feed(packb([{}])) - up.feed(packb([{}])) - assert up.unpack() == [{}] - assert up.unpack() == [{}] - assert result == [{}, [{}], {}, [{}]] - - del up - - assert sys.getrefcount(hook) == basecnt - - -def test_unpacker_ext_hook(): - class MyUnpacker(Unpacker): - def __init__(self): - super(MyUnpacker, self).__init__(ext_hook=self._hook, raw=False) - - def _hook(self, code, data): - if code == 1: - return int(data) - else: - return ExtType(code, data) - - unpacker = MyUnpacker() - unpacker.feed(packb({"a": 1})) - assert unpacker.unpack() == {"a": 1} - unpacker.feed(packb({"a": ExtType(1, b"123")})) - assert unpacker.unpack() == {"a": 123} - unpacker.feed(packb({"a": ExtType(2, b"321")})) - assert unpacker.unpack() == {"a": ExtType(2, b"321")} - - -if __name__ == "__main__": - test_unpack_array_header_from_file() - test_unpacker_hook_refcnt() - test_unpacker_ext_hook() diff --git a/srsly/tests/ruamel_yaml/__init__.py b/srsly/tests/ruamel_yaml/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/srsly/tests/ruamel_yaml/roundtrip.py b/srsly/tests/ruamel_yaml/roundtrip.py deleted file mode 100755 index e3fd26f..0000000 --- a/srsly/tests/ruamel_yaml/roundtrip.py +++ /dev/null @@ -1,308 +0,0 @@ -""" -helper routines for testing round trip of commented YAML data -""" -import sys -import textwrap -from pathlib import Path - -enforce = object() - - -def dedent(data): - try: - position_of_first_newline = data.index("\n") - for idx in range(position_of_first_newline): - if not data[idx].isspace(): - raise ValueError - except ValueError: - pass - else: - data = data[position_of_first_newline + 1 :] - return textwrap.dedent(data) - - -def round_trip_load(inp, preserve_quotes=None, version=None): - import srsly.ruamel_yaml # NOQA - - dinp = dedent(inp) - return srsly.ruamel_yaml.load( - dinp, - Loader=srsly.ruamel_yaml.RoundTripLoader, - preserve_quotes=preserve_quotes, - version=version, - ) - - -def round_trip_load_all(inp, preserve_quotes=None, version=None): - import srsly.ruamel_yaml # NOQA - - dinp = dedent(inp) - return srsly.ruamel_yaml.load_all( - dinp, - Loader=srsly.ruamel_yaml.RoundTripLoader, - preserve_quotes=preserve_quotes, - version=version, - ) - - -def round_trip_dump( - data, - stream=None, - indent=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - explicit_start=None, - explicit_end=None, - version=None, -): - import srsly.ruamel_yaml # NOQA - - return srsly.ruamel_yaml.round_trip_dump( - data, - stream=stream, - indent=indent, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - ) - - -def diff(inp, outp, file_name="stdin"): - import difflib - - inl = inp.splitlines(True) # True for keepends - outl = outp.splitlines(True) - diff = difflib.unified_diff(inl, outl, file_name, "round trip YAML") - # 2.6 difflib has trailing space on filename lines %-) - strip_trailing_space = sys.version_info < (2, 7) - for line in diff: - if strip_trailing_space and line[:4] in ["--- ", "+++ "]: - line = line.rstrip() + "\n" - sys.stdout.write(line) - - -def round_trip( - inp, - outp=None, - extra=None, - intermediate=None, - indent=None, - block_seq_indent=None, - top_level_colon_align=None, - prefix_colon=None, - preserve_quotes=None, - explicit_start=None, - explicit_end=None, - version=None, - dump_data=None, -): - """ - inp: input string to parse - outp: expected output (equals input if not specified) - """ - if outp is None: - outp = inp - doutp = dedent(outp) - if extra is not None: - doutp += extra - data = round_trip_load(inp, preserve_quotes=preserve_quotes) - if dump_data: - print("data", data) - if intermediate is not None: - if isinstance(intermediate, dict): - for k, v in intermediate.items(): - if data[k] != v: - print("{0!r} <> {1!r}".format(data[k], v)) - raise ValueError - res = round_trip_dump( - data, - indent=indent, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - ) - if res != doutp: - diff(doutp, res, "input string") - print("\nroundtrip data:\n", res, sep="") - assert res == doutp - res = round_trip_dump( - data, - indent=indent, - block_seq_indent=block_seq_indent, - top_level_colon_align=top_level_colon_align, - prefix_colon=prefix_colon, - explicit_start=explicit_start, - explicit_end=explicit_end, - version=version, - ) - print("roundtrip second round data:\n", res, sep="") - assert res == doutp - return data - - -def na_round_trip( - inp, - outp=None, - extra=None, - intermediate=None, - indent=None, - top_level_colon_align=None, - prefix_colon=None, - preserve_quotes=None, - explicit_start=None, - explicit_end=None, - version=None, - dump_data=None, -): - """ - inp: input string to parse - outp: expected output (equals input if not specified) - """ - inp = dedent(inp) - if outp is None: - outp = inp - if version is not None: - version = version - doutp = dedent(outp) - if extra is not None: - doutp += extra - yaml = YAML() - yaml.preserve_quotes = preserve_quotes - yaml.scalar_after_indicator = False # newline after every directives end - data = yaml.load(inp) - if dump_data: - print("data", data) - if intermediate is not None: - if isinstance(intermediate, dict): - for k, v in intermediate.items(): - if data[k] != v: - print("{0!r} <> {1!r}".format(data[k], v)) - raise ValueError - yaml.indent = indent - yaml.top_level_colon_align = top_level_colon_align - yaml.prefix_colon = prefix_colon - yaml.explicit_start = explicit_start - yaml.explicit_end = explicit_end - res = yaml.dump(data, compare=doutp) - return res - - -def YAML(**kw): - import srsly.ruamel_yaml # NOQA - - class MyYAML(srsly.ruamel_yaml.YAML): - """auto dedent string parameters on load""" - - def load(self, stream): - if isinstance(stream, str): - if stream and stream[0] == "\n": - stream = stream[1:] - stream = textwrap.dedent(stream) - return srsly.ruamel_yaml.YAML.load(self, stream) - - def load_all(self, stream): - if isinstance(stream, str): - if stream and stream[0] == "\n": - stream = stream[1:] - stream = textwrap.dedent(stream) - for d in srsly.ruamel_yaml.YAML.load_all(self, stream): - yield d - - def dump(self, data, **kw): - from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA - - assert ("stream" in kw) ^ ("compare" in kw) - if "stream" in kw: - return srsly.ruamel_yaml.YAML.dump(data, **kw) - lkw = kw.copy() - expected = textwrap.dedent(lkw.pop("compare")) - unordered_lines = lkw.pop("unordered_lines", False) - if expected and expected[0] == "\n": - expected = expected[1:] - lkw["stream"] = st = StringIO() - srsly.ruamel_yaml.YAML.dump(self, data, **lkw) - res = st.getvalue() - print(res) - if unordered_lines: - res = sorted(res.splitlines()) - expected = sorted(expected.splitlines()) - assert res == expected - - def round_trip(self, stream, **kw): - from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA - - assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str)) - lkw = kw.copy() - if stream and stream[0] == "\n": - stream = stream[1:] - stream = textwrap.dedent(stream) - data = srsly.ruamel_yaml.YAML.load(self, stream) - outp = lkw.pop("outp", stream) - lkw["stream"] = st = StringIO() - srsly.ruamel_yaml.YAML.dump(self, data, **lkw) - res = st.getvalue() - if res != outp: - diff(outp, res, "input string") - assert res == outp - - def round_trip_all(self, stream, **kw): - from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA - - assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str)) - lkw = kw.copy() - if stream and stream[0] == "\n": - stream = stream[1:] - stream = textwrap.dedent(stream) - data = list(srsly.ruamel_yaml.YAML.load_all(self, stream)) - outp = lkw.pop("outp", stream) - lkw["stream"] = st = StringIO() - srsly.ruamel_yaml.YAML.dump_all(self, data, **lkw) - res = st.getvalue() - if res != outp: - diff(outp, res, "input string") - assert res == outp - - return MyYAML(**kw) - - -def save_and_run(program, base_dir=None, output=None, file_name=None, optimized=False): - """ - safe and run a python program, thereby circumventing any restrictions on module level - imports - """ - from subprocess import check_output, STDOUT, CalledProcessError - - if not hasattr(base_dir, "hash"): - base_dir = Path(str(base_dir)) - if file_name is None: - file_name = "safe_and_run_tmp.py" - file_name = base_dir / file_name - file_name.write_text(dedent(program)) - - try: - cmd = [sys.executable] - if optimized: - cmd.append("-O") - cmd.append(str(file_name)) - print("running:", *cmd) - res = check_output(cmd, stderr=STDOUT, universal_newlines=True) - if output is not None: - if "__pypy__" in sys.builtin_module_names: - res = res.splitlines(True) - res = [line for line in res if "no version info" not in line] - res = "".join(res) - print("result: ", res, end="") - print("expected:", output, end="") - assert res == output - except CalledProcessError as exception: - print("##### Running '{} {}' FAILED #####".format(sys.executable, file_name)) - print(exception.output) - return exception.returncode - return 0 diff --git a/srsly/tests/ruamel_yaml/test_a_dedent.py b/srsly/tests/ruamel_yaml/test_a_dedent.py deleted file mode 100755 index 146dfd0..0000000 --- a/srsly/tests/ruamel_yaml/test_a_dedent.py +++ /dev/null @@ -1,55 +0,0 @@ -from .roundtrip import dedent - - -class TestDedent: - def test_start_newline(self): - # fmt: off - x = dedent(""" - 123 - 456 - """) - # fmt: on - assert x == "123\n 456\n" - - def test_start_space_newline(self): - # special construct to prevent stripping of following whitespace - # fmt: off - x = dedent(" " """ - 123 - """) - # fmt: on - assert x == "123\n" - - def test_start_no_newline(self): - # special construct to prevent stripping of following whitespac - x = dedent( - """\ - 123 - 456 - """ - ) - assert x == "123\n 456\n" - - def test_preserve_no_newline_at_end(self): - x = dedent( - """ - 123""" - ) - assert x == "123" - - def test_preserve_no_newline_at_all(self): - x = dedent( - """\ - 123""" - ) - assert x == "123" - - def test_multiple_dedent(self): - x = dedent( - dedent( - """ - 123 - """ - ) - ) - assert x == "123\n" diff --git a/srsly/tests/ruamel_yaml/test_add_xxx.py b/srsly/tests/ruamel_yaml/test_add_xxx.py deleted file mode 100755 index 14ed103..0000000 --- a/srsly/tests/ruamel_yaml/test_add_xxx.py +++ /dev/null @@ -1,184 +0,0 @@ -# coding: utf-8 - -import re -import pytest # NOQA - -from .roundtrip import dedent - - -# from PyYAML docs -class Dice(tuple): - def __new__(cls, a, b): - return tuple.__new__(cls, [a, b]) - - def __repr__(self): - return "Dice(%s,%s)" % self - - -def dice_constructor(loader, node): - value = loader.construct_scalar(node) - a, b = map(int, value.split("d")) - return Dice(a, b) - - -def dice_representer(dumper, data): - return dumper.represent_scalar(u"!dice", u"{}d{}".format(*data)) - - -def test_dice_constructor(): - import srsly.ruamel_yaml # NOQA - - srsly.ruamel_yaml.add_constructor(u"!dice", dice_constructor) - with pytest.raises(ValueError): - data = srsly.ruamel_yaml.load( - "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader - ) - assert str(data) == "{'initial hit points': Dice(8,4)}" - - -def test_dice_constructor_with_loader(): - import srsly.ruamel_yaml # NOQA - - with pytest.raises(ValueError): - srsly.ruamel_yaml.add_constructor( - u"!dice", dice_constructor, Loader=srsly.ruamel_yaml.Loader - ) - data = srsly.ruamel_yaml.load( - "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader - ) - assert str(data) == "{'initial hit points': Dice(8,4)}" - - -def test_dice_representer(): - import srsly.ruamel_yaml # NOQA - - srsly.ruamel_yaml.add_representer(Dice, dice_representer) - # srsly.ruamel_yaml 0.15.8+ no longer forces quotes tagged scalars - assert ( - srsly.ruamel_yaml.dump(dict(gold=Dice(10, 6)), default_flow_style=False) - == "gold: !dice 10d6\n" - ) - - -def test_dice_implicit_resolver(): - import srsly.ruamel_yaml # NOQA - - pattern = re.compile(r"^\d+d\d+$") - with pytest.raises(ValueError): - srsly.ruamel_yaml.add_implicit_resolver(u"!dice", pattern) - assert ( - srsly.ruamel_yaml.dump(dict(treasure=Dice(10, 20)), default_flow_style=False) - == "treasure: 10d20\n" - ) - assert srsly.ruamel_yaml.load( - "damage: 5d10", Loader=srsly.ruamel_yaml.Loader - ) == dict(damage=Dice(5, 10)) - - -class Obj1(dict): - def __init__(self, suffix): - self._suffix = suffix - self._node = None - - def add_node(self, n): - self._node = n - - def __repr__(self): - return "Obj1(%s->%s)" % (self._suffix, self.items()) - - def dump(self): - return repr(self._node) - - -class YAMLObj1(object): - yaml_tag = u"!obj:" - - @classmethod - def from_yaml(cls, loader, suffix, node): - import srsly.ruamel_yaml # NOQA - - obj1 = Obj1(suffix) - if isinstance(node, srsly.ruamel_yaml.MappingNode): - obj1.add_node(loader.construct_mapping(node)) - else: - raise NotImplementedError - return obj1 - - @classmethod - def to_yaml(cls, dumper, data): - return dumper.represent_scalar(cls.yaml_tag + data._suffix, data.dump()) - - -def test_yaml_obj(): - import srsly.ruamel_yaml # NOQA - - srsly.ruamel_yaml.add_representer(Obj1, YAMLObj1.to_yaml) - srsly.ruamel_yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml) - with pytest.raises(ValueError): - x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader) - print(x) - assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n""" - - -def test_yaml_obj_with_loader_and_dumper(): - import srsly.ruamel_yaml # NOQA - - srsly.ruamel_yaml.add_representer( - Obj1, YAMLObj1.to_yaml, Dumper=srsly.ruamel_yaml.Dumper - ) - srsly.ruamel_yaml.add_multi_constructor( - YAMLObj1.yaml_tag, YAMLObj1.from_yaml, Loader=srsly.ruamel_yaml.Loader - ) - with pytest.raises(ValueError): - x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader) - # x = srsly.ruamel_yaml.load('!obj:x.2\na: 1') - print(x) - assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n""" - - -# ToDo use nullege to search add_multi_representer and add_path_resolver -# and add some test code - -# Issue 127 reported by Tommy Wang - - -def test_issue_127(): - import srsly.ruamel_yaml # NOQA - - class Ref(srsly.ruamel_yaml.YAMLObject): - yaml_constructor = srsly.ruamel_yaml.RoundTripConstructor - yaml_representer = srsly.ruamel_yaml.RoundTripRepresenter - yaml_tag = u"!Ref" - - def __init__(self, logical_id): - self.logical_id = logical_id - - @classmethod - def from_yaml(cls, loader, node): - return cls(loader.construct_scalar(node)) - - @classmethod - def to_yaml(cls, dumper, data): - if isinstance(data.logical_id, srsly.ruamel_yaml.scalarstring.ScalarString): - style = data.logical_id.style # srsly.ruamel_yaml>0.15.8 - else: - style = None - return dumper.represent_scalar(cls.yaml_tag, data.logical_id, style=style) - - document = dedent( - """\ - AList: - - !Ref One - - !Ref 'Two' - - !Ref - Two and a half - BList: [!Ref Three, !Ref "Four"] - CList: - - Five Six - - 'Seven Eight' - """ - ) - data = srsly.ruamel_yaml.round_trip_load(document, preserve_quotes=True) - assert srsly.ruamel_yaml.round_trip_dump( - data, indent=4, block_seq_indent=2 - ) == document.replace("\n Two and", " Two and") diff --git a/srsly/tests/ruamel_yaml/test_anchor.py b/srsly/tests/ruamel_yaml/test_anchor.py deleted file mode 100755 index ad0859f..0000000 --- a/srsly/tests/ruamel_yaml/test_anchor.py +++ /dev/null @@ -1,575 +0,0 @@ -""" -testing of anchors and the aliases referring to them -""" - -import pytest -from textwrap import dedent -import platform -import srsly - -from .roundtrip import ( - round_trip, - dedent, - round_trip_load, - round_trip_dump, - YAML, -) # NOQA - - -def load(s): - return round_trip_load(dedent(s)) - - -def compare(d, s): - assert round_trip_dump(d) == dedent(s) - - -class TestAnchorsAliases: - def test_anchor_id_renumber(self): - from srsly.ruamel_yaml.serializer import Serializer - - assert Serializer.ANCHOR_TEMPLATE == "id%03d" - data = load( - """ - a: &id002 - b: 1 - c: 2 - d: *id002 - """ - ) - compare( - data, - """ - a: &id001 - b: 1 - c: 2 - d: *id001 - """, - ) - - def test_template_matcher(self): - """test if id matches the anchor template""" - from srsly.ruamel_yaml.serializer import templated_id - - assert templated_id(u"id001") - assert templated_id(u"id999") - assert templated_id(u"id1000") - assert templated_id(u"id0001") - assert templated_id(u"id0000") - assert not templated_id(u"id02") - assert not templated_id(u"id000") - assert not templated_id(u"x000") - - # def test_re_matcher(self): - # import re - # assert re.compile(u'id(?!000)\\d{3,}').match('id001') - # assert not re.compile(u'id(?!000\\d*)\\d{3,}').match('id000') - # assert re.compile(u'id(?!000$)\\d{3,}').match('id0001') - - def test_anchor_assigned(self): - from srsly.ruamel_yaml.comments import CommentedMap - - data = load( - """ - a: &id002 - b: 1 - c: 2 - d: *id002 - e: &etemplate - b: 1 - c: 2 - f: *etemplate - """ - ) - d = data["d"] - assert isinstance(d, CommentedMap) - assert d.yaml_anchor() is None # got dropped as it matches pattern - e = data["e"] - assert isinstance(e, CommentedMap) - assert e.yaml_anchor().value == "etemplate" - assert e.yaml_anchor().always_dump is False - - def test_anchor_id_retained(self): - data = load( - """ - a: &id002 - b: 1 - c: 2 - d: *id002 - e: &etemplate - b: 1 - c: 2 - f: *etemplate - """ - ) - compare( - data, - """ - a: &id001 - b: 1 - c: 2 - d: *id001 - e: &etemplate - b: 1 - c: 2 - f: *etemplate - """, - ) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_alias_before_anchor(self): - from srsly.ruamel_yaml.composer import ComposerError - - with pytest.raises(ComposerError): - data = load( - """ - d: *id002 - a: &id002 - b: 1 - c: 2 - """ - ) - data = data - - def test_anchor_on_sequence(self): - # as reported by Bjorn Stabell - # https://bitbucket.org/ruamel/yaml/issue/7/anchor-names-not-preserved - from srsly.ruamel_yaml.comments import CommentedSeq - - data = load( - """ - nut1: &alice - - 1 - - 2 - nut2: &blake - - some data - - *alice - nut3: - - *blake - - *alice - """ - ) - r = data["nut1"] - assert isinstance(r, CommentedSeq) - assert r.yaml_anchor() is not None - assert r.yaml_anchor().value == "alice" - - merge_yaml = dedent( - """ - - &CENTER {x: 1, y: 2} - - &LEFT {x: 0, y: 2} - - &BIG {r: 10} - - &SMALL {r: 1} - # All the following maps are equal: - # Explicit keys - - x: 1 - y: 2 - r: 10 - label: center/small - # Merge one map - - <<: *CENTER - r: 10 - label: center/medium - # Merge multiple maps - - <<: [*CENTER, *BIG] - label: center/big - # Override - - <<: [*BIG, *LEFT, *SMALL] - x: 1 - label: center/huge - """ - ) - - def test_merge_00(self): - data = load(self.merge_yaml) - d = data[4] - ok = True - for k in d: - for o in [5, 6, 7]: - x = d.get(k) - y = data[o].get(k) - if not isinstance(x, int): - x = x.split("/")[0] - y = y.split("/")[0] - if x != y: - ok = False - print("key", k, d.get(k), data[o].get(k)) - assert ok - - def test_merge_accessible(self): - from srsly.ruamel_yaml.comments import CommentedMap, merge_attrib - - data = load( - """ - k: &level_2 { a: 1, b2 } - l: &level_1 { a: 10, c: 3 } - m: - <<: *level_1 - c: 30 - d: 40 - """ - ) - d = data["m"] - assert isinstance(d, CommentedMap) - assert hasattr(d, merge_attrib) - - def test_merge_01(self): - data = load(self.merge_yaml) - compare(data, self.merge_yaml) - - def test_merge_nested(self): - yaml = """ - a: - <<: &content - 1: plugh - 2: plover - 0: xyzzy - b: - <<: *content - """ - data = round_trip(yaml) # NOQA - - def test_merge_nested_with_sequence(self): - yaml = """ - a: - <<: &content - <<: &y2 - 1: plugh - 2: plover - 0: xyzzy - b: - <<: [*content, *y2] - """ - data = round_trip(yaml) # NOQA - - def test_add_anchor(self): - from srsly.ruamel_yaml.comments import CommentedMap - - data = CommentedMap() - data_a = CommentedMap() - data["a"] = data_a - data_a["c"] = 3 - data["b"] = 2 - data.yaml_set_anchor("klm", always_dump=True) - data["a"].yaml_set_anchor("xyz", always_dump=True) - compare( - data, - """ - &klm - a: &xyz - c: 3 - b: 2 - """, - ) - - # this is an error in PyYAML - def test_reused_anchor(self): - from srsly.ruamel_yaml.error import ReusedAnchorWarning - - yaml = """ - - &a - x: 1 - - <<: *a - - &a - x: 2 - - <<: *a - """ - with pytest.warns(ReusedAnchorWarning): - data = round_trip(yaml) # NOQA - - def test_issue_130(self): - # issue 130 reported by Devid Fee - ys = dedent( - """\ - components: - server: &server_component - type: spark.server:ServerComponent - host: 0.0.0.0 - port: 8000 - shell: &shell_component - type: spark.shell:ShellComponent - - services: - server: &server_service - <<: *server_component - shell: &shell_service - <<: *shell_component - components: - server: {<<: *server_service} - """ - ) - data = srsly.ruamel_yaml.safe_load(ys) - assert data["services"]["shell"]["components"]["server"]["port"] == 8000 - - def test_issue_130a(self): - # issue 130 reported by Devid Fee - ys = dedent( - """\ - components: - server: &server_component - type: spark.server:ServerComponent - host: 0.0.0.0 - port: 8000 - shell: &shell_component - type: spark.shell:ShellComponent - - services: - server: &server_service - <<: *server_component - port: 4000 - shell: &shell_service - <<: *shell_component - components: - server: {<<: *server_service} - """ - ) - data = srsly.ruamel_yaml.safe_load(ys) - assert data["services"]["shell"]["components"]["server"]["port"] == 4000 - - -class TestMergeKeysValues: - - yaml_str = dedent( - """\ - - &mx - a: x1 - b: x2 - c: x3 - - &my - a: y1 - b: y2 # masked by the one in &mx - d: y4 - - - a: 1 - <<: [*mx, *my] - m: 6 - """ - ) - - # in the following d always has "expanded" the merges - - def test_merge_for(self): - from srsly.ruamel_yaml import safe_load - - d = safe_load(self.yaml_str) - data = round_trip_load(self.yaml_str) - count = 0 - for x in data[2]: - count += 1 - print(count, x) - assert count == len(d[2]) - - def test_merge_keys(self): - from srsly.ruamel_yaml import safe_load - - d = safe_load(self.yaml_str) - data = round_trip_load(self.yaml_str) - count = 0 - for x in data[2].keys(): - count += 1 - print(count, x) - assert count == len(d[2]) - - def test_merge_values(self): - from srsly.ruamel_yaml import safe_load - - d = safe_load(self.yaml_str) - data = round_trip_load(self.yaml_str) - count = 0 - for x in data[2].values(): - count += 1 - print(count, x) - assert count == len(d[2]) - - def test_merge_items(self): - from srsly.ruamel_yaml import safe_load - - d = safe_load(self.yaml_str) - data = round_trip_load(self.yaml_str) - count = 0 - for x in data[2].items(): - count += 1 - print(count, x) - assert count == len(d[2]) - - def test_len_items_delete(self): - from srsly.ruamel_yaml import safe_load - from srsly.ruamel_yaml.compat import PY3 - - d = safe_load(self.yaml_str) - data = round_trip_load(self.yaml_str) - x = data[2].items() - print("d2 items", d[2].items(), len(d[2].items()), x, len(x)) - ref = len(d[2].items()) - print("ref", ref) - assert len(x) == ref - del data[2]["m"] - if PY3: - ref -= 1 - assert len(x) == ref - del data[2]["d"] - if PY3: - ref -= 1 - assert len(x) == ref - del data[2]["a"] - if PY3: - ref -= 1 - assert len(x) == ref - - def test_issue_196_cast_of_dict(self, capsys): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - mapping = yaml.load( - """\ - anchored: &anchor - a : 1 - - mapping: - <<: *anchor - b: 2 - """ - )["mapping"] - - for k in mapping: - print("k", k) - for k in mapping.copy(): - print("kc", k) - - print("v", list(mapping.keys())) - print("v", list(mapping.values())) - print("v", list(mapping.items())) - print(len(mapping)) - print("-----") - - # print({**mapping}) - # print(type({**mapping})) - # assert 'a' in {**mapping} - assert "a" in mapping - x = {} - for k in mapping: - x[k] = mapping[k] - assert "a" in x - assert "a" in mapping.keys() - assert mapping["a"] == 1 - assert mapping.__getitem__("a") == 1 - assert "a" in dict(mapping) - assert "a" in dict(mapping.items()) - - def test_values_of_merged(self): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - data = yaml.load(dedent(self.yaml_str)) - assert list(data[2].values()) == [1, 6, "x2", "x3", "y4"] - - def test_issue_213_copy_of_merge(self): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - d = yaml.load( - """\ - foo: &foo - a: a - foo2: - <<: *foo - b: b - """ - )["foo2"] - assert d["a"] == "a" - d2 = d.copy() - assert d2["a"] == "a" - print("d", d) - del d["a"] - assert "a" not in d - assert "a" in d2 - - -class TestDuplicateKeyThroughAnchor: - def test_duplicate_key_00(self): - from srsly.ruamel_yaml import version_info - from srsly.ruamel_yaml import safe_load, round_trip_load - from srsly.ruamel_yaml.constructor import ( - DuplicateKeyFutureWarning, - DuplicateKeyError, - ) - - s = dedent( - """\ - &anchor foo: - foo: bar - *anchor : duplicate key - baz: bat - *anchor : duplicate key - """ - ) - if version_info < (0, 15, 1): - pass - elif version_info < (0, 16, 0): - with pytest.warns(DuplicateKeyFutureWarning): - safe_load(s) - with pytest.warns(DuplicateKeyFutureWarning): - round_trip_load(s) - else: - with pytest.raises(DuplicateKeyError): - safe_load(s) - with pytest.raises(DuplicateKeyError): - round_trip_load(s) - - def test_duplicate_key_01(self): - # so issue https://stackoverflow.com/a/52852106/1307905 - from srsly.ruamel_yaml import version_info - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - s = dedent( - """\ - - &name-name - a: 1 - - &help-name - b: 2 - - <<: *name-name - <<: *help-name - """ - ) - if version_info < (0, 15, 1): - pass - else: - with pytest.raises(DuplicateKeyError): - yaml = YAML(typ="safe") - yaml.load(s) - with pytest.raises(DuplicateKeyError): - yaml = YAML() - yaml.load(s) - - -class TestFullCharSetAnchors: - def test_master_of_orion(self): - # https://bitbucket.org/ruamel/yaml/issues/72/not-allowed-in-anchor-names - # submitted by Shalon Wood - yaml_str = """ - - collection: &Backend.Civilizations.RacialPerk - items: - - key: perk_population_growth_modifier - - *Backend.Civilizations.RacialPerk - """ - data = load(yaml_str) # NOQA - - def test_roundtrip_00(self): - yaml_str = """ - - &dotted.words.here - a: 1 - b: 2 - - *dotted.words.here - """ - data = round_trip(yaml_str) # NOQA - - def test_roundtrip_01(self): - yaml_str = """ - - &dotted.words.here[a, b] - - *dotted.words.here - """ - data = load(yaml_str) # NOQA - compare(data, yaml_str.replace("[", " [")) # an extra space is inserted diff --git a/srsly/tests/ruamel_yaml/test_api_change.py b/srsly/tests/ruamel_yaml/test_api_change.py deleted file mode 100755 index d2e4e1e..0000000 --- a/srsly/tests/ruamel_yaml/test_api_change.py +++ /dev/null @@ -1,247 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -""" -testing of anchors and the aliases referring to them -""" - -import sys -import textwrap -import pytest -from pathlib import Path - -pytestmark = pytest.mark.filterwarnings( - "ignore::pytest.PytestUnraisableExceptionWarning" -) - - -class TestNewAPI: - def test_duplicate_keys_00(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - yaml = YAML() - with pytest.raises(DuplicateKeyError): - yaml.load("{a: 1, a: 2}") - - def test_duplicate_keys_01(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - yaml = YAML(typ="safe", pure=True) - with pytest.raises(DuplicateKeyError): - yaml.load("{a: 1, a: 2}") - - def test_duplicate_keys_02(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - yaml = YAML(typ="safe") - with pytest.raises(DuplicateKeyError): - yaml.load("{a: 1, a: 2}") - - def test_issue_135(self): - # reported by Andrzej Ostrowski - from srsly.ruamel_yaml import YAML - - data = {"a": 1, "b": 2} - yaml = YAML(typ="safe") - # originally on 2.7: with pytest.raises(TypeError): - yaml.dump(data, sys.stdout) - - def test_issue_135_temporary_workaround(self): - # never raised error - from srsly.ruamel_yaml import YAML - - data = {"a": 1, "b": 2} - yaml = YAML(typ="safe", pure=True) - yaml.dump(data, sys.stdout) - - -class TestWrite: - def test_dump_path(self, tmpdir): - from srsly.ruamel_yaml import YAML - - fn = Path(str(tmpdir)) / "test.yaml" - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - yaml.dump(data, fn) - assert fn.read_text() == "a: 1\nb: 2\n" - - def test_dump_file(self, tmpdir): - from srsly.ruamel_yaml import YAML - - fn = Path(str(tmpdir)) / "test.yaml" - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - with open(str(fn), "w") as fp: - yaml.dump(data, fp) - assert fn.read_text() == "a: 1\nb: 2\n" - - def test_dump_missing_stream(self): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - with pytest.raises(TypeError): - yaml.dump(data) - - def test_dump_too_many_args(self, tmpdir): - from srsly.ruamel_yaml import YAML - - fn = Path(str(tmpdir)) / "test.yaml" - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - with pytest.raises(TypeError): - yaml.dump(data, fn, True) - - def test_transform(self, tmpdir): - from srsly.ruamel_yaml import YAML - - def tr(s): - return s.replace(" ", " ") - - fn = Path(str(tmpdir)) / "test.yaml" - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - yaml.dump(data, fn, transform=tr) - assert fn.read_text() == "a: 1\nb: 2\n" - - def test_print(self, capsys): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - data = yaml.map() - data["a"] = 1 - data["b"] = 2 - yaml.dump(data, sys.stdout) - out, err = capsys.readouterr() - assert out == "a: 1\nb: 2\n" - - -class TestRead: - def test_multi_load(self): - # make sure reader, scanner, parser get reset - from srsly.ruamel_yaml import YAML - - yaml = YAML() - yaml.load("a: 1") - yaml.load("a: 1") # did not work in 0.15.4 - - def test_parse(self): - # ensure `parse` method is functional and can parse "unsafe" yaml - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.constructor import ConstructorError - - yaml = YAML(typ="safe") - s = "- !User0 {age: 18, name: Anthon}" - # should fail to load - with pytest.raises(ConstructorError): - yaml.load(s) - # should parse fine - yaml = YAML(typ="safe") - for _ in yaml.parse(s): - pass - - -class TestLoadAll: - def test_multi_document_load(self, tmpdir): - """this went wrong on 3.7 because of StopIteration, PR 37 and Issue 211""" - from srsly.ruamel_yaml import YAML - - fn = Path(str(tmpdir)) / "test.yaml" - fn.write_text( - textwrap.dedent( - u"""\ - --- - - a - --- - - b - ... - """ - ) - ) - yaml = YAML() - assert list(yaml.load_all(fn)) == [["a"], ["b"]] - - -class TestDuplSet: - def test_dupl_set_00(self): - # round-trip-loader should except - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - yaml = YAML() - with pytest.raises(DuplicateKeyError): - yaml.load( - textwrap.dedent( - """\ - !!set - ? a - ? b - ? c - ? a - """ - ) - ) - - -class TestDumpLoadUnicode: - # test triggered by SamH on stackoverflow (https://stackoverflow.com/q/45281596/1307905) - # and answer by randomir (https://stackoverflow.com/a/45281922/1307905) - def test_write_unicode(self, tmpdir): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - text_dict = {"text": u"HELLO_WORLD©"} - file_name = str(tmpdir) + "/tstFile.yaml" - yaml.dump(text_dict, open(file_name, "w", encoding="utf8", newline="\n")) - assert open(file_name, "rb").read().decode("utf-8") == u"text: HELLO_WORLD©\n" - - def test_read_unicode(self, tmpdir): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - file_name = str(tmpdir) + "/tstFile.yaml" - with open(file_name, "wb") as fp: - fp.write(u"text: HELLO_WORLD©\n".encode("utf-8")) - with open(file_name, "r", encoding="utf8") as fp: - text_dict = yaml.load(fp) - assert text_dict["text"] == u"HELLO_WORLD©" - - -class TestFlowStyle: - def test_flow_style(self, capsys): - # https://stackoverflow.com/questions/45791712/ - from srsly.ruamel_yaml import YAML - - yaml = YAML() - yaml.default_flow_style = None - data = yaml.map() - data["b"] = 1 - data["a"] = [[1, 2], [3, 4]] - yaml.dump(data, sys.stdout) - out, err = capsys.readouterr() - assert out == "b: 1\na:\n- [1, 2]\n- [3, 4]\n" - - -class TestOldAPI: - @pytest.mark.skipif(sys.version_info >= (3, 0), reason="ok on Py3") - def test_duplicate_keys_02(self): - # Issue 165 unicode keys in error/warning - from srsly.ruamel_yaml import safe_load - from srsly.ruamel_yaml.constructor import DuplicateKeyError - - with pytest.raises(DuplicateKeyError): - safe_load("type: Doméstica\ntype: International") diff --git a/srsly/tests/ruamel_yaml/test_appliance.py b/srsly/tests/ruamel_yaml/test_appliance.py deleted file mode 100755 index bba0210..0000000 --- a/srsly/tests/ruamel_yaml/test_appliance.py +++ /dev/null @@ -1,220 +0,0 @@ -from __future__ import print_function - -import sys -import os -import types -import traceback -import pprint -import argparse -from srsly.ruamel_yaml.compat import PY3 - -# DATA = 'tests/data' -# determine the position of data dynamically relative to program -# this allows running test while the current path is not the top of the -# repository, e.g. from the tests/data directory: python ../test_yaml.py -DATA = __file__.rsplit(os.sep, 2)[0] + "/data" - - -def find_test_functions(collections): - if not isinstance(collections, list): - collections = [collections] - functions = [] - for collection in collections: - if not isinstance(collection, dict): - collection = vars(collection) - for key in sorted(collection): - value = collection[key] - if isinstance(value, types.FunctionType) and hasattr(value, "unittest"): - functions.append(value) - return functions - - -def find_test_filenames(directory): - filenames = {} - for filename in os.listdir(directory): - if os.path.isfile(os.path.join(directory, filename)): - base, ext = os.path.splitext(filename) - if base.endswith("-py2" if PY3 else "-py3"): - continue - filenames.setdefault(base, []).append(ext) - filenames = sorted(filenames.items()) - return filenames - - -def parse_arguments(args): - """""" - parser = argparse.ArgumentParser( - usage=""" run the yaml tests. By default - all functions on all appropriate test_files are run. Functions have - unittest attributes that determine the required extensions to filenames - that need to be available in order to run that test. E.g.\n\n - python test_yaml.py test_constructor_types\n - python test_yaml.py --verbose test_tokens spec-02-05\n\n - The presence of an extension in the .skip attribute of a function - disables the test for that function.""" - ) - # ToDo: make into int and test > 0 in functions - parser.add_argument( - "--verbose", - "-v", - action="store_true", - default="YAML_TEST_VERBOSE" in os.environ, - help="set verbosity output", - ) - parser.add_argument( - "--list-functions", - action="store_true", - help="""list all functions with required file extensions for test files - """, - ) - parser.add_argument("function", nargs="?", help="""restrict function to run""") - parser.add_argument( - "filenames", - nargs="*", - help="""basename of filename set, extensions (.code, .data) have to - be a superset of those in the unittest attribute of the selected - function""", - ) - args = parser.parse_args(args) - # print('args', args) - verbose = args.verbose - include_functions = [args.function] if args.function else [] - include_filenames = args.filenames - # if args is None: - # args = sys.argv[1:] - # verbose = False - # if '-v' in args: - # verbose = True - # args.remove('-v') - # if '--verbose' in args: - # verbose = True - # args.remove('--verbose') # never worked without this - # if 'YAML_TEST_VERBOSE' in os.environ: - # verbose = True - # include_functions = [] - # if args: - # include_functions.append(args.pop(0)) - if "YAML_TEST_FUNCTIONS" in os.environ: - include_functions.extend(os.environ["YAML_TEST_FUNCTIONS"].split()) - # include_filenames = [] - # include_filenames.extend(args) - if "YAML_TEST_FILENAMES" in os.environ: - include_filenames.extend(os.environ["YAML_TEST_FILENAMES"].split()) - return include_functions, include_filenames, verbose, args - - -def execute(function, filenames, verbose): - if PY3: - name = function.__name__ - else: - if hasattr(function, "unittest_name"): - name = function.unittest_name - else: - name = function.func_name - if verbose: - sys.stdout.write("=" * 75 + "\n") - sys.stdout.write("%s(%s)...\n" % (name, ", ".join(filenames))) - try: - function(verbose=verbose, *filenames) - except Exception as exc: - info = sys.exc_info() - if isinstance(exc, AssertionError): - kind = "FAILURE" - else: - kind = "ERROR" - if verbose: - traceback.print_exc(limit=1, file=sys.stdout) - else: - sys.stdout.write(kind[0]) - sys.stdout.flush() - else: - kind = "SUCCESS" - info = None - if not verbose: - sys.stdout.write(".") - sys.stdout.flush() - return (name, filenames, kind, info) - - -def display(results, verbose): - if results and not verbose: - sys.stdout.write("\n") - total = len(results) - failures = 0 - errors = 0 - for name, filenames, kind, info in results: - if kind == "SUCCESS": - continue - if kind == "FAILURE": - failures += 1 - if kind == "ERROR": - errors += 1 - sys.stdout.write("=" * 75 + "\n") - sys.stdout.write("%s(%s): %s\n" % (name, ", ".join(filenames), kind)) - if kind == "ERROR": - traceback.print_exception(file=sys.stdout, *info) - else: - sys.stdout.write("Traceback (most recent call last):\n") - traceback.print_tb(info[2], file=sys.stdout) - sys.stdout.write("%s: see below\n" % info[0].__name__) - sys.stdout.write("~" * 75 + "\n") - for arg in info[1].args: - pprint.pprint(arg, stream=sys.stdout) - for filename in filenames: - sys.stdout.write("-" * 75 + "\n") - sys.stdout.write("%s:\n" % filename) - if PY3: - with open(filename, "r", errors="replace") as fp: - data = fp.read() - else: - with open(filename, "rb") as fp: - data = fp.read() - sys.stdout.write(data) - if data and data[-1] != "\n": - sys.stdout.write("\n") - sys.stdout.write("=" * 75 + "\n") - sys.stdout.write("TESTS: %s\n" % total) - ret_val = 0 - if failures: - sys.stdout.write("FAILURES: %s\n" % failures) - ret_val = 1 - if errors: - sys.stdout.write("ERRORS: %s\n" % errors) - ret_val = 2 - return ret_val - - -def run(collections, args=None): - test_functions = find_test_functions(collections) - test_filenames = find_test_filenames(DATA) - include_functions, include_filenames, verbose, a = parse_arguments(args) - if a.list_functions: - print("test functions:") - for f in test_functions: - print(" {:30s} {}".format(f.__name__, f.unittest)) - return - results = [] - for function in test_functions: - if include_functions and function.__name__ not in include_functions: - continue - if function.unittest: - for base, exts in test_filenames: - if include_filenames and base not in include_filenames: - continue - filenames = [] - for ext in function.unittest: - if ext not in exts: - break - filenames.append(os.path.join(DATA, base + ext)) - else: - skip_exts = getattr(function, "skip", []) - for skip_ext in skip_exts: - if skip_ext in exts: - break - else: - result = execute(function, filenames, verbose) - results.append(result) - else: - result = execute(function, [], verbose) - results.append(result) - return display(results, verbose=verbose) diff --git a/srsly/tests/ruamel_yaml/test_class_register.py b/srsly/tests/ruamel_yaml/test_class_register.py deleted file mode 100755 index 190f6a2..0000000 --- a/srsly/tests/ruamel_yaml/test_class_register.py +++ /dev/null @@ -1,141 +0,0 @@ -# coding: utf-8 - -""" -testing of YAML.register_class and @yaml_object -""" - -from .roundtrip import YAML - - -class User0(object): - def __init__(self, name, age): - self.name = name - self.age = age - - -class User1(object): - yaml_tag = u"!user" - - def __init__(self, name, age): - self.name = name - self.age = age - - @classmethod - def to_yaml(cls, representer, node): - return representer.represent_scalar( - cls.yaml_tag, u"{.name}-{.age}".format(node, node) - ) - - @classmethod - def from_yaml(cls, constructor, node): - return cls(*node.value.split("-")) - - -class TestRegisterClass(object): - def test_register_0_rt(self): - yaml = YAML() - yaml.register_class(User0) - ys = """ - - !User0 - name: Anthon - age: 18 - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys, unordered_lines=True) - - def test_register_0_safe(self): - # default_flow_style = None - yaml = YAML(typ="safe") - yaml.register_class(User0) - ys = """ - - !User0 {age: 18, name: Anthon} - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_register_0_unsafe(self): - # default_flow_style = None - yaml = YAML(typ="unsafe") - yaml.register_class(User0) - ys = """ - - !User0 {age: 18, name: Anthon} - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_register_1_rt(self): - yaml = YAML() - yaml.register_class(User1) - ys = """ - - !user Anthon-18 - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_register_1_safe(self): - yaml = YAML(typ="safe") - yaml.register_class(User1) - ys = """ - [!user Anthon-18] - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_register_1_unsafe(self): - yaml = YAML(typ="unsafe") - yaml.register_class(User1) - ys = """ - [!user Anthon-18] - """ - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - -class TestDecorator(object): - def test_decorator_implicit(self): - from srsly.ruamel_yaml import yaml_object - - yml = YAML() - - @yaml_object(yml) - class User2(object): - def __init__(self, name, age): - self.name = name - self.age = age - - ys = """ - - !User2 - name: Anthon - age: 18 - """ - d = yml.load(ys) - yml.dump(d, compare=ys, unordered_lines=True) - - def test_decorator_explicit(self): - from srsly.ruamel_yaml import yaml_object - - yml = YAML() - - @yaml_object(yml) - class User3(object): - yaml_tag = u"!USER" - - def __init__(self, name, age): - self.name = name - self.age = age - - @classmethod - def to_yaml(cls, representer, node): - return representer.represent_scalar( - cls.yaml_tag, u"{.name}-{.age}".format(node, node) - ) - - @classmethod - def from_yaml(cls, constructor, node): - return cls(*node.value.split("-")) - - ys = """ - - !USER Anthon-18 - """ - d = yml.load(ys) - yml.dump(d, compare=ys) diff --git a/srsly/tests/ruamel_yaml/test_collections.py b/srsly/tests/ruamel_yaml/test_collections.py deleted file mode 100755 index 632aae2..0000000 --- a/srsly/tests/ruamel_yaml/test_collections.py +++ /dev/null @@ -1,21 +0,0 @@ -# coding: utf-8 - -""" -collections.OrderedDict is a new class not supported by PyYAML (issue 83 by Frazer McLean) - -This is now so integrated in Python that it can be mapped to !!omap - -""" - -import pytest # NOQA - - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -class TestOrderedDict: - def test_ordereddict(self): - from collections import OrderedDict - import srsly.ruamel_yaml # NOQA - - assert srsly.ruamel_yaml.dump(OrderedDict()) == "!!omap []\n" diff --git a/srsly/tests/ruamel_yaml/test_comment_manipulation.py b/srsly/tests/ruamel_yaml/test_comment_manipulation.py deleted file mode 100755 index 95ce934..0000000 --- a/srsly/tests/ruamel_yaml/test_comment_manipulation.py +++ /dev/null @@ -1,639 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -import pytest # NOQA - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -def load(s): - return round_trip_load(dedent(s)) - - -def compare(data, s, **kw): - assert round_trip_dump(data, **kw) == dedent(s) - - -def compare_eol(data, s): - assert "EOL" in s - ds = dedent(s).replace("EOL", "").replace("\n", "|\n") - assert round_trip_dump(data).replace("\n", "|\n") == ds - - -class TestCommentsManipulation: - - # list - def test_seq_set_comment_on_existing_explicit_column(self): - data = load( - """ - - a # comment 1 - - b - - c - """ - ) - data.yaml_add_eol_comment("comment 2", key=1, column=6) - exp = """ - - a # comment 1 - - b # comment 2 - - c - """ - compare(data, exp) - - def test_seq_overwrite_comment_on_existing_explicit_column(self): - data = load( - """ - - a # comment 1 - - b - - c - """ - ) - data.yaml_add_eol_comment("comment 2", key=0, column=6) - exp = """ - - a # comment 2 - - b - - c - """ - compare(data, exp) - - def test_seq_first_comment_explicit_column(self): - data = load( - """ - - a - - b - - c - """ - ) - data.yaml_add_eol_comment("comment 1", key=1, column=6) - exp = """ - - a - - b # comment 1 - - c - """ - compare(data, exp) - - def test_seq_set_comment_on_existing_column_prev(self): - data = load( - """ - - a # comment 1 - - b - - c - - d # comment 3 - """ - ) - data.yaml_add_eol_comment("comment 2", key=1) - exp = """ - - a # comment 1 - - b # comment 2 - - c - - d # comment 3 - """ - compare(data, exp) - - def test_seq_set_comment_on_existing_column_next(self): - data = load( - """ - - a # comment 1 - - b - - c - - d # comment 3 - """ - ) - print(data._yaml_comment) - # print(type(data._yaml_comment._items[0][0].start_mark)) - # srsly.ruamel_yaml.error.Mark - # print(type(data._yaml_comment._items[0][0].start_mark)) - data.yaml_add_eol_comment("comment 2", key=2) - exp = """ - - a # comment 1 - - b - - c # comment 2 - - d # comment 3 - """ - compare(data, exp) - - def test_seq_set_comment_on_existing_column_further_away(self): - """ - no comment line before or after, take the latest before - the new position - """ - data = load( - """ - - a # comment 1 - - b - - c - - d - - e - - f # comment 3 - """ - ) - print(data._yaml_comment) - # print(type(data._yaml_comment._items[0][0].start_mark)) - # srsly.ruamel_yaml.error.Mark - # print(type(data._yaml_comment._items[0][0].start_mark)) - data.yaml_add_eol_comment("comment 2", key=3) - exp = """ - - a # comment 1 - - b - - c - - d # comment 2 - - e - - f # comment 3 - """ - compare(data, exp) - - def test_seq_set_comment_on_existing_explicit_column_with_hash(self): - data = load( - """ - - a # comment 1 - - b - - c - """ - ) - data.yaml_add_eol_comment("# comment 2", key=1, column=6) - exp = """ - - a # comment 1 - - b # comment 2 - - c - """ - compare(data, exp) - - # dict - - def test_dict_set_comment_on_existing_explicit_column(self): - data = load( - """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 - e: 5 - """ - ) - data.yaml_add_eol_comment("comment 2", key="c", column=7) - exp = """ - a: 1 # comment 1 - b: 2 - c: 3 # comment 2 - d: 4 - e: 5 - """ - compare(data, exp) - - def test_dict_overwrite_comment_on_existing_explicit_column(self): - data = load( - """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 - e: 5 - """ - ) - data.yaml_add_eol_comment("comment 2", key="a", column=7) - exp = """ - a: 1 # comment 2 - b: 2 - c: 3 - d: 4 - e: 5 - """ - compare(data, exp) - - def test_map_set_comment_on_existing_column_prev(self): - data = load( - """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 - e: 5 # comment 3 - """ - ) - data.yaml_add_eol_comment("comment 2", key="b") - exp = """ - a: 1 # comment 1 - b: 2 # comment 2 - c: 3 - d: 4 - e: 5 # comment 3 - """ - compare(data, exp) - - def test_map_set_comment_on_existing_column_next(self): - data = load( - """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 - e: 5 # comment 3 - """ - ) - data.yaml_add_eol_comment("comment 2", key="d") - exp = """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 # comment 2 - e: 5 # comment 3 - """ - compare(data, exp) - - def test_map_set_comment_on_existing_column_further_away(self): - """ - no comment line before or after, take the latest before - the new position - """ - data = load( - """ - a: 1 # comment 1 - b: 2 - c: 3 - d: 4 - e: 5 # comment 3 - """ - ) - data.yaml_add_eol_comment("comment 2", key="c") - print(round_trip_dump(data)) - exp = """ - a: 1 # comment 1 - b: 2 - c: 3 # comment 2 - d: 4 - e: 5 # comment 3 - """ - compare(data, exp) - - def test_before_top_map_rt(self): - data = load( - """ - a: 1 - b: 2 - """ - ) - data.yaml_set_start_comment("Hello\nWorld\n") - exp = """ - # Hello - # World - a: 1 - b: 2 - """ - compare(data, exp.format(comment="#")) - - def test_before_top_map_replace(self): - data = load( - """ - # abc - # def - a: 1 # 1 - b: 2 - """ - ) - data.yaml_set_start_comment("Hello\nWorld\n") - exp = """ - # Hello - # World - a: 1 # 1 - b: 2 - """ - compare(data, exp.format(comment="#")) - - def test_before_top_map_from_scratch(self): - from srsly.ruamel_yaml.comments import CommentedMap - - data = CommentedMap() - data["a"] = 1 - data["b"] = 2 - data.yaml_set_start_comment("Hello\nWorld\n") - # print(data.ca) - # print(data.ca._items) - exp = """ - # Hello - # World - a: 1 - b: 2 - """ - compare(data, exp.format(comment="#")) - - def test_before_top_seq_rt(self): - data = load( - """ - - a - - b - """ - ) - data.yaml_set_start_comment("Hello\nWorld\n") - print(round_trip_dump(data)) - exp = """ - # Hello - # World - - a - - b - """ - compare(data, exp) - - def test_before_top_seq_rt_replace(self): - s = """ - # this - # that - - a - - b - """ - data = load(s.format(comment="#")) - data.yaml_set_start_comment("Hello\nWorld\n") - print(round_trip_dump(data)) - exp = """ - # Hello - # World - - a - - b - """ - compare(data, exp.format(comment="#")) - - def test_before_top_seq_from_scratch(self): - from srsly.ruamel_yaml.comments import CommentedSeq - - data = CommentedSeq() - data.append("a") - data.append("b") - data.yaml_set_start_comment("Hello\nWorld\n") - print(round_trip_dump(data)) - exp = """ - # Hello - # World - - a - - b - """ - compare(data, exp.format(comment="#")) - - # nested variants - def test_before_nested_map_rt(self): - data = load( - """ - a: 1 - b: - c: 2 - d: 3 - """ - ) - data["b"].yaml_set_start_comment("Hello\nWorld\n") - exp = """ - a: 1 - b: - # Hello - # World - c: 2 - d: 3 - """ - compare(data, exp.format(comment="#")) - - def test_before_nested_map_rt_indent(self): - data = load( - """ - a: 1 - b: - c: 2 - d: 3 - """ - ) - data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2) - exp = """ - a: 1 - b: - # Hello - # World - c: 2 - d: 3 - """ - compare(data, exp.format(comment="#")) - print(data["b"].ca) - - def test_before_nested_map_from_scratch(self): - from srsly.ruamel_yaml.comments import CommentedMap - - data = CommentedMap() - datab = CommentedMap() - data["a"] = 1 - data["b"] = datab - datab["c"] = 2 - datab["d"] = 3 - data["b"].yaml_set_start_comment("Hello\nWorld\n") - exp = """ - a: 1 - b: - # Hello - # World - c: 2 - d: 3 - """ - compare(data, exp.format(comment="#")) - - def test_before_nested_seq_from_scratch(self): - from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq - - data = CommentedMap() - datab = CommentedSeq() - data["a"] = 1 - data["b"] = datab - datab.append("c") - datab.append("d") - data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2) - exp = """ - a: 1 - b: - # Hello - # World - - c - - d - """ - compare(data, exp.format(comment="#")) - - def test_before_nested_seq_from_scratch_block_seq_indent(self): - from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq - - data = CommentedMap() - datab = CommentedSeq() - data["a"] = 1 - data["b"] = datab - datab.append("c") - datab.append("d") - data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2) - exp = """ - a: 1 - b: - # Hello - # World - - c - - d - """ - compare(data, exp.format(comment="#"), indent=4, block_seq_indent=2) - - def test_map_set_comment_before_and_after_non_first_key_00(self): - # http://stackoverflow.com/a/40705671/1307905 - data = load( - """ - xyz: - a: 1 # comment 1 - b: 2 - - test1: - test2: - test3: 3 - """ - ) - data.yaml_set_comment_before_after_key( - "test1", "before test1 (top level)", after="before test2" - ) - data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4) - exp = """ - xyz: - a: 1 # comment 1 - b: 2 - - # before test1 (top level) - test1: - # before test2 - test2: - # after test2 - test3: 3 - """ - compare(data, exp) - - def Xtest_map_set_comment_before_and_after_non_first_key_01(self): - data = load( - """ - xyz: - a: 1 # comment 1 - b: 2 - - test1: - test2: - test3: 3 - """ - ) - data.yaml_set_comment_before_after_key( - "test1", "before test1 (top level)", after="before test2\n\n" - ) - data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4) - # EOL is needed here as dedenting gets rid of spaces (as well as does Emacs - exp = """ - xyz: - a: 1 # comment 1 - b: 2 - - # before test1 (top level) - test1: - # before test2 - EOL - test2: - # after test2 - test3: 3 - """ - compare_eol(data, exp) - - # EOL is no longer necessary - # fixed together with issue # 216 - def test_map_set_comment_before_and_after_non_first_key_01(self): - data = load( - """ - xyz: - a: 1 # comment 1 - b: 2 - - test1: - test2: - test3: 3 - """ - ) - data.yaml_set_comment_before_after_key( - "test1", "before test1 (top level)", after="before test2\n\n" - ) - data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4) - exp = """ - xyz: - a: 1 # comment 1 - b: 2 - - # before test1 (top level) - test1: - # before test2 - - test2: - # after test2 - test3: 3 - """ - compare(data, exp) - - def Xtest_map_set_comment_before_and_after_non_first_key_02(self): - data = load( - """ - xyz: - a: 1 # comment 1 - b: 2 - - test1: - test2: - test3: 3 - """ - ) - data.yaml_set_comment_before_after_key( - "test1", - "xyz\n\nbefore test1 (top level)", - after="\nbefore test2", - after_indent=4, - ) - data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4) - # EOL is needed here as dedenting gets rid of spaces (as well as does Emacs - exp = """ - xyz: - a: 1 # comment 1 - b: 2 - - # xyz - - # before test1 (top level) - test1: - EOL - # before test2 - test2: - # after test2 - test3: 3 - """ - compare_eol(data, exp) - - def test_map_set_comment_before_and_after_non_first_key_02(self): - data = load( - """ - xyz: - a: 1 # comment 1 - b: 2 - - test1: - test2: - test3: 3 - """ - ) - data.yaml_set_comment_before_after_key( - "test1", - "xyz\n\nbefore test1 (top level)", - after="\nbefore test2", - after_indent=4, - ) - data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4) - exp = """ - xyz: - a: 1 # comment 1 - b: 2 - - # xyz - - # before test1 (top level) - test1: - - # before test2 - test2: - # after test2 - test3: 3 - """ - compare(data, exp) diff --git a/srsly/tests/ruamel_yaml/test_comments.py b/srsly/tests/ruamel_yaml/test_comments.py deleted file mode 100755 index b65198d..0000000 --- a/srsly/tests/ruamel_yaml/test_comments.py +++ /dev/null @@ -1,968 +0,0 @@ -# coding: utf-8 - -""" -comment testing is all about roundtrips -these can be done in the "old" way by creating a file.data and file.roundtrip -but there is little flexibility in doing that - -but some things are not easily tested, eog. how a -roundtrip changes - -""" - -import pytest -import sys - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump - - -class TestComments: - def test_no_end_of_file_eol(self): - """not excluding comments caused some problems if at the end of - the file without a newline. First error, then included \0 """ - x = """\ - - europe: 10 # abc""" - round_trip(x, extra="\n") - with pytest.raises(AssertionError): - round_trip(x, extra="a\n") - - def test_no_comments(self): - round_trip( - """ - - europe: 10 - - usa: - - ohio: 2 - - california: 9 - """ - ) - - def test_round_trip_ordering(self): - round_trip( - """ - a: 1 - b: 2 - c: 3 - b1: 2 - b2: 2 - d: 4 - e: 5 - f: 6 - """ - ) - - def test_complex(self): - round_trip( - """ - - europe: 10 # top - - usa: - - ohio: 2 - - california: 9 # o - """ - ) - - def test_dropped(self): - s = """\ - # comment - scalar - ... - """ - round_trip(s, "scalar\n...\n") - - def test_main_mapping_begin_end(self): - round_trip( - """ - # C start a - # C start b - abc: 1 - ghi: 2 - klm: 3 - # C end a - # C end b - """ - ) - - def test_reindent(self): - x = """\ - a: - b: # comment 1 - c: 1 # comment 2 - """ - d = round_trip_load(x) - y = round_trip_dump(d, indent=4) - assert y == dedent( - """\ - a: - b: # comment 1 - c: 1 # comment 2 - """ - ) - - def test_main_mapping_begin_end_items_post(self): - round_trip( - """ - # C start a - # C start b - abc: 1 # abc comment - ghi: 2 - klm: 3 # klm comment - # C end a - # C end b - """ - ) - - def test_main_sequence_begin_end(self): - round_trip( - """ - # C start a - # C start b - - abc - - ghi - - klm - # C end a - # C end b - """ - ) - - def test_main_sequence_begin_end_items_post(self): - round_trip( - """ - # C start a - # C start b - - abc # abc comment - - ghi - - klm # klm comment - # C end a - # C end b - """ - ) - - def test_main_mapping_begin_end_complex(self): - round_trip( - """ - # C start a - # C start b - abc: 1 - ghi: 2 - klm: - 3a: alpha - 3b: beta # it is all greek to me - # C end a - # C end b - """ - ) - - def test_09(self): # 2.9 from the examples in the spec - s = """\ - hr: # 1998 hr ranking - - Mark McGwire - - Sammy Sosa - rbi: - # 1998 rbi ranking - - Sammy Sosa - - Ken Griffey - """ - round_trip(s, indent=4, block_seq_indent=2) - - def test_09a(self): - round_trip( - """ - hr: # 1998 hr ranking - - Mark McGwire - - Sammy Sosa - rbi: - # 1998 rbi ranking - - Sammy Sosa - - Ken Griffey - """ - ) - - def test_simple_map_middle_comment(self): - round_trip( - """ - abc: 1 - # C 3a - # C 3b - ghi: 2 - """ - ) - - def test_map_in_map_0(self): - round_trip( - """ - map1: # comment 1 - # comment 2 - map2: - key1: val1 - """ - ) - - def test_map_in_map_1(self): - # comment is moved from value to key - round_trip( - """ - map1: - # comment 1 - map2: - key1: val1 - """ - ) - - def test_application_arguments(self): - # application configur - round_trip( - """ - args: - username: anthon - passwd: secret - fullname: Anthon van der Neut - tmux: - session-name: test - loop: - wait: 10 - """ - ) - - def test_substitute(self): - x = """ - args: - username: anthon # name - passwd: secret # password - fullname: Anthon van der Neut - tmux: - session-name: test - loop: - wait: 10 - """ - data = round_trip_load(x) - data["args"]["passwd"] = "deleted password" - # note the requirement to add spaces for alignment of comment - x = x.replace(": secret ", ": deleted password") - assert round_trip_dump(data) == dedent(x) - - def test_set_comment(self): - round_trip( - """ - !!set - # the beginning - ? a - # next one is B (lowercase) - ? b # You see? Promised you. - ? c - # this is the end - """ - ) - - def test_omap_comment_roundtrip(self): - round_trip( - """ - !!omap - - a: 1 - - b: 2 # two - - c: 3 # three - - d: 4 - """ - ) - - def test_omap_comment_roundtrip_pre_comment(self): - round_trip( - """ - !!omap - - a: 1 - - b: 2 # two - - c: 3 # three - # last one - - d: 4 - """ - ) - - def test_non_ascii(self): - round_trip( - """ - verbosity: 1 # 0 is minimal output, -1 none - base_url: http://gopher.net - special_indices: [1, 5, 8] - also_special: - - a - - 19 - - 32 - asia and europe: &asia_europe - Turkey: Ankara - Russia: Moscow - countries: - Asia: - <<: *asia_europe - Japan: Tokyo # 東京 - Europe: - <<: *asia_europe - Spain: Madrid - Italy: Rome - """ - ) - - def test_dump_utf8(self): - import srsly.ruamel_yaml # NOQA - - x = dedent( - """\ - ab: - - x # comment - - y # more comment - """ - ) - data = round_trip_load(x) - dumper = srsly.ruamel_yaml.RoundTripDumper - for utf in [True, False]: - y = srsly.ruamel_yaml.dump( - data, default_flow_style=False, Dumper=dumper, allow_unicode=utf - ) - assert y == x - - def test_dump_unicode_utf8(self): - import srsly.ruamel_yaml # NOQA - - x = dedent( - u"""\ - ab: - - x # comment - - y # more comment - """ - ) - data = round_trip_load(x) - dumper = srsly.ruamel_yaml.RoundTripDumper - for utf in [True, False]: - y = srsly.ruamel_yaml.dump( - data, default_flow_style=False, Dumper=dumper, allow_unicode=utf - ) - assert y == x - - def test_mlget_00(self): - x = """\ - a: - - b: - c: 42 - - d: - f: 196 - e: - g: 3.14 - """ - d = round_trip_load(x) - assert d.mlget(["a", 1, "d", "f"], list_ok=True) == 196 - with pytest.raises(AssertionError): - d.mlget(["a", 1, "d", "f"]) == 196 - - -class TestInsertPopList: - """list insertion is more complex than dict insertion, as you - need to move the values to subsequent keys on insert""" - - @property - def ins(self): - return """\ - ab: - - a # a - - b # b - - c - - d # d - - de: - - 1 - - 2 - """ - - def test_insert_0(self): - d = round_trip_load(self.ins) - d["ab"].insert(0, "xyz") - y = round_trip_dump(d, indent=2) - assert y == dedent( - """\ - ab: - - xyz - - a # a - - b # b - - c - - d # d - - de: - - 1 - - 2 - """ - ) - - def test_insert_1(self): - d = round_trip_load(self.ins) - d["ab"].insert(4, "xyz") - y = round_trip_dump(d, indent=2) - assert y == dedent( - """\ - ab: - - a # a - - b # b - - c - - d # d - - - xyz - de: - - 1 - - 2 - """ - ) - - def test_insert_2(self): - d = round_trip_load(self.ins) - d["ab"].insert(1, "xyz") - y = round_trip_dump(d, indent=2) - assert y == dedent( - """\ - ab: - - a # a - - xyz - - b # b - - c - - d # d - - de: - - 1 - - 2 - """ - ) - - def test_pop_0(self): - d = round_trip_load(self.ins) - d["ab"].pop(0) - y = round_trip_dump(d, indent=2) - print(y) - assert y == dedent( - """\ - ab: - - b # b - - c - - d # d - - de: - - 1 - - 2 - """ - ) - - def test_pop_1(self): - d = round_trip_load(self.ins) - d["ab"].pop(1) - y = round_trip_dump(d, indent=2) - print(y) - assert y == dedent( - """\ - ab: - - a # a - - c - - d # d - - de: - - 1 - - 2 - """ - ) - - def test_pop_2(self): - d = round_trip_load(self.ins) - d["ab"].pop(2) - y = round_trip_dump(d, indent=2) - print(y) - assert y == dedent( - """\ - ab: - - a # a - - b # b - - d # d - - de: - - 1 - - 2 - """ - ) - - def test_pop_3(self): - d = round_trip_load(self.ins) - d["ab"].pop(3) - y = round_trip_dump(d, indent=2) - print(y) - assert y == dedent( - """\ - ab: - - a # a - - b # b - - c - de: - - 1 - - 2 - """ - ) - - -# inspired by demux' question on stackoverflow -# http://stackoverflow.com/a/36970608/1307905 -class TestInsertInMapping: - @property - def ins(self): - return """\ - first_name: Art - occupation: Architect # This is an occupation comment - about: Art Vandelay is a fictional character that George invents... - """ - - def test_insert_at_pos_1(self): - d = round_trip_load(self.ins) - d.insert(1, "last name", "Vandelay", comment="new key") - y = round_trip_dump(d) - print(y) - assert y == dedent( - """\ - first_name: Art - last name: Vandelay # new key - occupation: Architect # This is an occupation comment - about: Art Vandelay is a fictional character that George invents... - """ - ) - - def test_insert_at_pos_0(self): - d = round_trip_load(self.ins) - d.insert(0, "last name", "Vandelay", comment="new key") - y = round_trip_dump(d) - print(y) - assert y == dedent( - """\ - last name: Vandelay # new key - first_name: Art - occupation: Architect # This is an occupation comment - about: Art Vandelay is a fictional character that George invents... - """ - ) - - def test_insert_at_pos_3(self): - # much more simple if done with appending. - d = round_trip_load(self.ins) - d.insert(3, "last name", "Vandelay", comment="new key") - y = round_trip_dump(d) - print(y) - assert y == dedent( - """\ - first_name: Art - occupation: Architect # This is an occupation comment - about: Art Vandelay is a fictional character that George invents... - last name: Vandelay # new key - """ - ) - - -class TestCommentedMapMerge: - def test_in_operator(self): - data = round_trip_load( - """ - x: &base - a: 1 - b: 2 - c: 3 - y: - <<: *base - k: 4 - l: 5 - """ - ) - assert data["x"]["a"] == 1 - assert "a" in data["x"] - assert data["y"]["a"] == 1 - assert "a" in data["y"] - - def test_issue_60(self): - data = round_trip_load( - """ - x: &base - a: 1 - y: - <<: *base - """ - ) - assert data["x"]["a"] == 1 - assert data["y"]["a"] == 1 - if sys.version_info >= (3, 12): - assert str(data["y"]) == """ordereddict({'a': 1})""" - else: - assert str(data["y"]) == """ordereddict([('a', 1)])""" - - def test_issue_60_1(self): - data = round_trip_load( - """ - x: &base - a: 1 - y: - <<: *base - b: 2 - """ - ) - assert data["x"]["a"] == 1 - assert data["y"]["a"] == 1 - if sys.version_info >= (3, 12): - assert str(data["y"]) == """ordereddict({'b': 2, 'a': 1})""" - else: - assert str(data["y"]) == """ordereddict([('b', 2), ('a', 1)])""" - - -class TestEmptyLines: - # prompted by issue 46 from Alex Harvey - def test_issue_46(self): - yaml_str = dedent( - """\ - --- - # Please add key/value pairs in alphabetical order - - aws_s3_bucket: 'mys3bucket' - - jenkins_ad_credentials: - bind_name: 'CN=svc-AAA-BBB-T,OU=Example,DC=COM,DC=EXAMPLE,DC=Local' - bind_pass: 'xxxxyyyy{' - """ - ) - d = round_trip_load(yaml_str, preserve_quotes=True) - y = round_trip_dump(d, explicit_start=True) - assert yaml_str == y - - def test_multispace_map(self): - round_trip( - """ - a: 1x - - b: 2x - - - c: 3x - - - - d: 4x - - """ - ) - - @pytest.mark.xfail(strict=True) - def test_multispace_map_initial(self): - round_trip( - """ - - a: 1x - - b: 2x - - - c: 3x - - - - d: 4x - - """ - ) - - def test_embedded_map(self): - round_trip( - """ - - a: 1y - b: 2y - - c: 3y - """ - ) - - def test_toplevel_seq(self): - round_trip( - """\ - - 1 - - - 2 - - - 3 - """ - ) - - def test_embedded_seq(self): - round_trip( - """ - a: - b: - - 1 - - - 2 - - - - 3 - """ - ) - - def test_line_with_only_spaces(self): - # issue 54 - yaml_str = "---\n\na: 'x'\n \nb: y\n" - d = round_trip_load(yaml_str, preserve_quotes=True) - y = round_trip_dump(d, explicit_start=True) - stripped = "" - for line in yaml_str.splitlines(): - stripped += line.rstrip() + "\n" - print(line + "$") - assert stripped == y - - def test_some_eol_spaces(self): - # spaces after tokens and on empty lines - yaml_str = '--- \n \na: "x" \n \nb: y \n' - d = round_trip_load(yaml_str, preserve_quotes=True) - y = round_trip_dump(d, explicit_start=True) - stripped = "" - for line in yaml_str.splitlines(): - stripped += line.rstrip() + "\n" - print(line + "$") - assert stripped == y - - def test_issue_54_not_ok(self): - yaml_str = dedent( - """\ - toplevel: - - # some comment - sublevel: 300 - """ - ) - d = round_trip_load(yaml_str) - print(d.ca) - y = round_trip_dump(d, indent=4) - print(y.replace("\n", "$\n")) - assert yaml_str == y - - def test_issue_54_ok(self): - yaml_str = dedent( - """\ - toplevel: - # some comment - sublevel: 300 - """ - ) - d = round_trip_load(yaml_str) - y = round_trip_dump(d, indent=4) - assert yaml_str == y - - def test_issue_93(self): - round_trip( - """\ - a: - b: - - c1: cat # a1 - # my comment on catfish - - c2: catfish # a2 - """ - ) - - def test_issue_93_00(self): - round_trip( - """\ - a: - - - c1: cat # a1 - # my comment on catfish - - c2: catfish # a2 - """ - ) - - def test_issue_93_01(self): - round_trip( - """\ - - - c1: cat # a1 - # my comment on catfish - - c2: catfish # a2 - """ - ) - - def test_issue_93_02(self): - # never failed as there is no indent - round_trip( - """\ - - c1: cat - # my comment on catfish - - c2: catfish - """ - ) - - def test_issue_96(self): - # inserted extra line on trailing spaces - round_trip( - """\ - a: - b: - c: c_val - d: - - e: - g: g_val - """ - ) - - -class TestUnicodeComments: - @pytest.mark.skipif(sys.version_info < (2, 7), reason="wide unicode") - def test_issue_55(self): # reported by Haraguroicha Hsu - round_trip( - """\ - name: TEST - description: test using - author: Harguroicha - sql: - command: |- - select name from testtbl where no = :no - - ci-test: - - :no: 04043709 # 小花 - - :no: 05161690 # 茶 - - :no: 05293147 # 〇𤋥川 - - :no: 05338777 # 〇〇啓 - - :no: 05273867 # 〇 - - :no: 05205786 # 〇𤦌 - """ - ) - - -class TestEmptyValueBeforeComments: - def test_issue_25a(self): - round_trip( - """\ - - a: b - c: d - d: # foo - - e: f - """ - ) - - def test_issue_25a1(self): - round_trip( - """\ - - a: b - c: d - d: # foo - e: f - """ - ) - - def test_issue_25b(self): - round_trip( - """\ - var1: #empty - var2: something #notempty - """ - ) - - def test_issue_25c(self): - round_trip( - """\ - params: - a: 1 # comment a - b: # comment b - c: 3 # comment c - """ - ) - - def test_issue_25c1(self): - round_trip( - """\ - params: - a: 1 # comment a - b: # comment b - # extra - c: 3 # comment c - """ - ) - - def test_issue_25_00(self): - round_trip( - """\ - params: - a: 1 # comment a - b: # comment b - """ - ) - - def test_issue_25_01(self): - round_trip( - """\ - a: # comment 1 - # comment 2 - - b: # comment 3 - c: 1 # comment 4 - """ - ) - - def test_issue_25_02(self): - round_trip( - """\ - a: # comment 1 - # comment 2 - - b: 2 # comment 3 - """ - ) - - def test_issue_25_03(self): - s = """\ - a: # comment 1 - # comment 2 - - b: 2 # comment 3 - """ - round_trip(s, indent=4, block_seq_indent=2) - - def test_issue_25_04(self): - round_trip( - """\ - a: # comment 1 - # comment 2 - b: 1 # comment 3 - """ - ) - - def test_flow_seq_within_seq(self): - round_trip( - """\ - # comment 1 - - a - - b - # comment 2 - - c - - d - # comment 3 - - [e] - - f - # comment 4 - - [] - """ - ) - - -test_block_scalar_commented_line_template = """\ -y: p -# Some comment - -a: | - x -{}b: y -""" - - -class TestBlockScalarWithComments: - # issue 99 reported by Colm O'Connor - def test_scalar_with_comments(self): - import srsly.ruamel_yaml # NOQA - - for x in [ - "", - "\n", - "\n# Another comment\n", - "\n\n", - "\n\n# abc\n#xyz\n", - "\n\n# abc\n#xyz\n", - "# abc\n\n#xyz\n", - "\n\n # abc\n #xyz\n", - ]: - - commented_line = test_block_scalar_commented_line_template.format(x) - data = srsly.ruamel_yaml.round_trip_load(commented_line) - - assert srsly.ruamel_yaml.round_trip_dump(data) == commented_line diff --git a/srsly/tests/ruamel_yaml/test_contextmanager.py b/srsly/tests/ruamel_yaml/test_contextmanager.py deleted file mode 100755 index f94f11c..0000000 --- a/srsly/tests/ruamel_yaml/test_contextmanager.py +++ /dev/null @@ -1,118 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -""" -testing of anchors and the aliases referring to them -""" - -import sys -import pytest - - -single_doc = """\ -- a: 1 -- b: - - 2 - - 3 -""" - -single_data = [dict(a=1), dict(b=[2, 3])] - -multi_doc = """\ ---- -- abc -- xyz ---- -- a: 1 -- b: - - 2 - - 3 -""" - -multi_doc_data = [["abc", "xyz"], single_data] - - -def get_yaml(): - from srsly.ruamel_yaml import YAML - - return YAML() - - -class TestOldStyle: - def test_single_load(self): - d = get_yaml().load(single_doc) - print(d) - print(type(d[0])) - assert d == single_data - - def test_single_load_no_arg(self): - with pytest.raises(TypeError): - assert get_yaml().load() == single_data - - def test_multi_load(self): - data = list(get_yaml().load_all(multi_doc)) - assert data == multi_doc_data - - def test_single_dump(self, capsys): - get_yaml().dump(single_data, sys.stdout) - out, err = capsys.readouterr() - assert out == single_doc - - def test_multi_dump(self, capsys): - yaml = get_yaml() - yaml.explicit_start = True - yaml.dump_all(multi_doc_data, sys.stdout) - out, err = capsys.readouterr() - assert out == multi_doc - - -class TestContextManager: - def test_single_dump(self, capsys): - from srsly.ruamel_yaml import YAML - - with YAML(output=sys.stdout) as yaml: - yaml.dump(single_data) - out, err = capsys.readouterr() - print(err) - assert out == single_doc - - def test_multi_dump(self, capsys): - from srsly.ruamel_yaml import YAML - - with YAML(output=sys.stdout) as yaml: - yaml.explicit_start = True - yaml.dump(multi_doc_data[0]) - yaml.dump(multi_doc_data[1]) - - out, err = capsys.readouterr() - print(err) - assert out == multi_doc - - # input is not as simple with a context manager - # you need to indicate what you expect hence load and load_all - - # @pytest.mark.xfail(strict=True) - # def test_single_load(self): - # from srsly.ruamel_yaml import YAML - # with YAML(input=single_doc) as yaml: - # assert yaml.load() == single_data - # - # @pytest.mark.xfail(strict=True) - # def test_multi_load(self): - # from srsly.ruamel_yaml import YAML - # with YAML(input=multi_doc) as yaml: - # for idx, data in enumerate(yaml.load()): - # assert data == multi_doc_data[0] - - def test_roundtrip(self, capsys): - from srsly.ruamel_yaml import YAML - - with YAML(output=sys.stdout) as yaml: - yaml.explicit_start = True - for data in yaml.load_all(multi_doc): - yaml.dump(data) - - out, err = capsys.readouterr() - print(err) - assert out == multi_doc diff --git a/srsly/tests/ruamel_yaml/test_copy.py b/srsly/tests/ruamel_yaml/test_copy.py deleted file mode 100755 index c17cfb9..0000000 --- a/srsly/tests/ruamel_yaml/test_copy.py +++ /dev/null @@ -1,135 +0,0 @@ -# coding: utf-8 - -""" -Testing copy and deepcopy, instigated by Issue 84 (Peter Amstutz) -""" - -import copy - -import pytest # NOQA - -from .roundtrip import dedent, round_trip_load, round_trip_dump - - -class TestDeepCopy: - def test_preserve_flow_style_simple(self): - x = dedent( - """\ - {foo: bar, baz: quux} - """ - ) - data = round_trip_load(x) - data_copy = copy.deepcopy(data) - y = round_trip_dump(data_copy) - print("x [{}]".format(x)) - print("y [{}]".format(y)) - assert y == x - assert data.fa.flow_style() == data_copy.fa.flow_style() - - def test_deepcopy_flow_style_nested_dict(self): - x = dedent( - """\ - a: {foo: bar, baz: quux} - """ - ) - data = round_trip_load(x) - assert data["a"].fa.flow_style() is True - data_copy = copy.deepcopy(data) - assert data_copy["a"].fa.flow_style() is True - data_copy["a"].fa.set_block_style() - assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style() - assert data["a"].fa._flow_style is True - assert data_copy["a"].fa._flow_style is False - y = round_trip_dump(data_copy) - - print("x [{}]".format(x)) - print("y [{}]".format(y)) - assert y == dedent( - """\ - a: - foo: bar - baz: quux - """ - ) - - def test_deepcopy_flow_style_nested_list(self): - x = dedent( - """\ - a: [1, 2, 3] - """ - ) - data = round_trip_load(x) - assert data["a"].fa.flow_style() is True - data_copy = copy.deepcopy(data) - assert data_copy["a"].fa.flow_style() is True - data_copy["a"].fa.set_block_style() - assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style() - assert data["a"].fa._flow_style is True - assert data_copy["a"].fa._flow_style is False - y = round_trip_dump(data_copy) - - print("x [{}]".format(x)) - print("y [{}]".format(y)) - assert y == dedent( - """\ - a: - - 1 - - 2 - - 3 - """ - ) - - -class TestCopy: - def test_copy_flow_style_nested_dict(self): - x = dedent( - """\ - a: {foo: bar, baz: quux} - """ - ) - data = round_trip_load(x) - assert data["a"].fa.flow_style() is True - data_copy = copy.copy(data) - assert data_copy["a"].fa.flow_style() is True - data_copy["a"].fa.set_block_style() - assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style() - assert data["a"].fa._flow_style is False - assert data_copy["a"].fa._flow_style is False - y = round_trip_dump(data_copy) - z = round_trip_dump(data) - assert y == z - - assert y == dedent( - """\ - a: - foo: bar - baz: quux - """ - ) - - def test_copy_flow_style_nested_list(self): - x = dedent( - """\ - a: [1, 2, 3] - """ - ) - data = round_trip_load(x) - assert data["a"].fa.flow_style() is True - data_copy = copy.copy(data) - assert data_copy["a"].fa.flow_style() is True - data_copy["a"].fa.set_block_style() - assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style() - assert data["a"].fa._flow_style is False - assert data_copy["a"].fa._flow_style is False - y = round_trip_dump(data_copy) - - print("x [{}]".format(x)) - print("y [{}]".format(y)) - assert y == dedent( - """\ - a: - - 1 - - 2 - - 3 - """ - ) diff --git a/srsly/tests/ruamel_yaml/test_datetime.py b/srsly/tests/ruamel_yaml/test_datetime.py deleted file mode 100755 index db88ef8..0000000 --- a/srsly/tests/ruamel_yaml/test_datetime.py +++ /dev/null @@ -1,157 +0,0 @@ -# coding: utf-8 - -""" -http://yaml.org/type/timestamp.html specifies the regexp to use -for datetime.date and datetime.datetime construction. Date is simple -but datetime can have 'T' or 't' as well as 'Z' or a timezone offset (in -hours and minutes). This information was originally used to create -a UTC datetime and then discarded - -examples from the above: - -canonical: 2001-12-15T02:59:43.1Z -valid iso8601: 2001-12-14t21:59:43.10-05:00 -space separated: 2001-12-14 21:59:43.10 -5 -no time zone (Z): 2001-12-15 2:59:43.10 -date (00:00:00Z): 2002-12-14 - -Please note that a fraction can only be included if not equal to 0 - -""" - -import copy -import pytest # NOQA - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -class TestDateTime: - def test_date_only(self): - inp = """ - - 2011-10-02 - """ - exp = """ - - 2011-10-02 - """ - round_trip(inp, exp) - - def test_zero_fraction(self): - inp = """ - - 2011-10-02 16:45:00.0 - """ - exp = """ - - 2011-10-02 16:45:00 - """ - round_trip(inp, exp) - - def test_long_fraction(self): - inp = """ - - 2011-10-02 16:45:00.1234 # expand with zeros - - 2011-10-02 16:45:00.123456 - - 2011-10-02 16:45:00.12345612 # round to microseconds - - 2011-10-02 16:45:00.1234565 # round up - - 2011-10-02 16:45:00.12345678 # round up - """ - exp = """ - - 2011-10-02 16:45:00.123400 # expand with zeros - - 2011-10-02 16:45:00.123456 - - 2011-10-02 16:45:00.123456 # round to microseconds - - 2011-10-02 16:45:00.123457 # round up - - 2011-10-02 16:45:00.123457 # round up - """ - round_trip(inp, exp) - - def test_canonical(self): - inp = """ - - 2011-10-02T16:45:00.1Z - """ - exp = """ - - 2011-10-02T16:45:00.100000Z - """ - round_trip(inp, exp) - - def test_spaced_timezone(self): - inp = """ - - 2011-10-02T11:45:00 -5 - """ - exp = """ - - 2011-10-02T11:45:00-5 - """ - round_trip(inp, exp) - - def test_normal_timezone(self): - round_trip( - """ - - 2011-10-02T11:45:00-5 - - 2011-10-02 11:45:00-5 - - 2011-10-02T11:45:00-05:00 - - 2011-10-02 11:45:00-05:00 - """ - ) - - def test_no_timezone(self): - inp = """ - - 2011-10-02 6:45:00 - """ - exp = """ - - 2011-10-02 06:45:00 - """ - round_trip(inp, exp) - - def test_explicit_T(self): - inp = """ - - 2011-10-02T16:45:00 - """ - exp = """ - - 2011-10-02T16:45:00 - """ - round_trip(inp, exp) - - def test_explicit_t(self): # to upper - inp = """ - - 2011-10-02t16:45:00 - """ - exp = """ - - 2011-10-02T16:45:00 - """ - round_trip(inp, exp) - - def test_no_T_multi_space(self): - inp = """ - - 2011-10-02 16:45:00 - """ - exp = """ - - 2011-10-02 16:45:00 - """ - round_trip(inp, exp) - - def test_iso(self): - round_trip( - """ - - 2011-10-02T15:45:00+01:00 - """ - ) - - def test_zero_tz(self): - round_trip( - """ - - 2011-10-02T15:45:00+0 - """ - ) - - def test_issue_45(self): - round_trip( - """ - dt: 2016-08-19T22:45:47Z - """ - ) - - def test_deepcopy_datestring(self): - # reported by Quuxplusone, http://stackoverflow.com/a/41577841/1307905 - x = dedent( - """\ - foo: 2016-10-12T12:34:56 - """ - ) - data = copy.deepcopy(round_trip_load(x)) - assert round_trip_dump(data) == x diff --git a/srsly/tests/ruamel_yaml/test_deprecation.py b/srsly/tests/ruamel_yaml/test_deprecation.py deleted file mode 100755 index 14acd71..0000000 --- a/srsly/tests/ruamel_yaml/test_deprecation.py +++ /dev/null @@ -1,13 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -import sys -import pytest # NOQA - - -@pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 9), - reason='collections not available?') -def test_collections_deprecation(): - with pytest.warns(DeprecationWarning): - from collections import Hashable # NOQA diff --git a/srsly/tests/ruamel_yaml/test_documents.py b/srsly/tests/ruamel_yaml/test_documents.py deleted file mode 100755 index cad6ff2..0000000 --- a/srsly/tests/ruamel_yaml/test_documents.py +++ /dev/null @@ -1,79 +0,0 @@ -# coding: utf-8 - -import pytest # NOQA - -from .roundtrip import round_trip, round_trip_load_all - - -class TestDocument: - def test_single_doc_begin_end(self): - inp = """\ - --- - - a - - b - ... - """ - round_trip(inp, explicit_start=True, explicit_end=True) - - def test_multi_doc_begin_end(self): - from srsly.ruamel_yaml import dump_all, RoundTripDumper - - inp = """\ - --- - - a - ... - --- - - b - ... - """ - docs = list(round_trip_load_all(inp)) - assert docs == [["a"], ["b"]] - out = dump_all( - docs, Dumper=RoundTripDumper, explicit_start=True, explicit_end=True - ) - assert out == "---\n- a\n...\n---\n- b\n...\n" - - def test_multi_doc_no_start(self): - inp = """\ - - a - ... - --- - - b - ... - """ - docs = list(round_trip_load_all(inp)) - assert docs == [["a"], ["b"]] - - def test_multi_doc_no_end(self): - inp = """\ - - a - --- - - b - """ - docs = list(round_trip_load_all(inp)) - assert docs == [["a"], ["b"]] - - def test_multi_doc_ends_only(self): - # this is ok in 1.2 - inp = """\ - - a - ... - - b - ... - """ - docs = list(round_trip_load_all(inp, version=(1, 2))) - assert docs == [["a"], ["b"]] - - def test_multi_doc_ends_only_1_1(self): - from srsly.ruamel_yaml import parser - - # this is not ok in 1.1 - with pytest.raises(parser.ParserError): - inp = """\ - - a - ... - - b - ... - """ - docs = list(round_trip_load_all(inp, version=(1, 1))) - assert docs == [["a"], ["b"]] # not True, but not reached diff --git a/srsly/tests/ruamel_yaml/test_fail.py b/srsly/tests/ruamel_yaml/test_fail.py deleted file mode 100755 index 02cef0b..0000000 --- a/srsly/tests/ruamel_yaml/test_fail.py +++ /dev/null @@ -1,255 +0,0 @@ -# coding: utf-8 - -# there is some work to do -# provide a failing test xyz and a non-failing xyz_no_fail ( to see -# what the current failing output is. -# on fix of srsly.ruamel_yaml, move the marked test to the appropriate test (without mark) -# and remove remove the xyz_no_fail - -import pytest - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump - - -class TestCommentFailures: - @pytest.mark.xfail(strict=True) - def test_set_comment_before_tag(self): - # no comments before tags - round_trip( - """ - # the beginning - !!set - # or this one? - ? a - # next one is B (lowercase) - ? b # You see? Promised you. - ? c - # this is the end - """ - ) - - def test_set_comment_before_tag_no_fail(self): - # no comments before tags - inp = """ - # the beginning - !!set - # or this one? - ? a - # next one is B (lowercase) - ? b # You see? Promised you. - ? c - # this is the end - """ - assert round_trip_dump(round_trip_load(inp)) == dedent( - """ - !!set - # or this one? - ? a - # next one is B (lowercase) - ? b # You see? Promised you. - ? c - # this is the end - """ - ) - - @pytest.mark.xfail(strict=True) - def test_comment_dash_line(self): - round_trip( - """ - - # abc - a: 1 - b: 2 - """ - ) - - def test_comment_dash_line_fail(self): - x = """ - - # abc - a: 1 - b: 2 - """ - data = round_trip_load(x) - # this is not nice - assert round_trip_dump(data) == dedent( - """ - # abc - - a: 1 - b: 2 - """ - ) - - -class TestIndentFailures: - @pytest.mark.xfail(strict=True) - def test_indent_not_retained(self): - round_trip( - """ - verbosity: 1 # 0 is minimal output, -1 none - base_url: http://gopher.net - special_indices: [1, 5, 8] - also_special: - - a - - 19 - - 32 - asia and europe: &asia_europe - Turkey: Ankara - Russia: Moscow - countries: - Asia: - <<: *asia_europe - Japan: Tokyo # 東京 - Europe: - <<: *asia_europe - Spain: Madrid - Italy: Rome - Antarctica: - - too cold - """ - ) - - def test_indent_not_retained_no_fail(self): - inp = """ - verbosity: 1 # 0 is minimal output, -1 none - base_url: http://gopher.net - special_indices: [1, 5, 8] - also_special: - - a - - 19 - - 32 - asia and europe: &asia_europe - Turkey: Ankara - Russia: Moscow - countries: - Asia: - <<: *asia_europe - Japan: Tokyo # 東京 - Europe: - <<: *asia_europe - Spain: Madrid - Italy: Rome - Antarctica: - - too cold - """ - assert round_trip_dump(round_trip_load(inp), indent=4) == dedent( - """ - verbosity: 1 # 0 is minimal output, -1 none - base_url: http://gopher.net - special_indices: [1, 5, 8] - also_special: - - a - - 19 - - 32 - asia and europe: &asia_europe - Turkey: Ankara - Russia: Moscow - countries: - Asia: - <<: *asia_europe - Japan: Tokyo # 東京 - Europe: - <<: *asia_europe - Spain: Madrid - Italy: Rome - Antarctica: - - too cold - """ - ) - - def Xtest_indent_top_level_no_fail(self): - inp = """ - - a: - - b - """ - round_trip(inp, indent=4) - - -class TestTagFailures: - @pytest.mark.xfail(strict=True) - def test_standard_short_tag(self): - round_trip( - """\ - !!map - name: Anthon - location: Germany - language: python - """ - ) - - def test_standard_short_tag_no_fail(self): - inp = """ - !!map - name: Anthon - location: Germany - language: python - """ - exp = """ - name: Anthon - location: Germany - language: python - """ - assert round_trip_dump(round_trip_load(inp)) == dedent(exp) - - -class TestFlowValues: - def test_flow_value_with_colon(self): - inp = """\ - {a: bcd:efg} - """ - round_trip(inp) - - def test_flow_value_with_colon_quoted(self): - inp = """\ - {a: 'bcd:efg'} - """ - round_trip(inp, preserve_quotes=True) - - -class TestMappingKey: - def test_simple_mapping_key(self): - inp = """\ - {a: 1, b: 2}: hello world - """ - round_trip(inp, preserve_quotes=True, dump_data=False) - - def test_set_simple_mapping_key(self): - from srsly.ruamel_yaml.comments import CommentedKeyMap - - d = {CommentedKeyMap([("a", 1), ("b", 2)]): "hello world"} - exp = dedent( - """\ - {a: 1, b: 2}: hello world - """ - ) - assert round_trip_dump(d) == exp - - def test_change_key_simple_mapping_key(self): - from srsly.ruamel_yaml.comments import CommentedKeyMap - - inp = """\ - {a: 1, b: 2}: hello world - """ - d = round_trip_load(inp, preserve_quotes=True) - d[CommentedKeyMap([("b", 1), ("a", 2)])] = d.pop( - CommentedKeyMap([("a", 1), ("b", 2)]) - ) - exp = dedent( - """\ - {b: 1, a: 2}: hello world - """ - ) - assert round_trip_dump(d) == exp - - def test_change_value_simple_mapping_key(self): - from srsly.ruamel_yaml.comments import CommentedKeyMap - - inp = """\ - {a: 1, b: 2}: hello world - """ - d = round_trip_load(inp, preserve_quotes=True) - d = {CommentedKeyMap([("a", 1), ("b", 2)]): "goodbye"} - exp = dedent( - """\ - {a: 1, b: 2}: goodbye - """ - ) - assert round_trip_dump(d) == exp diff --git a/srsly/tests/ruamel_yaml/test_float.py b/srsly/tests/ruamel_yaml/test_float.py deleted file mode 100755 index 4cffc68..0000000 --- a/srsly/tests/ruamel_yaml/test_float.py +++ /dev/null @@ -1,92 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -import pytest # NOQA - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - -# http://yaml.org/type/int.html is where underscores in integers are defined - - -class TestFloat: - def test_round_trip_non_exp(self): - data = round_trip( - """\ - - 1.0 - - 1.00 - - 23.100 - - -1.0 - - -1.00 - - -23.100 - - 42. - - -42. - - +42. - - .5 - - +.5 - - -.5 - """ - ) - print(data) - assert 0.999 < data[0] < 1.001 - assert 0.999 < data[1] < 1.001 - assert 23.099 < data[2] < 23.101 - assert 0.999 < -data[3] < 1.001 - assert 0.999 < -data[4] < 1.001 - assert 23.099 < -data[5] < 23.101 - assert 41.999 < data[6] < 42.001 - assert 41.999 < -data[7] < 42.001 - assert 41.999 < data[8] < 42.001 - assert 0.49 < data[9] < 0.51 - assert 0.49 < data[10] < 0.51 - assert -0.51 < data[11] < -0.49 - - def test_round_trip_zeros_0(self): - data = round_trip( - """\ - - 0. - - +0. - - -0. - - 0.0 - - +0.0 - - -0.0 - - 0.00 - - +0.00 - - -0.00 - """ - ) - print(data) - for d in data: - assert -0.00001 < d < 0.00001 - - def Xtest_round_trip_non_exp_trailing_dot(self): - data = round_trip( - """\ - """ - ) - print(data) - - def test_yaml_1_1_no_dot(self): - from srsly.ruamel_yaml.error import MantissaNoDotYAML1_1Warning - - with pytest.warns(MantissaNoDotYAML1_1Warning): - round_trip_load( - """\ - %YAML 1.1 - --- - - 1e6 - """ - ) - - -class TestCalculations(object): - def test_mul_00(self): - # issue 149 reported by jan.brezina@tul.cz - d = round_trip_load( - """\ - - 0.1 - """ - ) - d[0] *= -1 - x = round_trip_dump(d) - assert x == "- -0.1\n" diff --git a/srsly/tests/ruamel_yaml/test_flowsequencekey.py b/srsly/tests/ruamel_yaml/test_flowsequencekey.py deleted file mode 100755 index 8362bec..0000000 --- a/srsly/tests/ruamel_yaml/test_flowsequencekey.py +++ /dev/null @@ -1,25 +0,0 @@ -# coding: utf-8 - -""" -test flow style sequences as keys roundtrip - -""" - -# import pytest - -from .roundtrip import round_trip # , dedent, round_trip_load, round_trip_dump - - -class TestFlowStyleSequenceKey: - def test_so_39595807(self): - inp = """\ - %YAML 1.2 - --- - [2, 3, 4]: - a: - - 1 - - 2 - b: Hello World! - c: 'Voilà!' - """ - round_trip(inp, preserve_quotes=True, explicit_start=True, version=(1, 2)) diff --git a/srsly/tests/ruamel_yaml/test_indentation.py b/srsly/tests/ruamel_yaml/test_indentation.py deleted file mode 100755 index ef1057d..0000000 --- a/srsly/tests/ruamel_yaml/test_indentation.py +++ /dev/null @@ -1,365 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals - - -import pytest # NOQA - -from .roundtrip import round_trip, round_trip_load, round_trip_dump, dedent, YAML - - -def rt(s): - import srsly.ruamel_yaml - - res = srsly.ruamel_yaml.dump( - srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader), - Dumper=srsly.ruamel_yaml.RoundTripDumper, - ) - return res.strip() + "\n" - - -class TestIndent: - def test_roundtrip_inline_list(self): - s = "a: [a, b, c]\n" - output = rt(s) - assert s == output - - def test_roundtrip_mapping_of_inline_lists(self): - s = dedent( - """\ - a: [a, b, c] - j: [k, l, m] - """ - ) - output = rt(s) - assert s == output - - def test_roundtrip_mapping_of_inline_lists_comments(self): - s = dedent( - """\ - # comment A - a: [a, b, c] - # comment B - j: [k, l, m] - """ - ) - output = rt(s) - assert s == output - - def test_roundtrip_mapping_of_inline_sequence_eol_comments(self): - s = dedent( - """\ - # comment A - a: [a, b, c] # comment B - j: [k, l, m] # comment C - """ - ) - output = rt(s) - assert s == output - - # first test by explicitly setting flow style - def test_added_inline_list(self): - import srsly.ruamel_yaml - - s1 = dedent( - """ - a: - - b - - c - - d - """ - ) - s = "a: [b, c, d]\n" - data = srsly.ruamel_yaml.load(s1, Loader=srsly.ruamel_yaml.RoundTripLoader) - val = data["a"] - val.fa.set_flow_style() - # print(type(val), '_yaml_format' in dir(val)) - output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper) - assert s == output - - # ############ flow mappings - - def test_roundtrip_flow_mapping(self): - import srsly.ruamel_yaml - - s = dedent( - """\ - - {a: 1, b: hallo} - - {j: fka, k: 42} - """ - ) - data = srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader) - output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper) - assert s == output - - def test_roundtrip_sequence_of_inline_mappings_eol_comments(self): - s = dedent( - """\ - # comment A - - {a: 1, b: hallo} # comment B - - {j: fka, k: 42} # comment C - """ - ) - output = rt(s) - assert s == output - - def test_indent_top_level(self): - inp = """ - - a: - - b - """ - round_trip(inp, indent=4) - - def test_set_indent_5_block_list_indent_1(self): - inp = """ - a: - - b: c - - 1 - - d: - - 2 - """ - round_trip(inp, indent=5, block_seq_indent=1) - - def test_set_indent_4_block_list_indent_2(self): - inp = """ - a: - - b: c - - 1 - - d: - - 2 - """ - round_trip(inp, indent=4, block_seq_indent=2) - - def test_set_indent_3_block_list_indent_0(self): - inp = """ - a: - - b: c - - 1 - - d: - - 2 - """ - round_trip(inp, indent=3, block_seq_indent=0) - - def Xtest_set_indent_3_block_list_indent_2(self): - inp = """ - a: - - - b: c - - - 1 - - - d: - - - 2 - """ - round_trip(inp, indent=3, block_seq_indent=2) - - def test_set_indent_3_block_list_indent_2(self): - inp = """ - a: - - b: c - - 1 - - d: - - 2 - """ - round_trip(inp, indent=3, block_seq_indent=2) - - def Xtest_set_indent_2_block_list_indent_2(self): - inp = """ - a: - - - b: c - - - 1 - - - d: - - - 2 - """ - round_trip(inp, indent=2, block_seq_indent=2) - - # this is how it should be: block_seq_indent stretches the indent - def test_set_indent_2_block_list_indent_2(self): - inp = """ - a: - - b: c - - 1 - - d: - - 2 - """ - round_trip(inp, indent=2, block_seq_indent=2) - - # have to set indent! - def test_roundtrip_four_space_indents(self): - # fmt: off - s = ( - 'a:\n' - '- foo\n' - '- bar\n' - ) - # fmt: on - round_trip(s, indent=4) - - def test_roundtrip_four_space_indents_no_fail(self): - inp = """ - a: - - foo - - bar - """ - exp = """ - a: - - foo - - bar - """ - assert round_trip_dump(round_trip_load(inp)) == dedent(exp) - - -class TestYpkgIndent: - def test_00(self): - inp = """ - name : nano - version : 2.3.2 - release : 1 - homepage : http://www.nano-editor.org - source : - - http://www.nano-editor.org/dist/v2.3/nano-2.3.2.tar.gz : ff30924807ea289f5b60106be8 - license : GPL-2.0 - summary : GNU nano is an easy-to-use text editor - builddeps : - - ncurses-devel - description: | - GNU nano is an easy-to-use text editor originally designed - as a replacement for Pico, the ncurses-based editor from the non-free mailer - package Pine (itself now available under the Apache License as Alpine). - """ - round_trip( - inp, - indent=4, - block_seq_indent=2, - top_level_colon_align=True, - prefix_colon=" ", - ) - - -def guess(s): - from srsly.ruamel_yaml.util import load_yaml_guess_indent - - x, y, z = load_yaml_guess_indent(dedent(s)) - return y, z - - -class TestGuessIndent: - def test_guess_20(self): - inp = """\ - a: - - 1 - """ - assert guess(inp) == (2, 0) - - def test_guess_42(self): - inp = """\ - a: - - 1 - """ - assert guess(inp) == (4, 2) - - def test_guess_42a(self): - # block seq indent prevails over nested key indent level - inp = """\ - b: - a: - - 1 - """ - assert guess(inp) == (4, 2) - - def test_guess_3None(self): - inp = """\ - b: - a: 1 - """ - assert guess(inp) == (3, None) - - -class TestSeparateMapSeqIndents: - # using uncommon 6 indent with 3 push in as 2 push in automatically - # gets you 4 indent even if not set - def test_00(self): - # old style - yaml = YAML() - yaml.indent = 6 - yaml.block_seq_indent = 3 - inp = """ - a: - - 1 - - [1, 2] - """ - yaml.round_trip(inp) - - def test_01(self): - yaml = YAML() - yaml.indent(sequence=6) - yaml.indent(offset=3) - inp = """ - a: - - 1 - - {b: 3} - """ - yaml.round_trip(inp) - - def test_02(self): - yaml = YAML() - yaml.indent(mapping=5, sequence=6, offset=3) - inp = """ - a: - b: - - 1 - - [1, 2] - """ - yaml.round_trip(inp) - - def test_03(self): - inp = """ - a: - b: - c: - - 1 - - [1, 2] - """ - round_trip(inp, indent=4) - - def test_04(self): - yaml = YAML() - yaml.indent(mapping=5, sequence=6) - inp = """ - a: - b: - - 1 - - [1, 2] - - {d: 3.14} - """ - yaml.round_trip(inp) - - def test_issue_51(self): - yaml = YAML() - # yaml.map_indent = 2 # the default - yaml.indent(sequence=4, offset=2) - yaml.preserve_quotes = True - yaml.round_trip( - """ - role::startup::author::rsyslog_inputs: - imfile: - - ruleset: 'AEM-slinglog' - File: '/opt/aem/author/crx-quickstart/logs/error.log' - startmsg.regex: '^[-+T.:[:digit:]]*' - tag: 'error' - - ruleset: 'AEM-slinglog' - File: '/opt/aem/author/crx-quickstart/logs/stdout.log' - startmsg.regex: '^[-+T.:[:digit:]]*' - tag: 'stdout' - """ - ) - - -# ############ indentation diff --git a/srsly/tests/ruamel_yaml/test_int.py b/srsly/tests/ruamel_yaml/test_int.py deleted file mode 100755 index 6635fcf..0000000 --- a/srsly/tests/ruamel_yaml/test_int.py +++ /dev/null @@ -1,36 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -import pytest # NOQA - -from .roundtrip import dedent, round_trip_load, round_trip_dump - -# http://yaml.org/type/int.html is where underscores in integers are defined - - -class TestBinHexOct: - def test_calculate(self): - # make sure type, leading zero(s) and underscore are preserved - s = dedent( - """\ - - 42 - - 0b101010 - - 0x_2a - - 0x2A - - 0o00_52 - """ - ) - d = round_trip_load(s) - for idx, elem in enumerate(d): - elem -= 21 - d[idx] = elem - for idx, elem in enumerate(d): - elem *= 2 - d[idx] = elem - for idx, elem in enumerate(d): - t = elem - elem **= 2 - elem //= t - d[idx] = elem - assert round_trip_dump(d) == s diff --git a/srsly/tests/ruamel_yaml/test_issues.py b/srsly/tests/ruamel_yaml/test_issues.py deleted file mode 100755 index 11dfbb7..0000000 --- a/srsly/tests/ruamel_yaml/test_issues.py +++ /dev/null @@ -1,971 +0,0 @@ -# coding: utf-8 - -from __future__ import absolute_import, print_function, unicode_literals - - -import pytest # NOQA -import sys - - -from .roundtrip import ( - round_trip, - na_round_trip, - round_trip_load, - round_trip_dump, - dedent, - save_and_run, - YAML, -) # NOQA - - -class TestIssues: - def test_issue_61(self): - import srsly.ruamel_yaml - - s = dedent( - """ - def1: &ANCHOR1 - key1: value1 - def: &ANCHOR - <<: *ANCHOR1 - key: value - comb: - <<: *ANCHOR - """ - ) - data = srsly.ruamel_yaml.round_trip_load(s) - assert str(data["comb"]) == str(data["def"]) - if sys.version_info >= (3, 12): - assert ( - str(data["comb"]) == "ordereddict({'key': 'value', 'key1': 'value1'})" - ) - else: - assert ( - str(data["comb"]) == "ordereddict([('key', 'value'), ('key1', 'value1')])" - ) - - def test_issue_82(self, tmpdir): - program_src = r''' - from __future__ import print_function - - import srsly.ruamel_yaml as yaml - - import re - - - class SINumber(yaml.YAMLObject): - PREFIXES = {'k': 1e3, 'M': 1e6, 'G': 1e9} - yaml_loader = yaml.Loader - yaml_dumper = yaml.Dumper - yaml_tag = u'!si' - yaml_implicit_pattern = re.compile( - r'^(?P[0-9]+(?:\.[0-9]+)?)(?P[kMG])$') - - @classmethod - def from_yaml(cls, loader, node): - return cls(node.value) - - @classmethod - def to_yaml(cls, dumper, data): - return dumper.represent_scalar(cls.yaml_tag, str(data)) - - def __init__(self, *args): - m = self.yaml_implicit_pattern.match(args[0]) - self.value = float(m.groupdict()['value']) - self.prefix = m.groupdict()['prefix'] - - def __str__(self): - return str(self.value)+self.prefix - - def __int__(self): - return int(self.value*self.PREFIXES[self.prefix]) - - # This fails: - yaml.add_implicit_resolver(SINumber.yaml_tag, SINumber.yaml_implicit_pattern) - - ret = yaml.load(""" - [1,2,3, !si 10k, 100G] - """, Loader=yaml.Loader) - for idx, l in enumerate([1, 2, 3, 10000, 100000000000]): - assert int(ret[idx]) == l - ''' - assert save_and_run(dedent(program_src), tmpdir) == 1 - - def test_issue_82rt(self, tmpdir): - yaml_str = "[1, 2, 3, !si 10k, 100G]\n" - x = round_trip(yaml_str, preserve_quotes=True) # NOQA - - def test_issue_102(self): - yaml_str = dedent( - """ - var1: #empty - var2: something #notempty - var3: {} #empty object - var4: {a: 1} #filled object - var5: [] #empty array - """ - ) - x = round_trip(yaml_str, preserve_quotes=True) # NOQA - - def test_issue_150(self): - from srsly.ruamel_yaml import YAML - - inp = """\ - base: &base_key - first: 123 - second: 234 - - child: - <<: *base_key - third: 345 - """ - yaml = YAML() - data = yaml.load(inp) - child = data["child"] - assert "second" in dict(**child) - - def test_issue_160(self): - from srsly.ruamel_yaml.compat import StringIO - - s = dedent( - """\ - root: - # a comment - - {some_key: "value"} - - foo: 32 - bar: 32 - """ - ) - a = round_trip_load(s) - del a["root"][0]["some_key"] - buf = StringIO() - round_trip_dump(a, buf, block_seq_indent=4) - exp = dedent( - """\ - root: - # a comment - - {} - - foo: 32 - bar: 32 - """ - ) - assert buf.getvalue() == exp - - def test_issue_161(self): - yaml_str = dedent( - """\ - mapping-A: - key-A:{} - mapping-B: - """ - ) - for comment in ["", " # no-newline", " # some comment\n", "\n"]: - s = yaml_str.format(comment) - res = round_trip(s) # NOQA - - def test_issue_161a(self): - yaml_str = dedent( - """\ - mapping-A: - key-A:{} - mapping-B: - """ - ) - for comment in ["\n# between"]: - s = yaml_str.format(comment) - res = round_trip(s) # NOQA - - def test_issue_163(self): - s = dedent( - """\ - some-list: - # List comment - - {} - """ - ) - x = round_trip(s, preserve_quotes=True) # NOQA - - json_str = ( - r'{"sshKeys":[{"name":"AETROS\/google-k80-1","uses":0,"getLastUse":0,' - '"fingerprint":"MD5:19:dd:41:93:a1:a3:f5:91:4a:8e:9b:d0:ae:ce:66:4c",' - '"created":1509497961}]}' - ) - - json_str2 = '{"abc":[{"a":"1", "uses":0}]}' - - def test_issue_172(self): - x = round_trip_load(TestIssues.json_str2) # NOQA - x = round_trip_load(TestIssues.json_str) # NOQA - - def test_issue_176(self): - # basic request by Stuart Berg - from srsly.ruamel_yaml import YAML - - yaml = YAML() - seq = yaml.load("[1,2,3]") - seq[:] = [1, 2, 3, 4] - - def test_issue_176_preserve_comments_on_extended_slice_assignment(self): - yaml_str = dedent( - """\ - - a - - b # comment - - c # commment c - # comment c+ - - d - - - e # comment - """ - ) - seq = round_trip_load(yaml_str) - seq[1::2] = ["B", "D"] - res = round_trip_dump(seq) - assert res == yaml_str.replace(" b ", " B ").replace(" d\n", " D\n") - - def test_issue_176_test_slicing(self): - from srsly.ruamel_yaml.compat import PY2 - - mss = round_trip_load("[0, 1, 2, 3, 4]") - assert len(mss) == 5 - assert mss[2:2] == [] - assert mss[2:4] == [2, 3] - assert mss[1::2] == [1, 3] - - # slice assignment - m = mss[:] - m[2:2] = [42] - assert m == [0, 1, 42, 2, 3, 4] - - m = mss[:] - m[:3] = [42, 43, 44] - assert m == [42, 43, 44, 3, 4] - m = mss[:] - m[2:] = [42, 43, 44] - assert m == [0, 1, 42, 43, 44] - m = mss[:] - m[:] = [42, 43, 44] - assert m == [42, 43, 44] - - # extend slice assignment - m = mss[:] - m[2:4] = [42, 43, 44] - assert m == [0, 1, 42, 43, 44, 4] - m = mss[:] - m[1::2] = [42, 43] - assert m == [0, 42, 2, 43, 4] - m = mss[:] - if PY2: - with pytest.raises(ValueError, match="attempt to assign"): - m[1::2] = [42, 43, 44] - else: - with pytest.raises(TypeError, match="too many"): - m[1::2] = [42, 43, 44] - if PY2: - with pytest.raises(ValueError, match="attempt to assign"): - m[1::2] = [42] - else: - with pytest.raises(TypeError, match="not enough"): - m[1::2] = [42] - m = mss[:] - m += [5] - m[1::2] = [42, 43, 44] - assert m == [0, 42, 2, 43, 4, 44] - - # deleting - m = mss[:] - del m[1:3] - assert m == [0, 3, 4] - m = mss[:] - del m[::2] - assert m == [1, 3] - m = mss[:] - del m[:] - assert m == [] - - def test_issue_184(self): - yaml_str = dedent( - """\ - test::test: - # test - foo: - bar: baz - """ - ) - d = round_trip_load(yaml_str) - d["bar"] = "foo" - d.yaml_add_eol_comment("test1", "bar") - assert round_trip_dump(d) == yaml_str + "bar: foo # test1\n" - - def test_issue_219(self): - yaml_str = dedent( - """\ - [StackName: AWS::StackName] - """ - ) - d = round_trip_load(yaml_str) # NOQA - - def test_issue_219a(self): - yaml_str = dedent( - """\ - [StackName: - AWS::StackName] - """ - ) - d = round_trip_load(yaml_str) # NOQA - - def test_issue_220(self, tmpdir): - program_src = r''' - from srsly.ruamel_yaml import YAML - - yaml_str = u"""\ - --- - foo: ["bar"] - """ - - yaml = YAML(typ='safe', pure=True) - d = yaml.load(yaml_str) - print(d) - ''' - assert save_and_run(dedent(program_src), tmpdir, optimized=True) == 0 - - def test_issue_221_add(self): - from srsly.ruamel_yaml.comments import CommentedSeq - - a = CommentedSeq([1, 2, 3]) - a + [4, 5] - - def test_issue_221_sort(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - inp = dedent( - """\ - - d - - a # 1 - - c # 3 - - e # 5 - - b # 2 - """ - ) - a = yaml.load(dedent(inp)) - a.sort() - buf = StringIO() - yaml.dump(a, buf) - exp = dedent( - """\ - - a # 1 - - b # 2 - - c # 3 - - d - - e # 5 - """ - ) - assert buf.getvalue() == exp - - def test_issue_221_sort_reverse(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - inp = dedent( - """\ - - d - - a # 1 - - c # 3 - - e # 5 - - b # 2 - """ - ) - a = yaml.load(dedent(inp)) - a.sort(reverse=True) - buf = StringIO() - yaml.dump(a, buf) - exp = dedent( - """\ - - e # 5 - - d - - c # 3 - - b # 2 - - a # 1 - """ - ) - assert buf.getvalue() == exp - - def test_issue_221_sort_key(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - inp = dedent( - """\ - - four - - One # 1 - - Three # 3 - - five # 5 - - two # 2 - """ - ) - a = yaml.load(dedent(inp)) - a.sort(key=str.lower) - buf = StringIO() - yaml.dump(a, buf) - exp = dedent( - """\ - - five # 5 - - four - - One # 1 - - Three # 3 - - two # 2 - """ - ) - assert buf.getvalue() == exp - - def test_issue_221_sort_key_reverse(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - inp = dedent( - """\ - - four - - One # 1 - - Three # 3 - - five # 5 - - two # 2 - """ - ) - a = yaml.load(dedent(inp)) - a.sort(key=str.lower, reverse=True) - buf = StringIO() - yaml.dump(a, buf) - exp = dedent( - """\ - - two # 2 - - Three # 3 - - One # 1 - - four - - five # 5 - """ - ) - assert buf.getvalue() == exp - - def test_issue_222(self): - import srsly.ruamel_yaml - from srsly.ruamel_yaml.compat import StringIO - - buf = StringIO() - srsly.ruamel_yaml.safe_dump(["012923"], buf) - assert buf.getvalue() == "['012923']\n" - - def test_issue_223(self): - import srsly.ruamel_yaml - - yaml = srsly.ruamel_yaml.YAML(typ="safe") - yaml.load("phone: 0123456789") - - def test_issue_232(self): - import srsly.ruamel_yaml - import srsly.ruamel_yaml as yaml - - with pytest.raises(srsly.ruamel_yaml.parser.ParserError): - yaml.safe_load("]") - with pytest.raises(srsly.ruamel_yaml.parser.ParserError): - yaml.safe_load("{]") - - def test_issue_233(self): - from srsly.ruamel_yaml import YAML - import json - - yaml = YAML() - data = yaml.load("{}") - json_str = json.dumps(data) # NOQA - - def test_issue_233a(self): - from srsly.ruamel_yaml import YAML - import json - - yaml = YAML() - data = yaml.load("[]") - json_str = json.dumps(data) # NOQA - - def test_issue_234(self): - from srsly.ruamel_yaml import YAML - - inp = dedent( - """\ - - key: key1 - ctx: [one, two] - help: one - cmd: > - foo bar - foo bar - """ - ) - yaml = YAML(typ="safe", pure=True) - data = yaml.load(inp) - fold = data[0]["cmd"] - print(repr(fold)) - assert "\a" not in fold - - def test_issue_236(self): - inp = """ - conf: - xx: {a: "b", c: []} - asd: "nn" - """ - d = round_trip(inp, preserve_quotes=True) # NOQA - - def test_issue_238(self, tmpdir): - program_src = r""" - import srsly.ruamel_yaml - from srsly.ruamel_yaml.compat import StringIO - - yaml = srsly.ruamel_yaml.YAML(typ='unsafe') - - - class A: - def __setstate__(self, d): - self.__dict__ = d - - - class B: - pass - - - a = A() - b = B() - - a.x = b - b.y = [b] - assert a.x.y[0] == a.x - - buf = StringIO() - yaml.dump(a, buf) - - data = yaml.load(buf.getvalue()) - assert data.x.y[0] == data.x - """ - assert save_and_run(dedent(program_src), tmpdir) == 1 - - def test_issue_239(self): - inp = """ - first_name: Art - occupation: Architect - # I'm safe - about: Art Vandelay is a fictional character that George invents... - # we are not :( - # help me! - --- - # what?! - hello: world - # someone call the Batman - foo: bar # or quz - # Lost again - --- - I: knew - # final words - """ - d = YAML().round_trip_all(inp) # NOQA - - def test_issue_242(self): - from srsly.ruamel_yaml.comments import CommentedMap - - d0 = CommentedMap([("a", "b")]) - assert d0["a"] == "b" - - def test_issue_245(self): - from srsly.ruamel_yaml import YAML - - inp = """ - d: yes - """ - for typ in ["safepure", "rt", "safe"]: - if typ.endswith("pure"): - pure = True - typ = typ[:-4] - else: - pure = None - - yaml = YAML(typ=typ, pure=pure) - yaml.version = (1, 1) - d = yaml.load(inp) - print(typ, yaml.parser, yaml.resolver) - assert d["d"] is True - - def test_issue_249(self): - yaml = YAML() - inp = dedent( - """\ - # comment - - - - 1 - - 2 - - 3 - """ - ) - exp = dedent( - """\ - # comment - - - 1 - - 2 - - 3 - """ - ) - yaml.round_trip(inp, outp=exp) # NOQA - - def test_issue_250(self): - inp = """ - # 1. - - - 1 - # 2. - - map: 2 - # 3. - - 4 - """ - d = round_trip(inp) # NOQA - - # @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError) - def test_issue_279(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - yaml.indent(sequence=4, offset=2) - inp = dedent( - """\ - experiments: - - datasets: - # ATLAS EWK - - {dataset: ATLASWZRAP36PB, frac: 1.0} - - {dataset: ATLASZHIGHMASS49FB, frac: 1.0} - """ - ) - a = yaml.load(inp) - buf = StringIO() - yaml.dump(a, buf) - print(buf.getvalue()) - assert buf.getvalue() == inp - - def test_issue_280(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.representer import RepresenterError - from collections import namedtuple - from sys import stdout - - T = namedtuple("T", ("a", "b")) - t = T(1, 2) - yaml = YAML() - with pytest.raises(RepresenterError, match="cannot represent"): - yaml.dump({"t": t}, stdout) - - def test_issue_282(self): - # update from list of tuples caused AttributeError - import srsly.ruamel_yaml - - yaml_data = srsly.ruamel_yaml.comments.CommentedMap( - [("a", "apple"), ("b", "banana")] - ) - yaml_data.update([("c", "cantaloupe")]) - yaml_data.update({"d": "date", "k": "kiwi"}) - assert "c" in yaml_data.keys() - assert "c" in yaml_data._ok - - def test_issue_284(self): - import srsly.ruamel_yaml - - inp = dedent( - """\ - plain key: in-line value - : # Both empty - "quoted key": - - entry - """ - ) - yaml = srsly.ruamel_yaml.YAML(typ="rt") - yaml.version = (1, 2) - d = yaml.load(inp) - assert d[None] is None - - yaml = srsly.ruamel_yaml.YAML(typ="rt") - yaml.version = (1, 1) - with pytest.raises( - srsly.ruamel_yaml.parser.ParserError, match="expected " - ): - d = yaml.load(inp) - - def test_issue_285(self): - from srsly.ruamel_yaml import YAML - - yaml = YAML() - inp = dedent( - """\ - %YAML 1.1 - --- - - y - - n - - Y - - N - """ - ) - a = yaml.load(inp) - assert a[0] - assert a[2] - assert not a[1] - assert not a[3] - - def test_issue_286(self): - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - yaml = YAML() - inp = dedent( - """\ - parent_key: - - sub_key: sub_value - - # xxx""" - ) - a = yaml.load(inp) - a["new_key"] = "new_value" - buf = StringIO() - yaml.dump(a, buf) - assert buf.getvalue().endswith("xxx\nnew_key: new_value\n") - - def test_issue_288(self): - import sys - from srsly.ruamel_yaml.compat import StringIO - from srsly.ruamel_yaml import YAML - - yamldoc = dedent( - """\ - --- - # Reusable values - aliases: - # First-element comment - - &firstEntry First entry - # Second-element comment - - &secondEntry Second entry - - # Third-element comment is - # a multi-line value - - &thirdEntry Third entry - - # EOF Comment - """ - ) - - yaml = YAML() - yaml.indent(mapping=2, sequence=4, offset=2) - yaml.explicit_start = True - yaml.preserve_quotes = True - yaml.width = sys.maxsize - data = yaml.load(yamldoc) - buf = StringIO() - yaml.dump(data, buf) - assert buf.getvalue() == yamldoc - - def test_issue_288a(self): - import sys - from srsly.ruamel_yaml.compat import StringIO - from srsly.ruamel_yaml import YAML - - yamldoc = dedent( - """\ - --- - # Reusable values - aliases: - # First-element comment - - &firstEntry First entry - # Second-element comment - - &secondEntry Second entry - - # Third-element comment is - # a multi-line value - - &thirdEntry Third entry - - # EOF Comment - """ - ) - - yaml = YAML() - yaml.indent(mapping=2, sequence=4, offset=2) - yaml.explicit_start = True - yaml.preserve_quotes = True - yaml.width = sys.maxsize - data = yaml.load(yamldoc) - buf = StringIO() - yaml.dump(data, buf) - assert buf.getvalue() == yamldoc - - def test_issue_290(self): - import sys - from srsly.ruamel_yaml.compat import StringIO - from srsly.ruamel_yaml import YAML - - yamldoc = dedent( - """\ - --- - aliases: - # Folded-element comment - # for a multi-line value - - &FoldedEntry > - THIS IS A - FOLDED, MULTI-LINE - VALUE - - # Literal-element comment - # for a multi-line value - - &literalEntry | - THIS IS A - LITERAL, MULTI-LINE - VALUE - - # Plain-element comment - - &plainEntry Plain entry - """ - ) - - yaml = YAML() - yaml.indent(mapping=2, sequence=4, offset=2) - yaml.explicit_start = True - yaml.preserve_quotes = True - yaml.width = sys.maxsize - data = yaml.load(yamldoc) - buf = StringIO() - yaml.dump(data, buf) - assert buf.getvalue() == yamldoc - - def test_issue_290a(self): - import sys - from srsly.ruamel_yaml.compat import StringIO - from srsly.ruamel_yaml import YAML - - yamldoc = dedent( - """\ - --- - aliases: - # Folded-element comment - # for a multi-line value - - &FoldedEntry > - THIS IS A - FOLDED, MULTI-LINE - VALUE - - # Literal-element comment - # for a multi-line value - - &literalEntry | - THIS IS A - LITERAL, MULTI-LINE - VALUE - - # Plain-element comment - - &plainEntry Plain entry - """ - ) - - yaml = YAML() - yaml.indent(mapping=2, sequence=4, offset=2) - yaml.explicit_start = True - yaml.preserve_quotes = True - yaml.width = sys.maxsize - data = yaml.load(yamldoc) - buf = StringIO() - yaml.dump(data, buf) - assert buf.getvalue() == yamldoc - - # @pytest.mark.xfail(strict=True, reason='should fail pre 0.15.100', raises=AssertionError) - def test_issue_295(self): - # deepcopy also makes a copy of the start and end mark, and these did not - # have any comparison beyond their ID, which of course changed, breaking - # some old merge_comment code - import copy - - inp = dedent( - """ - A: - b: - # comment - - l1 - - l2 - - C: - d: e - f: - # comment2 - - - l31 - - l32 - - l33: '5' - """ - ) - data = round_trip_load(inp) # NOQA - dc = copy.deepcopy(data) - assert round_trip_dump(dc) == inp - - def test_issue_300(self): - from srsly.ruamel_yaml import YAML - - inp = dedent( - """ - %YAML 1.2 - %TAG ! tag:example.com,2019/path#fragment - --- - null - """ - ) - YAML().load(inp) - - def test_issue_300a(self): - import srsly.ruamel_yaml - - inp = dedent( - """ - %YAML 1.1 - %TAG ! tag:example.com,2019/path#fragment - --- - null - """ - ) - yaml = YAML() - with pytest.raises( - srsly.ruamel_yaml.scanner.ScannerError, match="while scanning a directive" - ): - yaml.load(inp) - - def test_issue_304(self): - inp = """ - %YAML 1.2 - %TAG ! tag:example.com,2019: - --- - !foo null - ... - """ - d = na_round_trip(inp) # NOQA - - def test_issue_305(self): - inp = """ - %YAML 1.2 - --- - ! null - ... - """ - d = na_round_trip(inp) # NOQA - - def test_issue_307(self): - inp = """ - %YAML 1.2 - %TAG ! tag:example.com,2019/path# - --- - null - ... - """ - d = na_round_trip(inp) # NOQA - - -# @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError) -# def test_issue_ xxx(self): -# inp = """ -# """ -# d = round_trip(inp) # NOQA diff --git a/srsly/tests/ruamel_yaml/test_json_numbers.py b/srsly/tests/ruamel_yaml/test_json_numbers.py deleted file mode 100755 index 0998e39..0000000 --- a/srsly/tests/ruamel_yaml/test_json_numbers.py +++ /dev/null @@ -1,57 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -import pytest # NOQA - -import json - - -def load(s, typ=float): - import srsly.ruamel_yaml - - x = '{"low": %s }' % (s) - print("input: [%s]" % (s), repr(x)) - # just to check it is loadable json - res = json.loads(x) - assert isinstance(res["low"], typ) - ret_val = srsly.ruamel_yaml.load(x, srsly.ruamel_yaml.RoundTripLoader) - print(ret_val) - return ret_val["low"] - - -class TestJSONNumbers: - # based on http://stackoverflow.com/a/30462009/1307905 - # yaml number regex: http://yaml.org/spec/1.2/spec.html#id2804092 - # - # -? [1-9] ( \. [0-9]* [1-9] )? ( e [-+] [1-9] [0-9]* )? - # - # which is not a superset of the JSON numbers - def test_json_number_float(self): - for x in ( - y.split("#")[0].strip() - for y in """ - 1.0 # should fail on YAML spec on 1-9 allowed as single digit - -1.0 - 1e-06 - 3.1e-5 - 3.1e+5 - 3.1e5 # should fail on YAML spec: no +- after e - """.splitlines() - ): - if not x: - continue - res = load(x) - assert isinstance(res, float) - - def test_json_number_int(self): - for x in ( - y.split("#")[0].strip() - for y in """ - 42 - """.splitlines() - ): - if not x: - continue - res = load(x, int) - assert isinstance(res, int) diff --git a/srsly/tests/ruamel_yaml/test_line_col.py b/srsly/tests/ruamel_yaml/test_line_col.py deleted file mode 100755 index 10bbd2e..0000000 --- a/srsly/tests/ruamel_yaml/test_line_col.py +++ /dev/null @@ -1,104 +0,0 @@ -# coding: utf-8 - -import pytest # NOQA - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -def load(s): - return round_trip_load(dedent(s)) - - -class TestLineCol: - def test_item_00(self): - data = load( - """ - - a - - e - - [b, d] - - c - """ - ) - assert data[2].lc.line == 2 - assert data[2].lc.col == 2 - - def test_item_01(self): - data = load( - """ - - a - - e - - {x: 3} - - c - """ - ) - assert data[2].lc.line == 2 - assert data[2].lc.col == 2 - - def test_item_02(self): - data = load( - """ - - a - - e - - !!set {x, y} - - c - """ - ) - assert data[2].lc.line == 2 - assert data[2].lc.col == 2 - - def test_item_03(self): - data = load( - """ - - a - - e - - !!omap - - x: 1 - - y: 3 - - c - """ - ) - assert data[2].lc.line == 2 - assert data[2].lc.col == 2 - - def test_item_04(self): - data = load( - """ - # testing line and column based on SO - # http://stackoverflow.com/questions/13319067/ - - key1: item 1 - key2: item 2 - - key3: another item 1 - key4: another item 2 - """ - ) - assert data[0].lc.line == 2 - assert data[0].lc.col == 2 - assert data[1].lc.line == 4 - assert data[1].lc.col == 2 - - def test_pos_mapping(self): - data = load( - """ - a: 1 - b: 2 - c: 3 - # comment - klm: 42 - d: 4 - """ - ) - assert data.lc.key("klm") == (4, 0) - assert data.lc.value("klm") == (4, 5) - - def test_pos_sequence(self): - data = load( - """ - - a - - b - - c - # next one! - - klm - - d - """ - ) - assert data.lc.item(3) == (4, 2) diff --git a/srsly/tests/ruamel_yaml/test_literal.py b/srsly/tests/ruamel_yaml/test_literal.py deleted file mode 100755 index a6c1624..0000000 --- a/srsly/tests/ruamel_yaml/test_literal.py +++ /dev/null @@ -1,335 +0,0 @@ -# coding: utf-8 -from __future__ import print_function - -import pytest # NOQA - -from .roundtrip import YAML # does an automatic dedent on load - - -""" -YAML 1.0 allowed root level literal style without indentation: - "Usually top level nodes are not indented" (example 4.21 in 4.6.3) -YAML 1.1 is a bit vague but says: - "Regardless of style, scalar content must always be indented by at least one space" - (4.4.3) - "In general, the document’s node is indented as if it has a parent indented at -1 spaces." - (4.3.3) -YAML 1.2 is again clear about root literal level scalar after directive in example 9.5: - -%YAML 1.2 ---- | -%!PS-Adobe-2.0 -... -%YAML1.2 ---- -# Empty -... -""" - - -class TestNoIndent: - def test_root_literal_scalar_indent_example_9_5(self): - yaml = YAML() - s = "%!PS-Adobe-2.0" - inp = """ - --- | - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_no_indent(self): - yaml = YAML() - s = "testing123" - inp = """ - --- | - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_no_indent_1_1(self): - yaml = YAML() - s = "testing123" - inp = """ - %YAML 1.1 - --- | - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_no_indent_1_1_old_style(self): - from textwrap import dedent - from srsly.ruamel_yaml import safe_load - - s = "testing123" - inp = """ - %YAML 1.1 - --- | - {} - """ - d = safe_load(dedent(inp.format(s))) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_no_indent_1_1_no_raise(self): - # from srsly.ruamel_yaml.parser import ParserError - - yaml = YAML() - yaml.root_level_block_style_scalar_no_indent_error_1_1 = True - s = "testing123" - # with pytest.raises(ParserError): - if True: - inp = """ - %YAML 1.1 - --- | - {} - """ - yaml.load(inp.format(s)) - - def test_root_literal_scalar_indent_offset_one(self): - yaml = YAML() - s = "testing123" - inp = """ - --- |1 - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_indent_offset_four(self): - yaml = YAML() - s = "testing123" - inp = """ - --- |4 - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_scalar_indent_offset_two_leading_space(self): - yaml = YAML() - s = " testing123" - inp = """ - --- |4 - {s} - {s} - """ - d = yaml.load(inp.format(s=s)) - print(d) - assert d == (s + "\n") * 2 - - def test_root_literal_scalar_no_indent_special(self): - yaml = YAML() - s = "%!PS-Adobe-2.0" - inp = """ - --- | - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_folding_scalar_indent(self): - yaml = YAML() - s = "%!PS-Adobe-2.0" - inp = """ - --- > - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_folding_scalar_no_indent(self): - yaml = YAML() - s = "testing123" - inp = """ - --- > - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_folding_scalar_no_indent_special(self): - yaml = YAML() - s = "%!PS-Adobe-2.0" - inp = """ - --- > - {} - """ - d = yaml.load(inp.format(s)) - print(d) - assert d == s + "\n" - - def test_root_literal_multi_doc(self): - yaml = YAML(typ="safe", pure=True) - s1 = "abc" - s2 = "klm" - inp = """ - --- |- - {} - --- | - {} - """ - for idx, d1 in enumerate(yaml.load_all(inp.format(s1, s2))): - print("d1:", d1) - assert ["abc", "klm\n"][idx] == d1 - - def test_root_literal_doc_indent_directives_end(self): - yaml = YAML() - yaml.explicit_start = True - inp = """ - --- |- - %YAML 1.3 - --- - this: is a test - """ - yaml.round_trip(inp) - - def test_root_literal_doc_indent_document_end(self): - yaml = YAML() - yaml.explicit_start = True - inp = """ - --- |- - some more - ... - text - """ - yaml.round_trip(inp) - - def test_root_literal_doc_indent_marker(self): - yaml = YAML() - yaml.explicit_start = True - inp = """ - --- |2 - some more - text - """ - d = yaml.load(inp) - print(type(d), repr(d)) - yaml.round_trip(inp) - - def test_nested_literal_doc_indent_marker(self): - yaml = YAML() - yaml.explicit_start = True - inp = """ - --- - a: |2 - some more - text - """ - d = yaml.load(inp) - print(type(d), repr(d)) - yaml.round_trip(inp) - - -class Test_RoundTripLiteral: - def test_rt_root_literal_scalar_no_indent(self): - yaml = YAML() - yaml.explicit_start = True - s = "testing123" - ys = """ - --- | - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_literal_scalar_indent(self): - yaml = YAML() - yaml.explicit_start = True - yaml.indent = 4 - s = "testing123" - ys = """ - --- | - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_plain_scalar_no_indent(self): - yaml = YAML() - yaml.explicit_start = True - yaml.indent = 0 - s = "testing123" - ys = """ - --- - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_plain_scalar_expl_indent(self): - yaml = YAML() - yaml.explicit_start = True - yaml.indent = 4 - s = "testing123" - ys = """ - --- - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_sq_scalar_expl_indent(self): - yaml = YAML() - yaml.explicit_start = True - yaml.indent = 4 - s = "'testing: 123'" - ys = """ - --- - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_dq_scalar_expl_indent(self): - # if yaml.indent is the default (None) - # then write after the directive indicator - yaml = YAML() - yaml.explicit_start = True - yaml.indent = 0 - s = '"\'testing123"' - ys = """ - --- - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_root_literal_scalar_no_indent_no_eol(self): - yaml = YAML() - yaml.explicit_start = True - s = "testing123" - ys = """ - --- |- - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) - - def test_rt_non_root_literal_scalar(self): - yaml = YAML() - s = "testing123" - ys = """ - - | - {} - """ - ys = ys.format(s) - d = yaml.load(ys) - yaml.dump(d, compare=ys) diff --git a/srsly/tests/ruamel_yaml/test_none.py b/srsly/tests/ruamel_yaml/test_none.py deleted file mode 100755 index 878c7e7..0000000 --- a/srsly/tests/ruamel_yaml/test_none.py +++ /dev/null @@ -1,53 +0,0 @@ -# coding: utf-8 - - -import pytest # NOQA - - -class TestNone: - def test_dump00(self): - import srsly.ruamel_yaml # NOQA - - data = None - s = srsly.ruamel_yaml.round_trip_dump(data) - assert s == "null\n...\n" - d = srsly.ruamel_yaml.round_trip_load(s) - assert d == data - - def test_dump01(self): - import srsly.ruamel_yaml # NOQA - - data = None - s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=True) - assert s == "null\n...\n" - d = srsly.ruamel_yaml.round_trip_load(s) - assert d == data - - def test_dump02(self): - import srsly.ruamel_yaml # NOQA - - data = None - s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=False) - assert s == "null\n...\n" - d = srsly.ruamel_yaml.round_trip_load(s) - assert d == data - - def test_dump03(self): - import srsly.ruamel_yaml # NOQA - - data = None - s = srsly.ruamel_yaml.round_trip_dump(data, explicit_start=True) - assert s == "---\n...\n" - d = srsly.ruamel_yaml.round_trip_load(s) - assert d == data - - def test_dump04(self): - import srsly.ruamel_yaml # NOQA - - data = None - s = srsly.ruamel_yaml.round_trip_dump( - data, explicit_start=True, explicit_end=False - ) - assert s == "---\n...\n" - d = srsly.ruamel_yaml.round_trip_load(s) - assert d == data diff --git a/srsly/tests/ruamel_yaml/test_numpy.py b/srsly/tests/ruamel_yaml/test_numpy.py deleted file mode 100755 index 2c21854..0000000 --- a/srsly/tests/ruamel_yaml/test_numpy.py +++ /dev/null @@ -1,24 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, absolute_import, division, unicode_literals - -try: - import numpy -except: # NOQA - numpy = None - - -def Xtest_numpy(): - import srsly.ruamel_yaml - - if numpy is None: - return - data = numpy.arange(10) - print("data", type(data), data) - - yaml_str = srsly.ruamel_yaml.dump(data) - datb = srsly.ruamel_yaml.load(yaml_str) - print("datb", type(datb), datb) - - print("\nYAML", yaml_str) - assert data == datb diff --git a/srsly/tests/ruamel_yaml/test_program_config.py b/srsly/tests/ruamel_yaml/test_program_config.py deleted file mode 100755 index 7e5331b..0000000 --- a/srsly/tests/ruamel_yaml/test_program_config.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest # NOQA - -# import srsly.ruamel_yaml -from .roundtrip import round_trip - - -class TestProgramConfig: - def test_application_arguments(self): - # application configur - round_trip( - """ - args: - username: anthon - passwd: secret - fullname: Anthon van der Neut - tmux: - session-name: test - loop: - wait: 10 - """ - ) - - def test_single(self): - # application configuration - round_trip( - """ - # default arguments for the program - args: # needed to prevent comment wrapping - # this should be your username - username: anthon - passwd: secret # this is plaintext don't reuse \ -# important/system passwords - fullname: Anthon van der Neut - tmux: - session-name: test # make sure this doesn't clash with - # other sessions - loop: # looping related defaults - # experiment with the following - wait: 10 - # no more argument info to pass - """ - ) - - def test_multi(self): - # application configuration - round_trip( - """ - # default arguments for the program - args: # needed to prevent comment wrapping - # this should be your username - username: anthon - passwd: secret # this is plaintext don't reuse - # important/system passwords - fullname: Anthon van der Neut - tmux: - session-name: test # make sure this doesn't clash with - # other sessions - loop: # looping related defaults - # experiment with the following - wait: 10 - # no more argument info to pass - """ - ) diff --git a/srsly/tests/ruamel_yaml/test_spec_examples.py b/srsly/tests/ruamel_yaml/test_spec_examples.py deleted file mode 100755 index 666d641..0000000 --- a/srsly/tests/ruamel_yaml/test_spec_examples.py +++ /dev/null @@ -1,334 +0,0 @@ -from .roundtrip import YAML -import pytest # NOQA - - -def test_example_2_1(): - yaml = YAML() - yaml.round_trip( - """ - - Mark McGwire - - Sammy Sosa - - Ken Griffey - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_2(): - yaml = YAML() - yaml.mapping_value_align = True - yaml.round_trip( - """ - hr: 65 # Home runs - avg: 0.278 # Batting average - rbi: 147 # Runs Batted In - """ - ) - - -def test_example_2_3(): - yaml = YAML() - yaml.indent(sequence=4, offset=2) - yaml.round_trip( - """ - american: - - Boston Red Sox - - Detroit Tigers - - New York Yankees - national: - - New York Mets - - Chicago Cubs - - Atlanta Braves - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_4(): - yaml = YAML() - yaml.mapping_value_align = True - yaml.round_trip( - """ - - - name: Mark McGwire - hr: 65 - avg: 0.278 - - - name: Sammy Sosa - hr: 63 - avg: 0.288 - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_5(): - yaml = YAML() - yaml.flow_sequence_element_align = True - yaml.round_trip( - """ - - [name , hr, avg ] - - [Mark McGwire, 65, 0.278] - - [Sammy Sosa , 63, 0.288] - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_6(): - yaml = YAML() - # yaml.flow_mapping_final_comma = False - yaml.flow_mapping_one_element_per_line = True - yaml.round_trip( - """ - Mark McGwire: {hr: 65, avg: 0.278} - Sammy Sosa: { - hr: 63, - avg: 0.288 - } - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_7(): - yaml = YAML() - yaml.round_trip_all( - """ - # Ranking of 1998 home runs - --- - - Mark McGwire - - Sammy Sosa - - Ken Griffey - - # Team ranking - --- - - Chicago Cubs - - St Louis Cardinals - """ - ) - - -def test_example_2_8(): - yaml = YAML() - yaml.explicit_start = True - yaml.explicit_end = True - yaml.round_trip_all( - """ - --- - time: 20:03:20 - player: Sammy Sosa - action: strike (miss) - ... - --- - time: 20:03:47 - player: Sammy Sosa - action: grand slam - ... - """ - ) - - -def test_example_2_9(): - yaml = YAML() - yaml.explicit_start = True - yaml.indent(sequence=4, offset=2) - yaml.round_trip( - """ - --- - hr: # 1998 hr ranking - - Mark McGwire - - Sammy Sosa - rbi: - # 1998 rbi ranking - - Sammy Sosa - - Ken Griffey - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_10(): - yaml = YAML() - yaml.explicit_start = True - yaml.indent(sequence=4, offset=2) - yaml.round_trip( - """ - --- - hr: - - Mark McGwire - # Following node labeled SS - - &SS Sammy Sosa - rbi: - - *SS # Subsequent occurrence - - Ken Griffey - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_11(): - yaml = YAML() - yaml.round_trip( - """ - ? - Detroit Tigers - - Chicago cubs - : - - 2001-07-23 - - ? [ New York Yankees, - Atlanta Braves ] - : [ 2001-07-02, 2001-08-12, - 2001-08-14 ] - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_12(): - yaml = YAML() - yaml.explicit_start = True - yaml.round_trip( - """ - --- - # Products purchased - - item : Super Hoop - quantity: 1 - - item : Basketball - quantity: 4 - - item : Big Shoes - quantity: 1 - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_13(): - yaml = YAML() - yaml.round_trip( - r""" - # ASCII Art - --- | - \//||\/|| - // || ||__ - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_14(): - yaml = YAML() - yaml.explicit_start = True - yaml.indent(root_scalar=2) # needs to be added - yaml.round_trip( - """ - --- > - Mark McGwire's - year was crippled - by a knee injury. - """ - ) - - -@pytest.mark.xfail(strict=True) -def test_example_2_15(): - yaml = YAML() - yaml.round_trip( - """ - > - Sammy Sosa completed another - fine season with great stats. - - 63 Home Runs - 0.288 Batting Average - - What a year! - """ - ) - - -def test_example_2_16(): - yaml = YAML() - yaml.round_trip( - """ - name: Mark McGwire - accomplishment: > - Mark set a major league - home run record in 1998. - stats: | - 65 Home Runs - 0.278 Batting Average - """ - ) - - -@pytest.mark.xfail( - strict=True, reason="cannot YAML dump escape sequences (\n) as hex and normal" -) -def test_example_2_17(): - yaml = YAML() - yaml.allow_unicode = False - yaml.preserve_quotes = True - yaml.round_trip( - r""" - unicode: "Sosa did fine.\u263A" - control: "\b1998\t1999\t2000\n" - hex esc: "\x0d\x0a is \r\n" - - single: '"Howdy!" he cried.' - quoted: ' # Not a ''comment''.' - tie-fighter: '|\-*-/|' - """ - ) - - -@pytest.mark.xfail( - strict=True, reason="non-literal/folding multiline scalars not supported" -) -def test_example_2_18(): - yaml = YAML() - yaml.round_trip( - """ - plain: - This unquoted scalar - spans many lines. - - quoted: "So does this - quoted scalar.\n" - """ - ) - - -@pytest.mark.xfail(strict=True, reason="leading + on decimal dropped") -def test_example_2_19(): - yaml = YAML() - yaml.round_trip( - """ - canonical: 12345 - decimal: +12345 - octal: 0o14 - hexadecimal: 0xC - """ - ) - - -@pytest.mark.xfail(strict=True, reason="case of NaN not preserved") -def test_example_2_20(): - yaml = YAML() - yaml.round_trip( - """ - canonical: 1.23015e+3 - exponential: 12.3015e+02 - fixed: 1230.15 - negative infinity: -.inf - not a number: .NaN - """ - ) - - -def Xtest_example_2_X(): - yaml = YAML() - yaml.round_trip( - """ - """ - ) diff --git a/srsly/tests/ruamel_yaml/test_string.py b/srsly/tests/ruamel_yaml/test_string.py deleted file mode 100755 index c82eb81..0000000 --- a/srsly/tests/ruamel_yaml/test_string.py +++ /dev/null @@ -1,229 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -""" -various test cases for string scalars in YAML files -'|' for preserved newlines -'>' for folded (newlines become spaces) - -and the chomping modifiers: -'-' for stripping: final line break and any trailing empty lines are excluded -'+' for keeping: final line break and empty lines are preserved -'' for clipping: final line break preserved, empty lines at end not - included in content (no modifier) - -""" - -import pytest -import platform - -# from srsly.ruamel_yaml.compat import ordereddict -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -class TestLiteralScalarString: - def test_basic_string(self): - round_trip( - """ - a: abcdefg - """ - ) - - def test_quoted_integer_string(self): - round_trip( - """ - a: '12345' - """ - ) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_preserve_string(self): - inp = """ - a: | - abc - def - """ - round_trip(inp, intermediate=dict(a="abc\ndef\n")) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_preserve_string_strip(self): - s = """ - a: |- - abc - def - - """ - round_trip(s, intermediate=dict(a="abc\ndef")) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_preserve_string_keep(self): - # with pytest.raises(AssertionError) as excinfo: - inp = """ - a: |+ - ghi - jkl - - - b: x - """ - round_trip(inp, intermediate=dict(a="ghi\njkl\n\n\n", b="x")) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_preserve_string_keep_at_end(self): - # at EOF you have to specify the ... to get proper "closure" - # of the multiline scalar - inp = """ - a: |+ - ghi - jkl - - ... - """ - round_trip(inp, intermediate=dict(a="ghi\njkl\n\n")) - - def test_fold_string(self): - inp = """ - a: > - abc - def - - """ - round_trip(inp) - - def test_fold_string_strip(self): - inp = """ - a: >- - abc - def - - """ - round_trip(inp) - - def test_fold_string_keep(self): - with pytest.raises(AssertionError) as excinfo: # NOQA - inp = """ - a: >+ - abc - def - - """ - round_trip(inp, intermediate=dict(a="abc def\n\n")) - - -class TestQuotedScalarString: - def test_single_quoted_string(self): - inp = """ - a: 'abc' - """ - round_trip(inp, preserve_quotes=True) - - def test_double_quoted_string(self): - inp = """ - a: "abc" - """ - round_trip(inp, preserve_quotes=True) - - def test_non_preserved_double_quoted_string(self): - inp = """ - a: "abc" - """ - exp = """ - a: abc - """ - round_trip(inp, outp=exp) - - -class TestReplace: - """inspired by issue 110 from sandres23""" - - def test_replace_preserved_scalar_string(self): - import srsly - - s = dedent( - """\ - foo: | - foo - foo - bar - foo - """ - ) - data = round_trip_load(s, preserve_quotes=True) - so = data["foo"].replace("foo", "bar", 2) - assert isinstance(so, srsly.ruamel_yaml.scalarstring.LiteralScalarString) - assert so == dedent( - """ - bar - bar - bar - foo - """ - ) - - def test_replace_double_quoted_scalar_string(self): - import srsly - - s = dedent( - """\ - foo: "foo foo bar foo" - """ - ) - data = round_trip_load(s, preserve_quotes=True) - so = data["foo"].replace("foo", "bar", 2) - assert isinstance(so, srsly.ruamel_yaml.scalarstring.DoubleQuotedScalarString) - assert so == "bar bar bar foo" - - -class TestWalkTree: - def test_basic(self): - from srsly.ruamel_yaml.comments import CommentedMap - from srsly.ruamel_yaml.scalarstring import walk_tree - - data = CommentedMap() - data[1] = "a" - data[2] = "with\nnewline\n" - walk_tree(data) - exp = """\ - 1: a - 2: | - with - newline - """ - assert round_trip_dump(data) == dedent(exp) - - def test_map(self): - from srsly.ruamel_yaml.compat import ordereddict - from srsly.ruamel_yaml.comments import CommentedMap - from srsly.ruamel_yaml.scalarstring import walk_tree, preserve_literal - from srsly.ruamel_yaml.scalarstring import DoubleQuotedScalarString as dq - from srsly.ruamel_yaml.scalarstring import SingleQuotedScalarString as sq - - data = CommentedMap() - data[1] = "a" - data[2] = "with\nnew : line\n" - data[3] = "${abc}" - data[4] = "almost:mapping" - m = ordereddict([("\n", preserve_literal), ("${", sq), (":", dq)]) - walk_tree(data, map=m) - exp = """\ - 1: a - 2: | - with - new : line - 3: '${abc}' - 4: "almost:mapping" - """ - assert round_trip_dump(data) == dedent(exp) diff --git a/srsly/tests/ruamel_yaml/test_tag.py b/srsly/tests/ruamel_yaml/test_tag.py deleted file mode 100755 index 4e4860c..0000000 --- a/srsly/tests/ruamel_yaml/test_tag.py +++ /dev/null @@ -1,171 +0,0 @@ -# coding: utf-8 - -import pytest # NOQA - -from .roundtrip import round_trip, round_trip_load, YAML - - -def register_xxx(**kw): - import srsly.ruamel_yaml as yaml - - class XXX(yaml.comments.CommentedMap): - @staticmethod - def yaml_dump(dumper, data): - return dumper.represent_mapping(u"!xxx", data) - - @classmethod - def yaml_load(cls, constructor, node): - data = cls() - yield data - constructor.construct_mapping(node, data) - - yaml.add_constructor(u"!xxx", XXX.yaml_load, constructor=yaml.RoundTripConstructor) - yaml.add_representer(XXX, XXX.yaml_dump, representer=yaml.RoundTripRepresenter) - - -class TestIndentFailures: - def test_tag(self): - round_trip( - """\ - !!python/object:__main__.Developer - name: Anthon - location: Germany - language: python - """ - ) - - def test_full_tag(self): - round_trip( - """\ - !!tag:yaml.org,2002:python/object:__main__.Developer - name: Anthon - location: Germany - language: python - """ - ) - - def test_standard_tag(self): - round_trip( - """\ - !!tag:yaml.org,2002:python/object:map - name: Anthon - location: Germany - language: python - """ - ) - - def test_Y1(self): - round_trip( - """\ - !yyy - name: Anthon - location: Germany - language: python - """ - ) - - def test_Y2(self): - round_trip( - """\ - !!yyy - name: Anthon - location: Germany - language: python - """ - ) - - -class TestRoundTripCustom: - def test_X1(self): - register_xxx() - round_trip( - """\ - !xxx - name: Anthon - location: Germany - language: python - """ - ) - - @pytest.mark.xfail(strict=True) - def test_X_pre_tag_comment(self): - register_xxx() - round_trip( - """\ - - - # hello - !xxx - name: Anthon - location: Germany - language: python - """ - ) - - @pytest.mark.xfail(strict=True) - def test_X_post_tag_comment(self): - register_xxx() - round_trip( - """\ - - !xxx - # hello - name: Anthon - location: Germany - language: python - """ - ) - - def test_scalar_00(self): - # https://stackoverflow.com/a/45967047/1307905 - round_trip( - """\ - Outputs: - Vpc: - Value: !Ref: vpc # first tag - Export: - Name: !Sub "${AWS::StackName}-Vpc" # second tag - """ - ) - - -class TestIssue201: - def test_encoded_unicode_tag(self): - round_trip_load( - """ - s: !!python/%75nicode 'abc' - """ - ) - - -class TestImplicitTaggedNodes: - def test_scalar(self): - round_trip( - """\ - - !Scalar abcdefg - """ - ) - - def test_mapping(self): - round_trip( - """\ - - !Mapping {a: 1, b: 2} - """ - ) - - def test_sequence(self): - yaml = YAML() - yaml.brace_single_entry_mapping_in_flow_sequence = True - yaml.mapping_value_align = True - yaml.round_trip( - """ - - !Sequence [a, {b: 1}, {c: {d: 3}}] - """ - ) - - def test_sequence2(self): - yaml = YAML() - yaml.mapping_value_align = True - yaml.round_trip( - """ - - !Sequence [a, b: 1, c: {d: 3}] - """ - ) diff --git a/srsly/tests/ruamel_yaml/test_version.py b/srsly/tests/ruamel_yaml/test_version.py deleted file mode 100755 index 2dbf53d..0000000 --- a/srsly/tests/ruamel_yaml/test_version.py +++ /dev/null @@ -1,175 +0,0 @@ -# coding: utf-8 - -import pytest # NOQA - -from .roundtrip import dedent, round_trip, round_trip_load - - -def load(s, version=None): - import srsly.ruamel_yaml # NOQA - - return srsly.ruamel_yaml.round_trip_load(dedent(s), version) - - -class TestVersions: - def test_explicit_1_2(self): - r = load( - """\ - %YAML 1.2 - --- - - 12:34:56 - - 012 - - 012345678 - - 0o12 - - on - - off - - yes - - no - - true - """ - ) - assert r[0] == "12:34:56" - assert r[1] == 12 - assert r[2] == 12345678 - assert r[3] == 10 - assert r[4] == "on" - assert r[5] == "off" - assert r[6] == "yes" - assert r[7] == "no" - assert r[8] is True - - def test_explicit_1_1(self): - r = load( - """\ - %YAML 1.1 - --- - - 12:34:56 - - 012 - - 012345678 - - 0o12 - - on - - off - - yes - - no - - true - """ - ) - assert r[0] == 45296 - assert r[1] == 10 - assert r[2] == "012345678" - assert r[3] == "0o12" - assert r[4] is True - assert r[5] is False - assert r[6] is True - assert r[7] is False - assert r[8] is True - - def test_implicit_1_2(self): - r = load( - """\ - - 12:34:56 - - 12:34:56.78 - - 012 - - 012345678 - - 0o12 - - on - - off - - yes - - no - - true - """ - ) - assert r[0] == "12:34:56" - assert r[1] == "12:34:56.78" - assert r[2] == 12 - assert r[3] == 12345678 - assert r[4] == 10 - assert r[5] == "on" - assert r[6] == "off" - assert r[7] == "yes" - assert r[8] == "no" - assert r[9] is True - - def test_load_version_1_1(self): - inp = """\ - - 12:34:56 - - 12:34:56.78 - - 012 - - 012345678 - - 0o12 - - on - - off - - yes - - no - - true - """ - r = load(inp, version="1.1") - assert r[0] == 45296 - assert r[1] == 45296.78 - assert r[2] == 10 - assert r[3] == "012345678" - assert r[4] == "0o12" - assert r[5] is True - assert r[6] is False - assert r[7] is True - assert r[8] is False - assert r[9] is True - - -class TestIssue62: - # bitbucket issue 62, issue_62 - def test_00(self): - import srsly.ruamel_yaml # NOQA - - s = dedent( - """\ - {}# Outside flow collection: - - ::vector - - ": - ()" - - Up, up, and away! - - -123 - - http://example.com/foo#bar - # Inside flow collection: - - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar] - """ - ) - with pytest.raises(srsly.ruamel_yaml.parser.ParserError): - round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) - round_trip(s.format(""), preserve_quotes=True) - - def test_00_single_comment(self): - import srsly.ruamel_yaml # NOQA - - s = dedent( - """\ - {}# Outside flow collection: - - ::vector - - ": - ()" - - Up, up, and away! - - -123 - - http://example.com/foo#bar - - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar] - """ - ) - with pytest.raises(srsly.ruamel_yaml.parser.ParserError): - round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) - round_trip(s.format(""), preserve_quotes=True) - # round_trip(s.format('%YAML 1.2\n---\n'), preserve_quotes=True, version=(1, 2)) - - def test_01(self): - import srsly.ruamel_yaml # NOQA - - s = dedent( - """\ - {}[random plain value that contains a ? character] - """ - ) - with pytest.raises(srsly.ruamel_yaml.parser.ParserError): - round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True) - round_trip(s.format(""), preserve_quotes=True) - # note the flow seq on the --- line! - round_trip(s.format("%YAML 1.2\n--- "), preserve_quotes=True, version="1.2") - - def test_so_45681626(self): - # was not properly parsing - round_trip_load('{"in":{},"out":{}}') diff --git a/srsly/tests/ruamel_yaml/test_yamlfile.py b/srsly/tests/ruamel_yaml/test_yamlfile.py deleted file mode 100755 index f30c47b..0000000 --- a/srsly/tests/ruamel_yaml/test_yamlfile.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import print_function - -""" -various test cases for YAML files -""" - -import sys -import pytest # NOQA -import platform - -from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA - - -class TestYAML: - def test_backslash(self): - round_trip( - """ - handlers: - static_files: applications/\\1/static/\\2 - """ - ) - - def test_omap_out(self): - # ordereddict mapped to !!omap - from srsly.ruamel_yaml.compat import ordereddict - import srsly.ruamel_yaml # NOQA - - x = ordereddict([("a", 1), ("b", 2)]) - res = srsly.ruamel_yaml.dump(x, default_flow_style=False) - assert res == dedent( - """ - !!omap - - a: 1 - - b: 2 - """ - ) - - def test_omap_roundtrip(self): - round_trip( - """ - !!omap - - a: 1 - - b: 2 - - c: 3 - - d: 4 - """ - ) - - @pytest.mark.skipif(sys.version_info < (2, 7), reason="collections not available") - def test_dump_collections_ordereddict(self): - from collections import OrderedDict - import srsly.ruamel_yaml # NOQA - - # OrderedDict mapped to !!omap - x = OrderedDict([("a", 1), ("b", 2)]) - res = srsly.ruamel_yaml.dump( - x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False - ) - assert res == dedent( - """ - !!omap - - a: 1 - - b: 2 - """ - ) - - @pytest.mark.skipif( - sys.version_info >= (3, 0) or platform.python_implementation() != "CPython", - reason="srsly.ruamel_yaml not available", - ) - def test_dump_ruamel_ordereddict(self): - from srsly.ruamel_yaml.compat import ordereddict - import srsly.ruamel_yaml # NOQA - - # OrderedDict mapped to !!omap - x = ordereddict([("a", 1), ("b", 2)]) - res = srsly.ruamel_yaml.dump( - x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False - ) - assert res == dedent( - """ - !!omap - - a: 1 - - b: 2 - """ - ) - - def test_CommentedSet(self): - from srsly.ruamel_yaml.constructor import CommentedSet - - s = CommentedSet(["a", "b", "c"]) - s.remove("b") - s.add("d") - assert s == CommentedSet(["a", "c", "d"]) - s.add("e") - s.add("f") - s.remove("e") - assert s == CommentedSet(["a", "c", "d", "f"]) - - def test_set_out(self): - # preferable would be the shorter format without the ': null' - import srsly.ruamel_yaml # NOQA - - x = set(["a", "b", "c"]) - res = srsly.ruamel_yaml.dump(x, default_flow_style=False) - assert res == dedent( - """ - !!set - a: null - b: null - c: null - """ - ) - - # ordering is not preserved in a set - def test_set_compact(self): - # this format is read and also should be written by default - round_trip( - """ - !!set - ? a - ? b - ? c - """ - ) - - def test_blank_line_after_comment(self): - round_trip( - """ - # Comment with spaces after it. - - - a: 1 - """ - ) - - def test_blank_line_between_seq_items(self): - round_trip( - """ - # Seq with empty lines in between items. - b: - - bar - - - - baz - """ - ) - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_blank_line_after_literal_chip(self): - s = """ - c: - - | - This item - has a blank line - following it. - - - | - To visually separate it from this item. - - This item contains a blank line. - - - """ - d = round_trip_load(dedent(s)) - print(d) - round_trip(s) - assert d["c"][0].split("it.")[1] == "\n" - assert d["c"][1].split("line.")[1] == "\n" - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_blank_line_after_literal_keep(self): - """ have to insert an eof marker in YAML to test this""" - s = """ - c: - - |+ - This item - has a blank line - following it. - - - |+ - To visually separate it from this item. - - This item contains a blank line. - - - ... - """ - d = round_trip_load(dedent(s)) - print(d) - round_trip(s) - assert d["c"][0].split("it.")[1] == "\n\n" - assert d["c"][1].split("line.")[1] == "\n\n\n" - - @pytest.mark.skipif( - platform.python_implementation() == "Jython", - reason="Jython throws RepresenterError", - ) - def test_blank_line_after_literal_strip(self): - s = """ - c: - - |- - This item - has a blank line - following it. - - - |- - To visually separate it from this item. - - This item contains a blank line. - - - """ - d = round_trip_load(dedent(s)) - print(d) - round_trip(s) - assert d["c"][0].split("it.")[1] == "" - assert d["c"][1].split("line.")[1] == "" - - def test_load_all_perserve_quotes(self): - import srsly.ruamel_yaml # NOQA - - s = dedent( - """\ - a: 'hello' - --- - b: "goodbye" - """ - ) - data = [] - for x in srsly.ruamel_yaml.round_trip_load_all(s, preserve_quotes=True): - data.append(x) - out = srsly.ruamel_yaml.dump_all(data, Dumper=srsly.ruamel_yaml.RoundTripDumper) - print(type(data[0]["a"]), data[0]["a"]) - # out = srsly.ruamel_yaml.round_trip_dump_all(data) - print(out) - assert out == s diff --git a/srsly/tests/ruamel_yaml/test_yamlobject.py b/srsly/tests/ruamel_yaml/test_yamlobject.py deleted file mode 100755 index b0d64b9..0000000 --- a/srsly/tests/ruamel_yaml/test_yamlobject.py +++ /dev/null @@ -1,90 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function - -import sys -import pytest # NOQA - -from .roundtrip import save_and_run # NOQA - - -def test_monster(tmpdir): - program_src = u'''\ - import srsly.ruamel_yaml - from textwrap import dedent - - class Monster(srsly.ruamel_yaml.YAMLObject): - yaml_tag = u'!Monster' - - def __init__(self, name, hp, ac, attacks): - self.name = name - self.hp = hp - self.ac = ac - self.attacks = attacks - - def __repr__(self): - return "%s(name=%r, hp=%r, ac=%r, attacks=%r)" % ( - self.__class__.__name__, self.name, self.hp, self.ac, self.attacks) - - data = srsly.ruamel_yaml.load(dedent("""\\ - --- !Monster - name: Cave spider - hp: [2,6] # 2d6 - ac: 16 - attacks: [BITE, HURT] - """), Loader=srsly.ruamel_yaml.Loader) - # normal dump, keys will be sorted - assert srsly.ruamel_yaml.dump(data) == dedent("""\\ - !Monster - ac: 16 - attacks: [BITE, HURT] - hp: [2, 6] - name: Cave spider - """) - ''' - assert save_and_run(program_src, tmpdir) == 1 - - -@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__") -def test_qualified_name00(tmpdir): - """issue 214""" - program_src = u"""\ - from srsly.ruamel_yaml import YAML - from srsly.ruamel_yaml.compat import StringIO - - class A: - def f(self): - pass - - yaml = YAML(typ='unsafe', pure=True) - yaml.explicit_end = True - buf = StringIO() - yaml.dump(A.f, buf) - res = buf.getvalue() - print('res', repr(res)) - assert res == "!!python/name:__main__.A.f ''\\n...\\n" - x = yaml.load(res) - assert x == A.f - """ - assert save_and_run(program_src, tmpdir) == 1 - - -@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__") -def test_qualified_name01(tmpdir): - """issue 214""" - from srsly.ruamel_yaml import YAML - import srsly.ruamel_yaml.comments - from srsly.ruamel_yaml.compat import StringIO - - with pytest.raises(ValueError): - yaml = YAML(typ="unsafe", pure=True) - yaml.explicit_end = True - buf = StringIO() - yaml.dump(srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor, buf) - res = buf.getvalue() - assert ( - res - == "!!python/name:srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor ''\n...\n" - ) - x = yaml.load(res) - assert x == srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor diff --git a/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py b/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py deleted file mode 100755 index 3a3c835..0000000 --- a/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py +++ /dev/null @@ -1,39 +0,0 @@ -# coding: utf-8 - -import sys -import pytest # NOQA - -from .roundtrip import round_trip_load, round_trip_dump, dedent - - -class TestLeftOverDebug: - # idea here is to capture round_trip_output via pytest stdout capture - # if there is are any leftover debug statements they should show up - def test_00(self, capsys): - s = dedent( - """ - a: 1 - b: [] - c: [a, 1] - d: {f: 3.14, g: 42} - """ - ) - d = round_trip_load(s) - round_trip_dump(d, sys.stdout) - out, err = capsys.readouterr() - assert out == s - - def test_01(self, capsys): - s = dedent( - """ - - 1 - - [] - - [a, 1] - - {f: 3.14, g: 42} - - - 123 - """ - ) - d = round_trip_load(s) - round_trip_dump(d, sys.stdout) - out, err = capsys.readouterr() - assert out == s diff --git a/srsly/tests/ruamel_yaml/test_z_data.py b/srsly/tests/ruamel_yaml/test_z_data.py deleted file mode 100755 index dbd13a1..0000000 --- a/srsly/tests/ruamel_yaml/test_z_data.py +++ /dev/null @@ -1,243 +0,0 @@ -# coding: utf-8 - -from __future__ import print_function, unicode_literals - -import sys -import pytest # NOQA -import warnings # NOQA - -from pathlib import Path - -base_path = Path("data") # that is srsly.ruamel_yaml.data -PY2 = sys.version_info[0] == 2 - - -class YAMLData(object): - yaml_tag = "!YAML" - - def __init__(self, s): - self._s = s - - # Conversion tables for input. E.g. "" is replaced by "\t" - # fmt: off - special = { - 'SPC': ' ', - 'TAB': '\t', - '---': '---', - '...': '...', - } - # fmt: on - - @property - def value(self): - if hasattr(self, "_p"): - return self._p - assert " \n" not in self._s - assert "\t\n" not in self._s - self._p = self._s - for k, v in YAMLData.special.items(): - k = "<" + k + ">" - self._p = self._p.replace(k, v) - return self._p - - def test_rewrite(self, s): - assert " \n" not in s - assert "\t\n" not in s - for k, v in YAMLData.special.items(): - k = "<" + k + ">" - s = s.replace(k, v) - return s - - @classmethod - def from_yaml(cls, constructor, node): - from srsly.ruamel_yaml.nodes import MappingNode - - if isinstance(node, MappingNode): - return cls(constructor.construct_mapping(node)) - return cls(node.value) - - -class Python(YAMLData): - yaml_tag = "!Python" - - -class Output(YAMLData): - yaml_tag = "!Output" - - -class Assert(YAMLData): - yaml_tag = "!Assert" - - @property - def value(self): - from srsly.ruamel_yaml.compat import Mapping - - if hasattr(self, "_pa"): - return self._pa - if isinstance(self._s, Mapping): - self._s["lines"] = self.test_rewrite(self._s["lines"]) - self._pa = self._s - return self._pa - - -def pytest_generate_tests(metafunc): - test_yaml = [] - paths = sorted(base_path.glob("**/*.yaml")) - idlist = [] - for path in paths: - stem = path.stem - if stem.startswith(".#"): # skip emacs temporary file - continue - idlist.append(stem) - test_yaml.append([path]) - metafunc.parametrize(["yaml"], test_yaml, ids=idlist, scope="class") - - -class TestYAMLData(object): - def yaml(self, yaml_version=None): - from srsly.ruamel_yaml import YAML - - y = YAML() - y.preserve_quotes = True - if yaml_version: - y.version = yaml_version - return y - - def docs(self, path): - from srsly.ruamel_yaml import YAML - - tyaml = YAML(typ="safe", pure=True) - tyaml.register_class(YAMLData) - tyaml.register_class(Python) - tyaml.register_class(Output) - tyaml.register_class(Assert) - return list(tyaml.load_all(path)) - - def yaml_load(self, value, yaml_version=None): - yaml = self.yaml(yaml_version=yaml_version) - data = yaml.load(value) - return yaml, data - - def round_trip(self, input, output=None, yaml_version=None): - from srsly.ruamel_yaml.compat import StringIO - - yaml, data = self.yaml_load(input.value, yaml_version=yaml_version) - buf = StringIO() - yaml.dump(data, buf) - expected = input.value if output is None else output.value - value = buf.getvalue() - if PY2: - value = value.decode("utf-8") - print("value", value) - # print('expected', expected) - assert value == expected - - def load_assert(self, input, confirm, yaml_version=None): - from srsly.ruamel_yaml.compat import Mapping - - d = self.yaml_load(input.value, yaml_version=yaml_version)[1] # NOQA - print("confirm.value", confirm.value, type(confirm.value)) - if isinstance(confirm.value, Mapping): - r = range(confirm.value["range"]) - lines = confirm.value["lines"].splitlines() - for idx in r: # NOQA - for line in lines: - line = "assert " + line - print(line) - exec(line) - else: - for line in confirm.value.splitlines(): - line = "assert " + line - print(line) - exec(line) - - def run_python(self, python, data, tmpdir): - from .roundtrip import save_and_run - - assert save_and_run(python.value, base_dir=tmpdir, output=data.value) == 0 - - # this is executed by pytest the methods with names not starting with test_ - # are helpers - def test_yaml_data(self, yaml, tmpdir): - from srsly.ruamel_yaml.compat import Mapping - - idx = 0 - typ = None - yaml_version = None - - docs = self.docs(yaml) - if isinstance(docs[0], Mapping): - d = docs[0] - typ = d.get("type") - yaml_version = d.get("yaml_version") - if "python" in d: - if not check_python_version(d["python"]): - pytest.skip("unsupported version") - idx += 1 - data = output = confirm = python = None - for doc in docs[idx:]: - if isinstance(doc, Output): - output = doc - elif isinstance(doc, Assert): - confirm = doc - elif isinstance(doc, Python): - python = doc - if typ is None: - typ = "python_run" - elif isinstance(doc, YAMLData): - data = doc - else: - print("no handler for type:", type(doc), repr(doc)) - raise AssertionError() - if typ is None: - if data is not None and output is not None: - typ = "rt" - elif data is not None and confirm is not None: - typ = "load_assert" - else: - assert data is not None - typ = "rt" - print("type:", typ) - if data is not None: - print("data:", data.value, end="") - print("output:", output.value if output is not None else output) - if typ == "rt": - self.round_trip(data, output, yaml_version=yaml_version) - elif typ == "python_run": - self.run_python(python, output if output is not None else data, tmpdir) - elif typ == "load_assert": - self.load_assert(data, confirm, yaml_version=yaml_version) - else: - print("\nrun type unknown:", typ) - raise AssertionError() - - -def check_python_version(match, current=None): - """ - version indication, return True if version matches. - match should be something like 3.6+, or [2.7, 3.3] etc. Floats - are converted to strings. Single values are made into lists. - """ - if current is None: - current = list(sys.version_info[:3]) - if not isinstance(match, list): - match = [match] - for m in match: - minimal = False - if isinstance(m, float): - m = str(m) - if m.endswith("+"): - minimal = True - m = m[:-1] - # assert m[0].isdigit() - # assert m[-1].isdigit() - m = [int(x) for x in m.split(".")] - current_len = current[: len(m)] - # print(m, current, current_len) - if minimal: - if current_len >= m: - return True - else: - if current_len == m: - return True - return False diff --git a/srsly/tests/test_json_api.py b/srsly/tests/test_json_api.py index 89ce400..85f0164 100644 --- a/srsly/tests/test_json_api.py +++ b/srsly/tests/test_json_api.py @@ -2,7 +2,6 @@ from io import StringIO from pathlib import Path import gzip -import numpy from .._json_api import ( read_json, @@ -55,8 +54,8 @@ def test_write_json_file(): data = {"hello": "world", "test": 123} # Provide two expected options, depending on how keys are ordered expected = [ - '{\n "hello":"world",\n "test":123\n}', - '{\n "test":123,\n "hello":"world"\n}', + '{\n "hello": "world",\n "test": 123\n}', + '{\n "test": 123,\n "hello": "world"\n}', ] with make_tempdir() as temp_dir: file_path = temp_dir / "tmp.json" @@ -69,8 +68,8 @@ def test_write_json_file_gzip(): data = {"hello": "world", "test": 123} # Provide two expected options, depending on how keys are ordered expected = [ - '{\n "hello":"world",\n "test":123\n}', - '{\n "test":123,\n "hello":"world"\n}', + '{\n "hello": "world",\n "test": 123\n}', + '{\n "test": 123,\n "hello": "world"\n}', ] with make_tempdir() as temp_dir: file_path = force_string(temp_dir / "tmp.json") @@ -83,8 +82,8 @@ def test_write_json_stdout(capsys): data = {"hello": "world", "test": 123} # Provide two expected options, depending on how keys are ordered expected = [ - '{\n "hello":"world",\n "test":123\n}\n', - '{\n "test":123,\n "hello":"world"\n}\n', + '{\n "hello": "world",\n "test": 123\n}\n', + '{\n "test": 123,\n "hello": "world"\n}\n', ] write_json("-", data) captured = capsys.readouterr() @@ -208,8 +207,14 @@ def test_json_loads_raises(obj): def test_unsupported_type_error(): + with pytest.raises(TypeError, match="is not JSON serializable"): + s = json_dumps({1, 2}) + + +def test_unsupported_type_error_numpy(): + numpy = pytest.importorskip("numpy") f = numpy.float32() - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="is not JSON serializable"): s = json_dumps(f) diff --git a/srsly/tests/test_msgpack_api.py b/srsly/tests/test_msgpack_api.py index b516038..13d0f66 100644 --- a/srsly/tests/test_msgpack_api.py +++ b/srsly/tests/test_msgpack_api.py @@ -1,8 +1,8 @@ -import pytest -from pathlib import Path import datetime -from mock import patch -import numpy +from collections import namedtuple +from pathlib import Path + +import pytest from .._msgpack_api import read_msgpack, write_msgpack from .._msgpack_api import msgpack_loads, msgpack_dumps @@ -54,14 +54,63 @@ def test_write_msgpack_file(): assert f.read() in expected -@patch("srsly.msgpack._msgpack_numpy.np", None) -@patch("srsly.msgpack._msgpack_numpy.has_numpy", False) +def test_msgpack_complex(): + inp = {"a": 1 + 2j} + out = msgpack_loads(msgpack_dumps(inp)) + assert out == inp + # Test that we didn't accidentally convert to np.complex128, + # which is a subclass of complex + assert type(out["a"]) is complex + + def test_msgpack_without_numpy(): - """Test that msgpack works without numpy and raises correct errors (e.g. + """Test that msgpack works with and without numpy and raises correct errors (e.g. when serializing datetime objects, the error should be msgpack's TypeError, not a "'np' is not defined error").""" - with pytest.raises(TypeError): - msgpack_loads(msgpack_dumps(datetime.datetime.now())) + with pytest.raises(TypeError, match="datetime.datetime"): + msgpack_dumps(datetime.datetime.now()) + + +@pytest.mark.parametrize( + "base_cls,data", + [ + (int, 1), + (float, 1), + (list, [1, 2]), + (dict, {"x": 1}), + (str, "foo"), + (bytes, b"foo"), + ], +) +def test_msgpack_subtypes(base_cls, data): + """Subtypes of base types are cast to their parents.""" + + class SubType(base_cls): + pass + + inp = SubType(data) + out = msgpack_loads(msgpack_dumps(inp)) + assert type(out) is base_cls + assert out == base_cls(data) + + +def test_msgpack_tuple(): + # There is no difference between list and tuple in msgpack + # Outcome is controlled by the use_list=True parameter + class MyTuple(tuple): + pass + + Named = namedtuple("Named", ["x", "y", "z"]) + + b1 = msgpack_dumps((1, 2, 3)) + b2 = msgpack_dumps([1, 2, 3]) + b3 = msgpack_dumps(MyTuple((1, 2, 3))) + b4 = msgpack_dumps(Named(x=1, y=2, z=3)) + assert b2 == b1 + assert b3 == b1 + assert b4 == b1 + assert msgpack_loads(b1) == [1, 2, 3] + assert msgpack_loads(b1, use_list=False) == (1, 2, 3) def test_msgpack_custom_encoder_decoder(): @@ -69,15 +118,15 @@ class CustomObject: def __init__(self, value): self.value = value - def serialize_obj(obj, chain=None): + def serialize_obj(obj): if isinstance(obj, CustomObject): return {"__custom__": obj.value} - return obj if chain is None else chain(obj) + return obj - def deserialize_obj(obj, chain=None): + def deserialize_obj(obj): if "__custom__" in obj: return CustomObject(obj["__custom__"]) - return obj if chain is None else chain(obj) + return obj data = {"a": 123, "b": CustomObject({"foo": "bar"})} with pytest.raises(TypeError): @@ -91,10 +140,70 @@ def deserialize_obj(obj, chain=None): assert new_data["a"] == 123 assert isinstance(new_data["b"], CustomObject) assert new_data["b"].value == {"foo": "bar"} - # Test that it also works with combinations of encoders/decoders (e.g. numpy) - data = {"a": numpy.zeros((1, 2, 3)), "b": CustomObject({"foo": "bar"})} + + # Test that it also works with combinations of encoders/decoders (e.g. complex) + data = {"a": 1 + 2j, "b": CustomObject({"foo": "bar"})} bytes_data = msgpack_dumps(data) new_data = msgpack_loads(bytes_data) - assert isinstance(new_data["a"], numpy.ndarray) + assert isinstance(new_data["a"], complex) assert isinstance(new_data["b"], CustomObject) assert new_data["b"].value == {"foo": "bar"} + + # Clean up + msgpack_encoders.deregister("custom_object") + msgpack_decoders.deregister("custom_object") + + +def test_msgpack_custom_subtype_handler(): + """By default, subtypes of base types are cast to their parents. + Test that the user can define a custom encoder/decoder to preserve + the subtype. + """ + + class MyInt(int): + pass + + def encode_myint(obj): + if isinstance(obj, MyInt): + return {"MyInt": int(obj)} + return obj + + def decode_myint(obj): + if "MyInt" in obj: + return MyInt(obj["MyInt"]) + return obj + + inp = MyInt(5) + out = msgpack_loads(msgpack_dumps(inp)) + assert out == 5 + assert type(out) is int + + msgpack_encoders.register("myint", func=encode_myint) + msgpack_decoders.register("myint", func=decode_myint) + out = msgpack_loads(msgpack_dumps(inp)) + assert out == MyInt(5) + assert type(out) is MyInt + + # Cleanup + msgpack_encoders.deregister("myint") + msgpack_decoders.deregister("myint") + + +def test_msgpack_numpy_not_installed(): + """Test that we get a clean ModuleNotFoundError when + trying to decode numpy data when numpy is not installed. + """ + # Output of np.float64(1) + bin = ( + b"\x83\xc4\x02nd\xc2\xc4\x04type\xa3", encode_html_chars=True), - ) - - def test_doubleLongIssue(self): - sut = {"a": -4342969734183514} - encoded = json.dumps(sut) - decoded = json.loads(encoded) - self.assertEqual(sut, decoded) - encoded = ujson.encode(sut) - decoded = ujson.decode(encoded) - self.assertEqual(sut, decoded) - - def test_doubleLongDecimalIssue(self): - sut = {"a": -12345678901234.56789012} - encoded = json.dumps(sut) - decoded = json.loads(encoded) - self.assertEqual(sut, decoded) - encoded = ujson.encode(sut) - decoded = ujson.decode(encoded) - self.assertEqual(sut, decoded) - - def test_encodeDecodeLongDecimal(self): - sut = {"a": -528656961.4399388} - encoded = ujson.dumps(sut) - ujson.decode(encoded) - - def test_decimalDecodeTest(self): - sut = {"a": 4.56} - encoded = ujson.encode(sut) - decoded = ujson.decode(encoded) - self.assertAlmostEqual(sut[u"a"], decoded[u"a"]) - - def test_encodeDictWithUnicodeKeys(self): - input = { - "key1": "value1", - "key1": "value1", - "key1": "value1", - "key1": "value1", - "key1": "value1", - "key1": "value1", - } - ujson.encode(input) - - input = { - "بن": "value1", - "بن": "value1", - "بن": "value1", - "بن": "value1", - "بن": "value1", - "بن": "value1", - "بن": "value1", - } - ujson.encode(input) - - def test_encodeDoubleConversion(self): - input = math.pi - output = ujson.encode(input) - self.assertEqual(round(input, 5), round(json.loads(output), 5)) - self.assertEqual(round(input, 5), round(ujson.decode(output), 5)) - - def test_encodeWithDecimal(self): - input = 1.0 - output = ujson.encode(input) - self.assertEqual(output, "1.0") - - def test_encodeDoubleNegConversion(self): - input = -math.pi - output = ujson.encode(input) - - self.assertEqual(round(input, 5), round(json.loads(output), 5)) - self.assertEqual(round(input, 5), round(ujson.decode(output), 5)) - - def test_encodeArrayOfNestedArrays(self): - input = [[[[]]]] * 20 - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - # self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeArrayOfDoubles(self): - input = [31337.31337, 31337.31337, 31337.31337, 31337.31337] * 10 - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - # self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeStringConversion2(self): - input = "A string \\ / \b \f \n \r \t" - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, '"A string \\\\ \\/ \\b \\f \\n \\r \\t"') - self.assertEqual(input, ujson.decode(output)) - - def test_decodeUnicodeConversion(self): - pass - - def test_encodeUnicodeConversion1(self): - input = "Räksmörgås اسامة بن محمد بن عوض بن لادن" - enc = ujson.encode(input) - dec = ujson.decode(enc) - self.assertEqual(enc, json_unicode(input)) - self.assertEqual(dec, json.loads(enc)) - - def test_encodeControlEscaping(self): - input = "\x19" - enc = ujson.encode(input) - dec = ujson.decode(enc) - self.assertEqual(input, dec) - self.assertEqual(enc, json_unicode(input)) - - def test_encodeUnicodeConversion2(self): - input = "\xe6\x97\xa5\xd1\x88" - enc = ujson.encode(input) - dec = ujson.decode(enc) - self.assertEqual(enc, json_unicode(input)) - self.assertEqual(dec, json.loads(enc)) - - def test_encodeUnicodeSurrogatePair(self): - input = "\xf0\x90\x8d\x86" - enc = ujson.encode(input) - dec = ujson.decode(enc) - - self.assertEqual(enc, json_unicode(input)) - self.assertEqual(dec, json.loads(enc)) - - def test_encodeUnicode4BytesUTF8(self): - input = "\xf0\x91\x80\xb0TRAILINGNORMAL" - enc = ujson.encode(input) - dec = ujson.decode(enc) - - self.assertEqual(enc, json_unicode(input)) - self.assertEqual(dec, json.loads(enc)) - - def test_encodeUnicode4BytesUTF8Highest(self): - input = "\xf3\xbf\xbf\xbfTRAILINGNORMAL" - enc = ujson.encode(input) - dec = ujson.decode(enc) - - self.assertEqual(enc, json_unicode(input)) - self.assertEqual(dec, json.loads(enc)) - - # Characters outside of Basic Multilingual Plane(larger than - # 16 bits) are represented as \UXXXXXXXX in python but should be encoded - # as \uXXXX\uXXXX in json. - def testEncodeUnicodeBMP(self): - s = "\U0001f42e\U0001f42e\U0001F42D\U0001F42D" # 🐮🐮🐭🐭 - encoded = ujson.dumps(s) - encoded_json = json.dumps(s) - - if len(s) == 4: - self.assertEqual(len(encoded), len(s) * 12 + 2) - else: - self.assertEqual(len(encoded), len(s) * 6 + 2) - - self.assertEqual(encoded, encoded_json) - decoded = ujson.loads(encoded) - self.assertEqual(s, decoded) - - # ujson outputs an UTF-8 encoded str object - encoded = ujson.dumps(s, ensure_ascii=False) - # json outputs an unicode object - encoded_json = json.dumps(s, ensure_ascii=False) - self.assertEqual(len(encoded), len(s) + 2) # original length + quotes - self.assertEqual(encoded, encoded_json) - decoded = ujson.loads(encoded) - self.assertEqual(s, decoded) - - def testEncodeSymbols(self): - s = "\u273f\u2661\u273f" # ✿♡✿ - encoded = ujson.dumps(s) - encoded_json = json.dumps(s) - self.assertEqual(len(encoded), len(s) * 6 + 2) # 6 characters + quotes - self.assertEqual(encoded, encoded_json) - decoded = ujson.loads(encoded) - self.assertEqual(s, decoded) - - # ujson outputs an UTF-8 encoded str object - encoded = ujson.dumps(s, ensure_ascii=False) - # json outputs an unicode object - encoded_json = json.dumps(s, ensure_ascii=False) - self.assertEqual(len(encoded), len(s) + 2) # original length + quotes - self.assertEqual(encoded, encoded_json) - decoded = ujson.loads(encoded) - self.assertEqual(s, decoded) - - def test_encodeArrayInArray(self): - input = [[[[]]]] - output = ujson.encode(input) - - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeIntConversion(self): - input = 31337 - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeIntNegConversion(self): - input = -31337 - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeLongNegConversion(self): - input = -9223372036854775808 - output = ujson.encode(input) - - json.loads(output) - ujson.decode(output) - - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeListConversion(self): - input = [1, 2, 3, 4] - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeDictConversion(self): - input = {"k1": 1, "k2": 2, "k3": 3, "k4": 4} - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(input, ujson.decode(output)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeNoneConversion(self): - input = None - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeTrueConversion(self): - input = True - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeFalseConversion(self): - input = False - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeToUTF8(self): - input = b"\xe6\x97\xa5\xd1\x88" - input = input.decode("utf-8") - enc = ujson.encode(input, ensure_ascii=False) - dec = ujson.decode(enc) - self.assertEqual(enc, json.dumps(input, ensure_ascii=False)) - self.assertEqual(dec, json.loads(enc)) - - def test_decodeFromUnicode(self): - input = '{"obj": 31337}' - dec1 = ujson.decode(input) - dec2 = ujson.decode(str(input)) - self.assertEqual(dec1, dec2) - - def test_encodeRecursionMax(self): - # 8 is the max recursion depth - class O2: - member = 0 - - def toDict(self): - return {"member": self.member} - - class O1: - member = 0 - - def toDict(self): - return {"member": self.member} - - input = O1() - input.member = O2() - input.member.member = input - self.assertRaises(OverflowError, ujson.encode, input) - - def test_encodeDoubleNan(self): - input = float("nan") - self.assertRaises(OverflowError, ujson.encode, input) - - def test_encodeDoubleInf(self): - input = float("inf") - self.assertRaises(OverflowError, ujson.encode, input) - - def test_encodeDoubleNegInf(self): - input = -float("inf") - self.assertRaises(OverflowError, ujson.encode, input) - - def test_encodeOrderedDict(self): - from collections import OrderedDict - - input = OrderedDict([(1, 1), (0, 0), (8, 8), (2, 2)]) - self.assertEqual('{"1":1,"0":0,"8":8,"2":2}', ujson.encode(input)) - - def test_decodeJibberish(self): - input = "fdsa sda v9sa fdsa" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenArrayStart(self): - input = "[" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenObjectStart(self): - input = "{" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenArrayEnd(self): - input = "]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayDepthTooBig(self): - input = "[" * (1024 * 1024) - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenObjectEnd(self): - input = "}" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeObjectTrailingCommaFail(self): - input = '{"one":1,}' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeObjectDepthTooBig(self): - input = "{" * (1024 * 1024) - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeStringUnterminated(self): - input = '"TESTING' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeStringUntermEscapeSequence(self): - input = '"TESTING\\"' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeStringBadEscape(self): - input = '"TESTING\\"' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeTrueBroken(self): - input = "tru" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeFalseBroken(self): - input = "fa" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeNullBroken(self): - input = "n" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenDictKeyTypeLeakTest(self): - input = '{{1337:""}}' - for x in range(1000): - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenDictLeakTest(self): - input = '{{"key":"}' - for x in range(1000): - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeBrokenListLeakTest(self): - input = "[[[true" - for x in range(1000): - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeDictWithNoKey(self): - input = "{{{{31337}}}}" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeDictWithNoColonOrValue(self): - input = '{{{{"key"}}}}' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeDictWithNoValue(self): - input = '{{{{"key":}}}}' - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeNumericIntPos(self): - input = "31337" - self.assertEqual(31337, ujson.decode(input)) - - def test_decodeNumericIntNeg(self): - input = "-31337" - self.assertEqual(-31337, ujson.decode(input)) - - def test_encodeUnicode4BytesUTF8Fail(self): - input = b"\xfd\xbf\xbf\xbf\xbf\xbf" - self.assertRaises(OverflowError, ujson.encode, input) - - def test_encodeNullCharacter(self): - input = "31337 \x00 1337" - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - input = "\x00" - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - self.assertEqual('" \\u0000\\r\\n "', ujson.dumps(" \u0000\r\n ")) - - def test_decodeNullCharacter(self): - input = '"31337 \\u0000 31337"' - self.assertEqual(ujson.decode(input), json.loads(input)) - - def test_encodeListLongConversion(self): - input = [ - 9223372036854775807, - 9223372036854775807, - 9223372036854775807, - 9223372036854775807, - 9223372036854775807, - 9223372036854775807, - ] - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeListLongUnsignedConversion(self): - input = [18446744073709551615, 18446744073709551615, 18446744073709551615] - output = ujson.encode(input) - - self.assertEqual(input, json.loads(output)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeLongConversion(self): - input = 9223372036854775807 - output = ujson.encode(input) - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_encodeLongUnsignedConversion(self): - input = 18446744073709551615 - output = ujson.encode(input) - - self.assertEqual(input, json.loads(output)) - self.assertEqual(output, json.dumps(input)) - self.assertEqual(input, ujson.decode(output)) - - def test_numericIntExp(self): - input = "1337E40" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_numericIntFrcExp(self): - input = "1.337E40" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpEPLUS(self): - input = "1337E+9" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpePLUS(self): - input = "1.337e+40" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpE(self): - input = "1337E40" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpe(self): - input = "1337e40" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpEMinus(self): - input = "1.337E-4" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_decodeNumericIntExpeMinus(self): - input = "1.337e-4" - output = ujson.decode(input) - self.assertEqual(output, json.loads(input)) - - def test_dumpToFile(self): - f = StringIO() - ujson.dump([1, 2, 3], f) - self.assertEqual("[1,2,3]", f.getvalue()) - - def test_dumpToFileLikeObject(self): - class filelike: - def __init__(self): - self.bytes = "" - - def write(self, bytes): - self.bytes += bytes - - f = filelike() - ujson.dump([1, 2, 3], f) - self.assertEqual("[1,2,3]", f.bytes) - - def test_dumpFileArgsError(self): - self.assertRaises(TypeError, ujson.dump, [], "") - - def test_loadFile(self): - f = StringIO("[1,2,3,4]") - self.assertEqual([1, 2, 3, 4], ujson.load(f)) - - def test_loadFileLikeObject(self): - class filelike: - def read(self): - try: - self.end - except AttributeError: - self.end = True - return "[1,2,3,4]" - - f = filelike() - self.assertEqual([1, 2, 3, 4], ujson.load(f)) - - def test_loadFileArgsError(self): - self.assertRaises(TypeError, ujson.load, "[]") - - def test_encodeNumericOverflow(self): - self.assertRaises(OverflowError, ujson.encode, 12839128391289382193812939) - - def test_decodeNumberWith32bitSignBit(self): - # Test that numbers that fit within 32 bits but would have the - # sign bit set (2**31 <= x < 2**32) are decoded properly. - docs = ( - '{"id": 3590016419}', - '{"id": %s}' % 2 ** 31, - '{"id": %s}' % 2 ** 32, - '{"id": %s}' % ((2 ** 32) - 1), - ) - results = (3590016419, 2 ** 31, 2 ** 32, 2 ** 32 - 1) - for doc, result in zip(docs, results): - self.assertEqual(ujson.decode(doc)["id"], result) - - def test_encodeBigEscape(self): - for x in range(10): - base = "\u00e5".encode("utf-8") - input = base * 1024 * 1024 * 2 - ujson.encode(input) - - def test_decodeBigEscape(self): - for x in range(10): - base = "\u00e5".encode("utf-8") - quote = '"'.encode() - input = quote + (base * 1024 * 1024 * 2) + quote - ujson.decode(input) - - def test_toDict(self): - d = {"key": 31337} - - class DictTest: - def toDict(self): - return d - - def __json__(self): - return '"json defined"' # Fallback and shouldn't be called. - - o = DictTest() - output = ujson.encode(o) - dec = ujson.decode(output) - self.assertEqual(dec, d) - - def test_object_with_json(self): - # If __json__ returns a string, then that string - # will be used as a raw JSON snippet in the object. - output_text = "this is the correct output" - - class JSONTest: - def __json__(self): - return '"' + output_text + '"' - - d = {u"key": JSONTest()} - output = ujson.encode(d) - dec = ujson.decode(output) - self.assertEqual(dec, {u"key": output_text}) - - def test_object_with_json_unicode(self): - # If __json__ returns a string, then that string - # will be used as a raw JSON snippet in the object. - output_text = u"this is the correct output" - - class JSONTest: - def __json__(self): - return u'"' + output_text + u'"' - - d = {u"key": JSONTest()} - output = ujson.encode(d) - dec = ujson.decode(output) - self.assertEqual(dec, {u"key": output_text}) - - def test_object_with_complex_json(self): - # If __json__ returns a string, then that string - # will be used as a raw JSON snippet in the object. - obj = {u"foo": [u"bar", u"baz"]} - - class JSONTest: - def __json__(self): - return ujson.encode(obj) - - d = {u"key": JSONTest()} - output = ujson.encode(d) - dec = ujson.decode(output) - self.assertEqual(dec, {u"key": obj}) - - def test_object_with_json_type_error(self): - # __json__ must return a string, otherwise it should raise an error. - for return_value in (None, 1234, 12.34, True, {}): - - class JSONTest: - def __json__(self): - return return_value - - d = {u"key": JSONTest()} - self.assertRaises(TypeError, ujson.encode, d) - - def test_object_with_json_attribute_error(self): - # If __json__ raises an error, make sure python actually raises it. - class JSONTest: - def __json__(self): - raise AttributeError - - d = {u"key": JSONTest()} - self.assertRaises(AttributeError, ujson.encode, d) - - def test_decodeArrayTrailingCommaFail(self): - input = "[31337,]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayLeadingCommaFail(self): - input = "[,31337]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayOnlyCommaFail(self): - input = "[,]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayUnmatchedBracketFail(self): - input = "[]]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayEmpty(self): - input = "[]" - obj = ujson.decode(input) - self.assertEqual([], obj) - - def test_decodeArrayOneItem(self): - input = "[31337]" - ujson.decode(input) - - def test_decodeLongUnsignedValue(self): - input = "18446744073709551615" - ujson.decode(input) - - def test_decodeBigValue(self): - input = "9223372036854775807" - ujson.decode(input) - - def test_decodeSmallValue(self): - input = "-9223372036854775808" - ujson.decode(input) - - def test_decodeTooBigValue(self): - input = "18446744073709551616" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeTooSmallValue(self): - input = "-90223372036854775809" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeVeryTooBigValue(self): - input = "18446744073709551616" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeVeryTooSmallValue(self): - input = "-90223372036854775809" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeWithTrailingWhitespaces(self): - input = "{}\n\t " - ujson.decode(input) - - def test_decodeWithTrailingNonWhitespaces(self): - input = "{}\n\t a" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeArrayWithBigInt(self): - input = "[18446744073709551616]" - self.assertRaises(ValueError, ujson.decode, input) - - def test_decodeFloatingPointAdditionalTests(self): - self.assertAlmostEqual(-1.1234567893, ujson.loads("-1.1234567893")) - self.assertAlmostEqual(-1.234567893, ujson.loads("-1.234567893")) - self.assertAlmostEqual(-1.34567893, ujson.loads("-1.34567893")) - self.assertAlmostEqual(-1.4567893, ujson.loads("-1.4567893")) - self.assertAlmostEqual(-1.567893, ujson.loads("-1.567893")) - self.assertAlmostEqual(-1.67893, ujson.loads("-1.67893")) - self.assertAlmostEqual(-1.7894, ujson.loads("-1.7894")) - self.assertAlmostEqual(-1.893, ujson.loads("-1.893")) - self.assertAlmostEqual(-1.3, ujson.loads("-1.3")) - - self.assertAlmostEqual(1.1234567893, ujson.loads("1.1234567893")) - self.assertAlmostEqual(1.234567893, ujson.loads("1.234567893")) - self.assertAlmostEqual(1.34567893, ujson.loads("1.34567893")) - self.assertAlmostEqual(1.4567893, ujson.loads("1.4567893")) - self.assertAlmostEqual(1.567893, ujson.loads("1.567893")) - self.assertAlmostEqual(1.67893, ujson.loads("1.67893")) - self.assertAlmostEqual(1.7894, ujson.loads("1.7894")) - self.assertAlmostEqual(1.893, ujson.loads("1.893")) - self.assertAlmostEqual(1.3, ujson.loads("1.3")) - - def test_ReadBadObjectSyntax(self): - input = '{"age", 44}' - self.assertRaises(ValueError, ujson.decode, input) - - def test_ReadTrue(self): - self.assertEqual(True, ujson.loads("true")) - - def test_ReadFalse(self): - self.assertEqual(False, ujson.loads("false")) - - def test_ReadNull(self): - self.assertEqual(None, ujson.loads("null")) - - def test_WriteTrue(self): - self.assertEqual("true", ujson.dumps(True)) - - def test_WriteFalse(self): - self.assertEqual("false", ujson.dumps(False)) - - def test_WriteNull(self): - self.assertEqual("null", ujson.dumps(None)) - - def test_ReadArrayOfSymbols(self): - self.assertEqual([True, False, None], ujson.loads(" [ true, false,null] ")) - - def test_WriteArrayOfSymbolsFromList(self): - self.assertEqual("[true,false,null]", ujson.dumps([True, False, None])) - - def test_WriteArrayOfSymbolsFromTuple(self): - self.assertEqual("[true,false,null]", ujson.dumps((True, False, None))) - - def test_encodingInvalidUnicodeCharacter(self): - s = "\udc7f" - self.assertRaises(UnicodeEncodeError, ujson.dumps, s) - - def test_sortKeys(self): - data = {"a": 1, "c": 1, "b": 1, "e": 1, "f": 1, "d": 1} - sortedKeys = ujson.dumps(data, sort_keys=True) - self.assertEqual(sortedKeys, '{"a":1,"b":1,"c":1,"d":1,"e":1,"f":1}') - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_values(self): - import gc - - gc.collect() - value = ["abc"] - data = {"1": value} - ref_count = sys.getrefcount(value) - ujson.dumps(data) - self.assertEqual(ref_count, sys.getrefcount(value)) - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_keys(self): - import gc - - gc.collect() - key1 = "1" - key2 = "1" - value1 = ["abc"] - value2 = [1, 2, 3] - data = {key1: value1, key2: value2} - ref_count1 = sys.getrefcount(key1) - ref_count2 = sys.getrefcount(key2) - ujson.dumps(data) - self.assertEqual(ref_count1, sys.getrefcount(key1)) - self.assertEqual(ref_count2, sys.getrefcount(key2)) - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_string_key(self): - import gc - - gc.collect() - key1 = "1" - value1 = 1 - data = {key1: value1} - ref_count1 = sys.getrefcount(key1) - ujson.dumps(data) - self.assertEqual(ref_count1, sys.getrefcount(key1)) - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_tuple_key(self): - import gc - - gc.collect() - key1 = ("a",) - value1 = 1 - data = {key1: value1} - ref_count1 = sys.getrefcount(key1) - ujson.dumps(data) - self.assertEqual(ref_count1, sys.getrefcount(key1)) - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_bytes_key(self): - import gc - - gc.collect() - key1 = b"1" - value1 = 1 - data = {key1: value1} - ref_count1 = sys.getrefcount(key1) - ujson.dumps(data) - self.assertEqual(ref_count1, sys.getrefcount(key1)) - - @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount") - def test_does_not_leak_dictionary_None_key(self): - import gc - - gc.collect() - key1 = None - value1 = 1 - data = {key1: value1} - ref_count1 = sys.getrefcount(key1) - ujson.dumps(data) - self.assertEqual(ref_count1, sys.getrefcount(key1)) - - -""" -def test_decodeNumericIntFrcOverflow(self): -input = "X.Y" -raise NotImplementedError("Implement this test!") - - -def test_decodeStringUnicodeEscape(self): -input = "\u3131" -raise NotImplementedError("Implement this test!") - -def test_decodeStringUnicodeBrokenEscape(self): -input = "\u3131" -raise NotImplementedError("Implement this test!") - -def test_decodeStringUnicodeInvalidEscape(self): -input = "\u3131" -raise NotImplementedError("Implement this test!") - -def test_decodeStringUTF8(self): -input = "someutfcharacters" -raise NotImplementedError("Implement this test!") - -""" - -if __name__ == "__main__": - unittest.main() - -""" -# Use this to look for memory leaks -if __name__ == '__main__': - from guppy import hpy - hp = hpy() - hp.setrelheap() - while True: - try: - unittest.main() - except SystemExit: - pass - heap = hp.heapu() - print(heap) -""" - - -@pytest.mark.parametrize("indent", list(range(65537, 65542))) -def test_dump_huge_indent(indent): - ujson.encode({"a": True}, indent=indent) - - -@pytest.mark.parametrize("first_length", list(range(2, 7))) -@pytest.mark.parametrize("second_length", list(range(10919, 10924))) -def test_dump_long_string(first_length, second_length): - ujson.dumps(["a" * first_length, "\x00" * second_length]) - - -def test_dump_indented_nested_list(): - a = _a = [] - for i in range(20): - _a.append(list(range(i))) - _a = _a[-1] - ujson.dumps(a, indent=i) - - -@pytest.mark.parametrize("indent", [0, 1, 2, 4, 5, 8, 49]) -def test_issue_334(indent): - path = Path(__file__).with_name("334-reproducer.json") - a = ujson.loads(path.read_bytes()) - ujson.dumps(a, indent=indent) - - -@pytest.mark.parametrize( - "test_input, expected", - [ - # Normal cases - (r'"\uD83D\uDCA9"', "\U0001F4A9"), - (r'"a\uD83D\uDCA9b"', "a\U0001F4A9b"), - # Unpaired surrogates - (r'"\uD800"', "\uD800"), - (r'"a\uD800b"', "a\uD800b"), - (r'"\uDEAD"', "\uDEAD"), - (r'"a\uDEADb"', "a\uDEADb"), - (r'"\uD83D\uD83D\uDCA9"', "\uD83D\U0001F4A9"), - (r'"\uDCA9\uD83D\uDCA9"', "\uDCA9\U0001F4A9"), - (r'"\uD83D\uDCA9\uD83D"', "\U0001F4A9\uD83D"), - (r'"\uD83D\uDCA9\uDCA9"', "\U0001F4A9\uDCA9"), - (r'"\uD83D \uDCA9"', "\uD83D \uDCA9"), - # No decoding of actual surrogate characters (rather than escaped ones) - ('"\uD800"', "\uD800"), - ('"\uDEAD"', "\uDEAD"), - ('"\uD800a\uDEAD"', "\uD800a\uDEAD"), - ('"\uD83D\uDCA9"', "\uD83D\uDCA9"), - ], -) -def test_decode_surrogate_characters(test_input, expected): - assert ujson.loads(test_input) == expected - assert ujson.loads(test_input.encode("utf-8", "surrogatepass")) == expected - - # Ensure that this matches stdlib's behaviour - assert json.loads(test_input) == expected diff --git a/srsly/ujson/JSONtoObj.c b/srsly/ujson/JSONtoObj.c deleted file mode 100644 index 8563970..0000000 --- a/srsly/ujson/JSONtoObj.c +++ /dev/null @@ -1,261 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms - * Copyright (c) 1988-1993 The Regents of the University of California. - * Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include "py_defines.h" -#include - - -//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__) -#define PRINTMARK() - -void Object_objectAddKey(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value) -{ - PyDict_SetItem (obj, name, value); - Py_DECREF( (PyObject *) name); - Py_DECREF( (PyObject *) value); - return; -} - -void Object_arrayAddItem(void *prv, JSOBJ obj, JSOBJ value) -{ - PyList_Append(obj, value); - Py_DECREF( (PyObject *) value); - return; -} - -/* -Check that Py_UCS4 is the same as JSUINT32, else Object_newString will fail. -Based on Linux's check in vbox_vmmdev_types.h. -This should be replaced with - _Static_assert(sizeof(Py_UCS4) == sizeof(JSUINT32)); -when C11 is made mandatory (CPython 3.11+, PyPy ?). -*/ -typedef char assert_py_ucs4_is_jsuint32[1 - 2*!(sizeof(Py_UCS4) == sizeof(JSUINT32))]; - -static JSOBJ Object_newString(void *prv, JSUINT32 *start, JSUINT32 *end) -{ - return PyUnicode_FromKindAndData (PyUnicode_4BYTE_KIND, (Py_UCS4 *) start, (end - start)); -} - -JSOBJ Object_newTrue(void *prv) -{ - Py_RETURN_TRUE; -} - -JSOBJ Object_newFalse(void *prv) -{ - Py_RETURN_FALSE; -} - -JSOBJ Object_newNull(void *prv) -{ - Py_RETURN_NONE; -} - -JSOBJ Object_newObject(void *prv) -{ - return PyDict_New(); -} - -JSOBJ Object_newArray(void *prv) -{ - return PyList_New(0); -} - -JSOBJ Object_newInteger(void *prv, JSINT32 value) -{ - return PyInt_FromLong( (long) value); -} - -JSOBJ Object_newLong(void *prv, JSINT64 value) -{ - return PyLong_FromLongLong (value); -} - -JSOBJ Object_newUnsignedLong(void *prv, JSUINT64 value) -{ - return PyLong_FromUnsignedLongLong (value); -} - -JSOBJ Object_newDouble(void *prv, double value) -{ - return PyFloat_FromDouble(value); -} - -static void Object_releaseObject(void *prv, JSOBJ obj) -{ - Py_DECREF( ((PyObject *)obj)); -} - -static char *g_kwlist[] = {"obj", "precise_float", NULL}; - -PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs) -{ - PyObject *ret; - PyObject *sarg; - PyObject *arg; - PyObject *opreciseFloat = NULL; - JSONObjectDecoder decoder = - { - Object_newString, - Object_objectAddKey, - Object_arrayAddItem, - Object_newTrue, - Object_newFalse, - Object_newNull, - Object_newObject, - Object_newArray, - Object_newInteger, - Object_newLong, - Object_newUnsignedLong, - Object_newDouble, - Object_releaseObject, - PyObject_Malloc, - PyObject_Free, - PyObject_Realloc - }; - - decoder.preciseFloat = 0; - decoder.prv = NULL; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", g_kwlist, &arg, &opreciseFloat)) - { - return NULL; - } - - if (opreciseFloat && PyObject_IsTrue(opreciseFloat)) - { - decoder.preciseFloat = 1; - } - - if (PyString_Check(arg)) - { - sarg = arg; - } - else - if (PyUnicode_Check(arg)) - { - sarg = PyUnicode_AsEncodedString(arg, NULL, "surrogatepass"); - if (sarg == NULL) - { - //Exception raised above us by codec according to docs - return NULL; - } - } - else - { - PyErr_Format(PyExc_TypeError, "Expected String or Unicode"); - return NULL; - } - - decoder.errorStr = NULL; - decoder.errorOffset = NULL; - - ret = JSON_DecodeObject(&decoder, PyString_AS_STRING(sarg), PyString_GET_SIZE(sarg)); - - if (sarg != arg) - { - Py_DECREF(sarg); - } - - if (decoder.errorStr) - { - /* - FIXME: It's possible to give a much nicer error message here with actual failing element in input etc*/ - - PyErr_Format (PyExc_ValueError, "%s", decoder.errorStr); - - if (ret) - { - Py_DECREF( (PyObject *) ret); - } - - return NULL; - } - - return ret; -} - -PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs) -{ - PyObject *read; - PyObject *string; - PyObject *result; - PyObject *file = NULL; - PyObject *argtuple; - - if (!PyArg_ParseTuple (args, "O", &file)) - { - return NULL; - } - - if (!PyObject_HasAttrString (file, "read")) - { - PyErr_Format (PyExc_TypeError, "expected file"); - return NULL; - } - - read = PyObject_GetAttrString (file, "read"); - - if (!PyCallable_Check (read)) { - Py_XDECREF(read); - PyErr_Format (PyExc_TypeError, "expected file"); - return NULL; - } - - string = PyObject_CallObject (read, NULL); - Py_XDECREF(read); - - if (string == NULL) - { - return NULL; - } - - argtuple = PyTuple_Pack(1, string); - - result = JSONToObj (self, argtuple, kwargs); - - Py_XDECREF(argtuple); - Py_XDECREF(string); - - if (result == NULL) { - return NULL; - } - - return result; -} diff --git a/srsly/ujson/__init__.py b/srsly/ujson/__init__.py deleted file mode 100644 index 744ff70..0000000 --- a/srsly/ujson/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ujson import decode, encode, dump, dumps, load, loads # noqa: F401 diff --git a/srsly/ujson/lib/dconv_wrapper.cc b/srsly/ujson/lib/dconv_wrapper.cc deleted file mode 100644 index 27e1b9b..0000000 --- a/srsly/ujson/lib/dconv_wrapper.cc +++ /dev/null @@ -1,58 +0,0 @@ -#include "double-conversion.h" - -namespace double_conversion -{ - static StringToDoubleConverter* s2d_instance = NULL; - static DoubleToStringConverter* d2s_instance = NULL; - - extern "C" - { - void dconv_d2s_init(int flags, - const char* infinity_symbol, - const char* nan_symbol, - char exponent_character, - int decimal_in_shortest_low, - int decimal_in_shortest_high, - int max_leading_padding_zeroes_in_precision_mode, - int max_trailing_padding_zeroes_in_precision_mode) - { - d2s_instance = new DoubleToStringConverter(flags, infinity_symbol, nan_symbol, - exponent_character, decimal_in_shortest_low, - decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode, - max_trailing_padding_zeroes_in_precision_mode); - } - - int dconv_d2s(double value, char* buf, int buflen, int* strlength) - { - StringBuilder sb(buf, buflen); - int success = static_cast(d2s_instance->ToShortest(value, &sb)); - *strlength = success ? sb.position() : -1; - return success; - } - - void dconv_d2s_free() - { - delete d2s_instance; - d2s_instance = NULL; - } - - void dconv_s2d_init(int flags, double empty_string_value, - double junk_string_value, const char* infinity_symbol, - const char* nan_symbol) - { - s2d_instance = new StringToDoubleConverter(flags, empty_string_value, - junk_string_value, infinity_symbol, nan_symbol); - } - - double dconv_s2d(const char* buffer, int length, int* processed_characters_count) - { - return s2d_instance->StringToDouble(buffer, length, processed_characters_count); - } - - void dconv_s2d_free() - { - delete s2d_instance; - s2d_instance = NULL; - } - } -} diff --git a/srsly/ujson/lib/ultrajson.h b/srsly/ujson/lib/ultrajson.h deleted file mode 100644 index a117901..0000000 --- a/srsly/ujson/lib/ultrajson.h +++ /dev/null @@ -1,324 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms - * Copyright (c) 1988-1993 The Regents of the University of California. - * Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -/* -Ultra fast JSON encoder and decoder -Developed by Jonas Tarnstrom (jonas@esn.me). - -Encoder notes: ------------------- - -:: Cyclic references :: -Cyclic referenced objects are not detected. -Set JSONObjectEncoder.recursionMax to suitable value or make sure input object -tree doesn't have cyclic references. - -*/ - -#ifndef __ULTRAJSON_H__ -#define __ULTRAJSON_H__ - -#include - -// Max decimals to encode double floating point numbers with -#ifndef JSON_DOUBLE_MAX_DECIMALS -#define JSON_DOUBLE_MAX_DECIMALS 15 -#endif - -// Max recursion depth, default for encoder -#ifndef JSON_MAX_RECURSION_DEPTH -#define JSON_MAX_RECURSION_DEPTH 1024 -#endif - -// Max recursion depth, default for decoder -#ifndef JSON_MAX_OBJECT_DEPTH -#define JSON_MAX_OBJECT_DEPTH 1024 -#endif - -/* -Dictates and limits how much stack space for buffers UltraJSON will use before resorting to provided heap functions */ -#ifndef JSON_MAX_STACK_BUFFER_SIZE -#define JSON_MAX_STACK_BUFFER_SIZE 131072 -#endif - -#ifdef _WIN32 - -typedef __int64 JSINT64; -typedef unsigned __int64 JSUINT64; - -typedef __int32 JSINT32; -typedef unsigned __int32 JSUINT32; -typedef unsigned __int8 JSUINT8; -typedef unsigned __int16 JSUTF16; -typedef unsigned __int32 JSUTF32; -typedef __int64 JSLONG; - -#define EXPORTFUNCTION __declspec(dllexport) - -#define FASTCALL_MSVC __fastcall -#define FASTCALL_ATTR -#define INLINE_PREFIX __inline - -#else - -#include -typedef int64_t JSINT64; -typedef uint64_t JSUINT64; - -typedef int32_t JSINT32; -typedef uint32_t JSUINT32; - -#define FASTCALL_MSVC - -#if !defined __x86_64__ -#define FASTCALL_ATTR __attribute__((fastcall)) -#else -#define FASTCALL_ATTR -#endif - -#define INLINE_PREFIX inline - -typedef uint8_t JSUINT8; -typedef uint16_t JSUTF16; -typedef uint32_t JSUTF32; - -typedef int64_t JSLONG; - -#define EXPORTFUNCTION -#endif - -#if !(defined(__LITTLE_ENDIAN__) || defined(__BIG_ENDIAN__)) - -#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ -#define __LITTLE_ENDIAN__ -#else - -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define __BIG_ENDIAN__ -#endif - -#endif - -#endif - -#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__) -#error "Endianess not supported" -#endif - -enum JSTYPES -{ - JT_NULL, // NULL - JT_TRUE, // boolean true - JT_FALSE, // boolean false - JT_INT, // (JSINT32 (signed 32-bit)) - JT_LONG, // (JSINT64 (signed 64-bit)) - JT_ULONG, // (JSUINT64 (unsigned 64-bit)) - JT_DOUBLE, // (double) - JT_UTF8, // (char 8-bit) - JT_RAW, // (raw char 8-bit) - JT_ARRAY, // Array structure - JT_OBJECT, // Key/Value structure - JT_INVALID, // Internal, do not return nor expect -}; - -typedef void * JSOBJ; -typedef void * JSITER; - -typedef struct __JSONTypeContext -{ - int type; - void *prv; - void *encoder_prv; -} JSONTypeContext; - -/* -Function pointer declarations, suitable for implementing UltraJSON */ -typedef int (*JSPFN_ITERNEXT)(JSOBJ obj, JSONTypeContext *tc); -typedef void (*JSPFN_ITEREND)(JSOBJ obj, JSONTypeContext *tc); -typedef JSOBJ (*JSPFN_ITERGETVALUE)(JSOBJ obj, JSONTypeContext *tc); -typedef char *(*JSPFN_ITERGETNAME)(JSOBJ obj, JSONTypeContext *tc, size_t *outLen); -typedef void *(*JSPFN_MALLOC)(size_t size); -typedef void (*JSPFN_FREE)(void *pptr); -typedef void *(*JSPFN_REALLOC)(void *base, size_t size); - - -struct __JSONObjectEncoder; - -typedef struct __JSONObjectEncoder -{ - void (*beginTypeContext)(JSOBJ obj, JSONTypeContext *tc, struct __JSONObjectEncoder *enc); - void (*endTypeContext)(JSOBJ obj, JSONTypeContext *tc); - const char *(*getStringValue)(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen); - JSINT64 (*getLongValue)(JSOBJ obj, JSONTypeContext *tc); - JSUINT64 (*getUnsignedLongValue)(JSOBJ obj, JSONTypeContext *tc); - JSINT32 (*getIntValue)(JSOBJ obj, JSONTypeContext *tc); - double (*getDoubleValue)(JSOBJ obj, JSONTypeContext *tc); - - /* - Retrieve next object in an iteration. Should return 0 to indicate iteration has reached end or 1 if there are more items. - Implementor is responsible for keeping state of the iteration. Use ti->prv fields for this - */ - JSPFN_ITERNEXT iterNext; - - /* - Ends the iteration of an iteratable object. - Any iteration state stored in ti->prv can be freed here - */ - JSPFN_ITEREND iterEnd; - - /* - Returns a reference to the value object of an iterator - The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object - */ - JSPFN_ITERGETVALUE iterGetValue; - - /* - Return name of iterator. - The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object - */ - JSPFN_ITERGETNAME iterGetName; - - /* - Release a value as indicated by setting ti->release = 1 in the previous getValue call. - The ti->prv array should contain the necessary context to release the value - */ - void (*releaseObject)(JSOBJ obj); - - /* Library functions - Set to NULL to use STDLIB malloc,realloc,free */ - JSPFN_MALLOC malloc; - JSPFN_REALLOC realloc; - JSPFN_FREE free; - - /* - Configuration for max recursion, set to 0 to use default (see JSON_MAX_RECURSION_DEPTH)*/ - int recursionMax; - - /* - Configuration for max decimals of double floating point numbers to encode (0-9) */ - int doublePrecision; - - /* - If true output will be ASCII with all characters above 127 encoded as \uXXXX. If false output will be UTF-8 or what ever charset strings are brought as */ - int forceASCII; - - /* - If true, '<', '>', and '&' characters will be encoded as \u003c, \u003e, and \u0026, respectively. If false, no special encoding will be used. */ - int encodeHTMLChars; - - /* - If true, '/' will be encoded as \/. If false, no escaping. */ - int escapeForwardSlashes; - - /* - If true, dictionaries are iterated through in sorted key order. */ - int sortKeys; - - /* - Configuration for spaces of indent */ - int indent; - - /* - Private pointer to be used by the caller. Passed as encoder_prv in JSONTypeContext */ - void *prv; - - /* - Set to an error message if error occured */ - const char *errorMsg; - JSOBJ errorObj; - - /* Buffer stuff */ - char *start; - char *offset; - char *end; - int heap; - int level; - -} JSONObjectEncoder; - - -/* -Encode an object structure into JSON. - -Arguments: -obj - An anonymous type representing the object -enc - Function definitions for querying JSOBJ type -buffer - Preallocated buffer to store result in. If NULL function allocates own buffer -cbBuffer - Length of buffer (ignored if buffer is NULL) - -Returns: -Encoded JSON object as a null terminated char string. - -NOTE: -If the supplied buffer wasn't enough to hold the result the function will allocate a new buffer. -Life cycle of the provided buffer must still be handled by caller. - -If the return value doesn't equal the specified buffer caller must release the memory using -JSONObjectEncoder.free or free() as specified when calling this function. -*/ -EXPORTFUNCTION char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *buffer, size_t cbBuffer); - - - -typedef struct __JSONObjectDecoder -{ - JSOBJ (*newString)(void *prv, JSUINT32 *start, JSUINT32 *end); - void (*objectAddKey)(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value); - void (*arrayAddItem)(void *prv, JSOBJ obj, JSOBJ value); - JSOBJ (*newTrue)(void *prv); - JSOBJ (*newFalse)(void *prv); - JSOBJ (*newNull)(void *prv); - JSOBJ (*newObject)(void *prv); - JSOBJ (*newArray)(void *prv); - JSOBJ (*newInt)(void *prv, JSINT32 value); - JSOBJ (*newLong)(void *prv, JSINT64 value); - JSOBJ (*newUnsignedLong)(void *prv, JSUINT64 value); - JSOBJ (*newDouble)(void *prv, double value); - void (*releaseObject)(void *prv, JSOBJ obj); - JSPFN_MALLOC malloc; - JSPFN_FREE free; - JSPFN_REALLOC realloc; - char *errorStr; - char *errorOffset; - int preciseFloat; - void *prv; -} JSONObjectDecoder; - -EXPORTFUNCTION JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer); - -#endif diff --git a/srsly/ujson/lib/ultrajsondec.c b/srsly/ujson/lib/ultrajsondec.c deleted file mode 100644 index a88a008..0000000 --- a/srsly/ujson/lib/ultrajsondec.c +++ /dev/null @@ -1,891 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms -* Copyright (c) 1988-1993 The Regents of the University of California. -* Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include "ultrajson.h" -#include -#include -#include -#include -#include -#include - -#ifndef TRUE -#define TRUE 1 -#define FALSE 0 -#endif -#ifndef NULL -#define NULL 0 -#endif - -struct DecoderState -{ - char *start; - char *end; - JSUINT32 *escStart; - JSUINT32 *escEnd; - int escHeap; - int lastType; - JSUINT32 objDepth; - void *prv; - JSONObjectDecoder *dec; -}; - -JSOBJ FASTCALL_MSVC decode_any( struct DecoderState *ds) FASTCALL_ATTR; -typedef JSOBJ (*PFN_DECODER)( struct DecoderState *ds); - -static JSOBJ SetError( struct DecoderState *ds, int offset, const char *message) -{ - ds->dec->errorOffset = ds->start + offset; - ds->dec->errorStr = (char *) message; - return NULL; -} - -double createDouble(double intNeg, double intValue, double frcValue, int frcDecimalCount) -{ - static const double g_pow10[] = {1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001,0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001, 0.0000000000001, 0.00000000000001, 0.000000000000001}; - return (intValue + (frcValue * g_pow10[frcDecimalCount])) * intNeg; -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decodePreciseFloat(struct DecoderState *ds) -{ - char *end; - double value; - errno = 0; - - value = strtod(ds->start, &end); - - if (errno == ERANGE) - { - return SetError(ds, -1, "Range error when decoding numeric as double"); - } - - ds->start = end; - return ds->dec->newDouble(ds->prv, value); -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_numeric (struct DecoderState *ds) -{ - int intNeg = 1; - int mantSize = 0; - JSUINT64 intValue; - JSUINT64 prevIntValue; - int chr; - int decimalCount = 0; - double frcValue = 0.0; - double expNeg; - double expValue; - char *offset = ds->start; - - JSUINT64 overflowLimit = LLONG_MAX; - - if (*(offset) == '-') - { - offset ++; - intNeg = -1; - overflowLimit = LLONG_MIN; - } - - // Scan integer part - intValue = 0; - - while (1) - { - chr = (int) (unsigned char) *(offset); - - switch (chr) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - //PERF: Don't do 64-bit arithmetic here unless we know we have to - prevIntValue = intValue; - intValue = intValue * 10ULL + (JSLONG) (chr - 48); - - if (intNeg == 1 && prevIntValue > intValue) - { - return SetError(ds, -1, "Value is too big!"); - } - else if (intNeg == -1 && intValue > overflowLimit) - { - return SetError(ds, -1, overflowLimit == LLONG_MAX ? "Value is too big!" : "Value is too small"); - } - - offset ++; - mantSize ++; - break; - } - case '.': - { - offset ++; - goto DECODE_FRACTION; - break; - } - case 'e': - case 'E': - { - offset ++; - goto DECODE_EXPONENT; - break; - } - - default: - { - goto BREAK_INT_LOOP; - break; - } - } - } - -BREAK_INT_LOOP: - - ds->lastType = JT_INT; - ds->start = offset; - - if (intNeg == 1 && (intValue & 0x8000000000000000ULL) != 0) - { - return ds->dec->newUnsignedLong(ds->prv, intValue); - } - else if ((intValue >> 31)) - { - return ds->dec->newLong(ds->prv, (JSINT64) (intValue * (JSINT64) intNeg)); - } - else - { - return ds->dec->newInt(ds->prv, (JSINT32) (intValue * intNeg)); - } - -DECODE_FRACTION: - - if (ds->dec->preciseFloat) - { - return decodePreciseFloat(ds); - } - - // Scan fraction part - frcValue = 0.0; - for (;;) - { - chr = (int) (unsigned char) *(offset); - - switch (chr) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - if (decimalCount < JSON_DOUBLE_MAX_DECIMALS) - { - frcValue = frcValue * 10.0 + (double) (chr - 48); - decimalCount ++; - } - offset ++; - break; - } - case 'e': - case 'E': - { - offset ++; - goto DECODE_EXPONENT; - break; - } - default: - { - goto BREAK_FRC_LOOP; - } - } - } - -BREAK_FRC_LOOP: - //FIXME: Check for arithemtic overflow here - ds->lastType = JT_DOUBLE; - ds->start = offset; - return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue, frcValue, decimalCount)); - -DECODE_EXPONENT: - if (ds->dec->preciseFloat) - { - return decodePreciseFloat(ds); - } - - expNeg = 1.0; - - if (*(offset) == '-') - { - expNeg = -1.0; - offset ++; - } - else - if (*(offset) == '+') - { - expNeg = +1.0; - offset ++; - } - - expValue = 0.0; - - for (;;) - { - chr = (int) (unsigned char) *(offset); - - switch (chr) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - expValue = expValue * 10.0 + (double) (chr - 48); - offset ++; - break; - } - default: - { - goto BREAK_EXP_LOOP; - } - } - } - -BREAK_EXP_LOOP: - //FIXME: Check for arithemtic overflow here - ds->lastType = JT_DOUBLE; - ds->start = offset; - return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue , frcValue, decimalCount) * pow(10.0, expValue * expNeg)); -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_true ( struct DecoderState *ds) -{ - char *offset = ds->start; - offset ++; - - if (*(offset++) != 'r') - goto SETERROR; - if (*(offset++) != 'u') - goto SETERROR; - if (*(offset++) != 'e') - goto SETERROR; - - ds->lastType = JT_TRUE; - ds->start = offset; - return ds->dec->newTrue(ds->prv); - -SETERROR: - return SetError(ds, -1, "Unexpected character found when decoding 'true'"); -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_false ( struct DecoderState *ds) -{ - char *offset = ds->start; - offset ++; - - if (*(offset++) != 'a') - goto SETERROR; - if (*(offset++) != 'l') - goto SETERROR; - if (*(offset++) != 's') - goto SETERROR; - if (*(offset++) != 'e') - goto SETERROR; - - ds->lastType = JT_FALSE; - ds->start = offset; - return ds->dec->newFalse(ds->prv); - -SETERROR: - return SetError(ds, -1, "Unexpected character found when decoding 'false'"); -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_null ( struct DecoderState *ds) -{ - char *offset = ds->start; - offset ++; - - if (*(offset++) != 'u') - goto SETERROR; - if (*(offset++) != 'l') - goto SETERROR; - if (*(offset++) != 'l') - goto SETERROR; - - ds->lastType = JT_NULL; - ds->start = offset; - return ds->dec->newNull(ds->prv); - -SETERROR: - return SetError(ds, -1, "Unexpected character found when decoding 'null'"); -} - -FASTCALL_ATTR void FASTCALL_MSVC SkipWhitespace(struct DecoderState *ds) -{ - char *offset = ds->start; - - for (;;) - { - switch (*offset) - { - case ' ': - case '\t': - case '\r': - case '\n': - offset ++; - break; - - default: - ds->start = offset; - return; - } - } -} - -enum DECODESTRINGSTATE -{ - DS_ISNULL = 0x32, - DS_ISQUOTE, - DS_ISESCAPE, - DS_UTFLENERROR, - -}; - -static const JSUINT8 g_decoderLookup[256] = -{ - /* 0x00 */ DS_ISNULL, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x10 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x20 */ 1, 1, DS_ISQUOTE, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, DS_ISESCAPE, 1, 1, 1, - /* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - /* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - /* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - /* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - /* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, -}; - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_string ( struct DecoderState *ds) -{ - int index; - JSUINT32 *escOffset; - JSUINT32 *escStart; - size_t escLen = (ds->escEnd - ds->escStart); - JSUINT8 *inputOffset; - JSUTF16 ch = 0; - JSUINT8 *lastHighSurrogate = NULL; - JSUINT8 oct; - JSUTF32 ucs; - ds->lastType = JT_INVALID; - ds->start ++; - - if ( (size_t) (ds->end - ds->start) > escLen) - { - size_t newSize = (ds->end - ds->start); - - if (ds->escHeap) - { - if (newSize > (SIZE_MAX / sizeof(JSUINT32))) - { - return SetError(ds, -1, "Could not reserve memory block"); - } - escStart = (JSUINT32 *)ds->dec->realloc(ds->escStart, newSize * sizeof(JSUINT32)); - if (!escStart) - { - ds->dec->free(ds->escStart); - return SetError(ds, -1, "Could not reserve memory block"); - } - ds->escStart = escStart; - } - else - { - JSUINT32 *oldStart = ds->escStart; - if (newSize > (SIZE_MAX / sizeof(JSUINT32))) - { - return SetError(ds, -1, "Could not reserve memory block"); - } - ds->escStart = (JSUINT32 *) ds->dec->malloc(newSize * sizeof(JSUINT32)); - if (!ds->escStart) - { - return SetError(ds, -1, "Could not reserve memory block"); - } - ds->escHeap = 1; - memcpy(ds->escStart, oldStart, escLen * sizeof(JSUINT32)); - } - - ds->escEnd = ds->escStart + newSize; - } - - escOffset = ds->escStart; - inputOffset = (JSUINT8 *) ds->start; - - for (;;) - { - switch (g_decoderLookup[(JSUINT8)(*inputOffset)]) - { - case DS_ISNULL: - { - return SetError(ds, -1, "Unmatched ''\"' when when decoding 'string'"); - } - case DS_ISQUOTE: - { - ds->lastType = JT_UTF8; - inputOffset ++; - ds->start += ( (char *) inputOffset - (ds->start)); - return ds->dec->newString(ds->prv, ds->escStart, escOffset); - } - case DS_UTFLENERROR: - { - return SetError (ds, -1, "Invalid UTF-8 sequence length when decoding 'string'"); - } - case DS_ISESCAPE: - inputOffset ++; - switch (*inputOffset) - { - case '\\': *(escOffset++) = '\\'; inputOffset++; continue; - case '\"': *(escOffset++) = '\"'; inputOffset++; continue; - case '/': *(escOffset++) = '/'; inputOffset++; continue; - case 'b': *(escOffset++) = '\b'; inputOffset++; continue; - case 'f': *(escOffset++) = '\f'; inputOffset++; continue; - case 'n': *(escOffset++) = '\n'; inputOffset++; continue; - case 'r': *(escOffset++) = '\r'; inputOffset++; continue; - case 't': *(escOffset++) = '\t'; inputOffset++; continue; - - case 'u': - { - int index; - inputOffset ++; - - for (index = 0; index < 4; index ++) - { - switch (*inputOffset) - { - case '\0': return SetError (ds, -1, "Unterminated unicode escape sequence when decoding 'string'"); - default: return SetError (ds, -1, "Unexpected character in unicode escape sequence when decoding 'string'"); - - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - ch = (ch << 4) + (JSUTF16) (*inputOffset - '0'); - break; - - case 'a': - case 'b': - case 'c': - case 'd': - case 'e': - case 'f': - ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'a'); - break; - - case 'A': - case 'B': - case 'C': - case 'D': - case 'E': - case 'F': - ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'A'); - break; - } - - inputOffset ++; - } - - if ((ch & 0xfc00) == 0xdc00 && lastHighSurrogate == inputOffset - 6 * sizeof(*inputOffset)) - { - // Low surrogate immediately following a high surrogate - // Overwrite existing high surrogate with combined character - *(escOffset-1) = (((*(escOffset-1) - 0xd800) <<10) | (ch - 0xdc00)) + 0x10000; - } - else - { - *(escOffset++) = (JSUINT32) ch; - } - if ((ch & 0xfc00) == 0xd800) - { - lastHighSurrogate = inputOffset; - } - break; - } - - case '\0': return SetError(ds, -1, "Unterminated escape sequence when decoding 'string'"); - default: return SetError(ds, -1, "Unrecognized escape sequence when decoding 'string'"); - } - break; - - case 1: - { - *(escOffset++) = (JSUINT32) (*inputOffset++); - break; - } - - case 2: - { - ucs = (*inputOffset++) & 0x1f; - ucs <<= 6; - if (((*inputOffset) & 0x80) != 0x80) - { - return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'"); - } - ucs |= (*inputOffset++) & 0x3f; - if (ucs < 0x80) return SetError (ds, -1, "Overlong 2 byte UTF-8 sequence detected when decoding 'string'"); - *(escOffset++) = (JSUINT32) ucs; - break; - } - - case 3: - { - JSUTF32 ucs = 0; - ucs |= (*inputOffset++) & 0x0f; - - for (index = 0; index < 2; index ++) - { - ucs <<= 6; - oct = (*inputOffset++); - - if ((oct & 0x80) != 0x80) - { - return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'"); - } - - ucs |= oct & 0x3f; - } - - if (ucs < 0x800) return SetError (ds, -1, "Overlong 3 byte UTF-8 sequence detected when encoding string"); - *(escOffset++) = (JSUINT32) ucs; - break; - } - - case 4: - { - JSUTF32 ucs = 0; - ucs |= (*inputOffset++) & 0x07; - - for (index = 0; index < 3; index ++) - { - ucs <<= 6; - oct = (*inputOffset++); - - if ((oct & 0x80) != 0x80) - { - return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'"); - } - - ucs |= oct & 0x3f; - } - - if (ucs < 0x10000) return SetError (ds, -1, "Overlong 4 byte UTF-8 sequence detected when decoding 'string'"); - - *(escOffset++) = (JSUINT32) ucs; - break; - } - } - } -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_array(struct DecoderState *ds) -{ - JSOBJ itemValue; - JSOBJ newObj; - int len; - ds->objDepth++; - if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) { - return SetError(ds, -1, "Reached object decoding depth limit"); - } - - newObj = ds->dec->newArray(ds->prv); - len = 0; - - ds->lastType = JT_INVALID; - ds->start ++; - - for (;;) - { - SkipWhitespace(ds); - - if ((*ds->start) == ']') - { - ds->objDepth--; - if (len == 0) - { - ds->start ++; - return newObj; - } - - ds->dec->releaseObject(ds->prv, newObj); - return SetError(ds, -1, "Unexpected character found when decoding array value (1)"); - } - - itemValue = decode_any(ds); - - if (itemValue == NULL) - { - ds->dec->releaseObject(ds->prv, newObj); - return NULL; - } - - ds->dec->arrayAddItem (ds->prv, newObj, itemValue); - - SkipWhitespace(ds); - - switch (*(ds->start++)) - { - case ']': - { - ds->objDepth--; - return newObj; - } - case ',': - break; - - default: - ds->dec->releaseObject(ds->prv, newObj); - return SetError(ds, -1, "Unexpected character found when decoding array value (2)"); - } - - len ++; - } -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_object( struct DecoderState *ds) -{ - JSOBJ itemName; - JSOBJ itemValue; - JSOBJ newObj; - int len; - - ds->objDepth++; - if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) { - return SetError(ds, -1, "Reached object decoding depth limit"); - } - - newObj = ds->dec->newObject(ds->prv); - len = 0; - - ds->start ++; - - for (;;) - { - SkipWhitespace(ds); - - if ((*ds->start) == '}') - { - ds->objDepth--; - if (len == 0) - { - ds->start ++; - return newObj; - } - - ds->dec->releaseObject(ds->prv, newObj); - return SetError(ds, -1, "Unexpected character in found when decoding object value"); - } - - ds->lastType = JT_INVALID; - itemName = decode_any(ds); - - if (itemName == NULL) - { - ds->dec->releaseObject(ds->prv, newObj); - return NULL; - } - - if (ds->lastType != JT_UTF8) - { - ds->dec->releaseObject(ds->prv, newObj); - ds->dec->releaseObject(ds->prv, itemName); - return SetError(ds, -1, "Key name of object must be 'string' when decoding 'object'"); - } - - SkipWhitespace(ds); - - if (*(ds->start++) != ':') - { - ds->dec->releaseObject(ds->prv, newObj); - ds->dec->releaseObject(ds->prv, itemName); - return SetError(ds, -1, "No ':' found when decoding object value"); - } - - SkipWhitespace(ds); - - itemValue = decode_any(ds); - - if (itemValue == NULL) - { - ds->dec->releaseObject(ds->prv, newObj); - ds->dec->releaseObject(ds->prv, itemName); - return NULL; - } - - ds->dec->objectAddKey (ds->prv, newObj, itemName, itemValue); - - SkipWhitespace(ds); - - switch (*(ds->start++)) - { - case '}': - { - ds->objDepth--; - return newObj; - } - case ',': - break; - - default: - ds->dec->releaseObject(ds->prv, newObj); - return SetError(ds, -1, "Unexpected character in found when decoding object value"); - } - - len++; - } -} - -FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_any(struct DecoderState *ds) -{ - for (;;) - { - switch (*ds->start) - { - case '\"': - return decode_string (ds); - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - case '-': - return decode_numeric (ds); - - case '[': return decode_array (ds); - case '{': return decode_object (ds); - case 't': return decode_true (ds); - case 'f': return decode_false (ds); - case 'n': return decode_null (ds); - - case ' ': - case '\t': - case '\r': - case '\n': - // White space - ds->start ++; - break; - - default: - return SetError(ds, -1, "Expected object or value"); - } - } -} - -JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer) -{ - /* - FIXME: Base the size of escBuffer of that of cbBuffer so that the unicode escaping doesn't run into the wall each time */ - struct DecoderState ds; - JSUINT32 escBuffer[(JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32))]; - JSOBJ ret; - - ds.start = (char *) buffer; - ds.end = ds.start + cbBuffer; - - ds.escStart = escBuffer; - ds.escEnd = ds.escStart + (JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32)); - ds.escHeap = 0; - ds.prv = dec->prv; - ds.dec = dec; - ds.dec->errorStr = NULL; - ds.dec->errorOffset = NULL; - ds.objDepth = 0; - - ds.dec = dec; - - ret = decode_any (&ds); - - if (ds.escHeap) - { - dec->free(ds.escStart); - } - - if (!(dec->errorStr)) - { - if ((ds.end - ds.start) > 0) - { - SkipWhitespace(&ds); - } - - if (ds.start != ds.end && ret) - { - dec->releaseObject(ds.prv, ret); - return SetError(&ds, -1, "Trailing data"); - } - } - - return ret; -} diff --git a/srsly/ujson/lib/ultrajsonenc.c b/srsly/ujson/lib/ultrajsonenc.c deleted file mode 100644 index ea6372b..0000000 --- a/srsly/ujson/lib/ultrajsonenc.c +++ /dev/null @@ -1,1067 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms - * Copyright (c) 1988-1993 The Regents of the University of California. - * Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include "ultrajson.h" -#include -#include -#include -#include -#include -#include - -#include - -#ifndef TRUE -#define TRUE 1 -#endif -#ifndef FALSE -#define FALSE 0 -#endif - -#if ( (defined(_WIN32) || defined(WIN32) ) && ( defined(_MSC_VER) ) ) -#define snprintf sprintf_s -#endif - -/* -Worst cases being: - -Control characters (ASCII < 32) -0x00 (1 byte) input => \u0000 output (6 bytes) -1 * 6 => 6 (6 bytes required) - -or UTF-16 surrogate pairs -4 bytes input in UTF-8 => \uXXXX\uYYYY (12 bytes). - -4 * 6 => 24 bytes (12 bytes required) - -The extra 2 bytes are for the quotes around the string - -*/ -#define RESERVE_STRING(_len) (2 + ((_len) * 6)) - -static const double g_pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000, 10000000000, 100000000000, 1000000000000, 10000000000000, 100000000000000, 1000000000000000}; -static const char g_hexChars[] = "0123456789abcdef"; -static const char g_escapeChars[] = "0123456789\\b\\t\\n\\f\\r\\\"\\\\\\/"; - -/* -FIXME: While this is fine dandy and working it's a magic value mess which probably only the author understands. -Needs a cleanup and more documentation */ - -/* -Table for pure ascii output escaping all characters above 127 to \uXXXX */ -static const JSUINT8 g_asciiOutputTable[256] = -{ -/* 0x00 */ 0, 30, 30, 30, 30, 30, 30, 30, 10, 12, 14, 30, 16, 18, 30, 30, -/* 0x10 */ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, -/* 0x20 */ 1, 1, 20, 1, 1, 1, 29, 1, 1, 1, 1, 1, 1, 1, 1, 24, -/* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 29, 1, 29, 1, -/* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 22, 1, 1, 1, -/* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -/* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, -/* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, -/* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, -/* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 1, 1 -}; - -static void SetError (JSOBJ obj, JSONObjectEncoder *enc, const char *message) -{ - enc->errorMsg = message; - enc->errorObj = obj; -} - -/* -FIXME: Keep track of how big these get across several encoder calls and try to make an estimate -That way we won't run our head into the wall each call */ -void Buffer_Realloc (JSONObjectEncoder *enc, size_t cbNeeded) -{ - size_t free_space = enc->end - enc->offset; - if (free_space >= cbNeeded) - { - return; - } - size_t curSize = enc->end - enc->start; - size_t newSize = curSize; - size_t offset = enc->offset - enc->start; - -#ifdef DEBUG - // In debug mode, allocate only what is requested so that any miscalculation - // shows up plainly as a crash. - newSize = (enc->offset - enc->start) + cbNeeded; -#else - while (newSize < curSize + cbNeeded) - { - newSize *= 2; - } -#endif - - if (enc->heap) - { - enc->start = (char *) enc->realloc (enc->start, newSize); - if (!enc->start) - { - SetError (NULL, enc, "Could not reserve memory block"); - return; - } - } - else - { - char *oldStart = enc->start; - enc->heap = 1; - enc->start = (char *) enc->malloc (newSize); - if (!enc->start) - { - SetError (NULL, enc, "Could not reserve memory block"); - return; - } - memcpy (enc->start, oldStart, offset); - } - enc->offset = enc->start + offset; - enc->end = enc->start + newSize; -} - -#define Buffer_Reserve(__enc, __len) \ - if ( (size_t) ((__enc)->end - (__enc)->offset) < (size_t) (__len)) \ - { \ - Buffer_Realloc((__enc), (__len));\ - } \ - -FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendShortHexUnchecked (char *outputOffset, unsigned short value) -{ - *(outputOffset++) = g_hexChars[(value & 0xf000) >> 12]; - *(outputOffset++) = g_hexChars[(value & 0x0f00) >> 8]; - *(outputOffset++) = g_hexChars[(value & 0x00f0) >> 4]; - *(outputOffset++) = g_hexChars[(value & 0x000f) >> 0]; -} - -int Buffer_EscapeStringUnvalidated (JSONObjectEncoder *enc, const char *io, const char *end) -{ - char *of = (char *) enc->offset; - - for (;;) - { - switch (*io) - { - case 0x00: - { - if (io < end) - { - *(of++) = '\\'; - *(of++) = 'u'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = '0'; - break; - } - else - { - enc->offset += (of - enc->offset); - return TRUE; - } - } - case '\"': (*of++) = '\\'; (*of++) = '\"'; break; - case '\\': (*of++) = '\\'; (*of++) = '\\'; break; - case '\b': (*of++) = '\\'; (*of++) = 'b'; break; - case '\f': (*of++) = '\\'; (*of++) = 'f'; break; - case '\n': (*of++) = '\\'; (*of++) = 'n'; break; - case '\r': (*of++) = '\\'; (*of++) = 'r'; break; - case '\t': (*of++) = '\\'; (*of++) = 't'; break; - - case '/': - { - if (enc->escapeForwardSlashes) - { - (*of++) = '\\'; - (*of++) = '/'; - } - else - { - // Same as default case below. - (*of++) = (*io); - } - break; - } - case 0x26: // '&' - case 0x3c: // '<' - case 0x3e: // '>' - { - if (enc->encodeHTMLChars) - { - // Fall through to \u00XX case below. - } - else - { - // Same as default case below. - (*of++) = (*io); - break; - } - } - case 0x01: - case 0x02: - case 0x03: - case 0x04: - case 0x05: - case 0x06: - case 0x07: - case 0x0b: - case 0x0e: - case 0x0f: - case 0x10: - case 0x11: - case 0x12: - case 0x13: - case 0x14: - case 0x15: - case 0x16: - case 0x17: - case 0x18: - case 0x19: - case 0x1a: - case 0x1b: - case 0x1c: - case 0x1d: - case 0x1e: - case 0x1f: - { - *(of++) = '\\'; - *(of++) = 'u'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)]; - *(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)]; - break; - } - default: (*of++) = (*io); break; - } - io++; - } -} - -int Buffer_EscapeStringValidated (JSOBJ obj, JSONObjectEncoder *enc, const char *io, const char *end) -{ - JSUTF32 ucs; - char *of = (char *) enc->offset; - - for (;;) - { -#ifdef DEBUG - // 6 is the maximum length of a single character (cf. RESERVE_STRING). - if ((io < end) && (enc->end - of < 6)) { - fprintf(stderr, "Ran out of buffer space during Buffer_EscapeStringValidated()\n"); - abort(); - } -#endif - JSUINT8 utflen = g_asciiOutputTable[(unsigned char) *io]; - - switch (utflen) - { - case 0: - { - if (io < end) - { - *(of++) = '\\'; - *(of++) = 'u'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = '0'; - io ++; - continue; - } - else - { - enc->offset += (of - enc->offset); - return TRUE; - } - } - - case 1: - { - *(of++)= (*io++); - continue; - } - - case 2: - { - JSUTF32 in; - JSUTF16 in16; - - if (end - io < 1) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string"); - return FALSE; - } - - memcpy(&in16, io, sizeof(JSUTF16)); - in = (JSUTF32) in16; - -#ifdef __LITTLE_ENDIAN__ - ucs = ((in & 0x1f) << 6) | ((in >> 8) & 0x3f); -#else - ucs = ((in & 0x1f00) >> 2) | (in & 0x3f); -#endif - - if (ucs < 0x80) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Overlong 2 byte UTF-8 sequence detected when encoding string"); - return FALSE; - } - - io += 2; - break; - } - - case 3: - { - JSUTF32 in; - JSUTF16 in16; - JSUINT8 in8; - - if (end - io < 2) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string"); - return FALSE; - } - - memcpy(&in16, io, sizeof(JSUTF16)); - memcpy(&in8, io + 2, sizeof(JSUINT8)); -#ifdef __LITTLE_ENDIAN__ - in = (JSUTF32) in16; - in |= in8 << 16; - ucs = ((in & 0x0f) << 12) | ((in & 0x3f00) >> 2) | ((in & 0x3f0000) >> 16); -#else - in = in16 << 8; - in |= in8; - ucs = ((in & 0x0f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f); -#endif - - if (ucs < 0x800) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Overlong 3 byte UTF-8 sequence detected when encoding string"); - return FALSE; - } - - io += 3; - break; - } - case 4: - { - JSUTF32 in; - - if (end - io < 3) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string"); - return FALSE; - } - - memcpy(&in, io, sizeof(JSUTF32)); -#ifdef __LITTLE_ENDIAN__ - ucs = ((in & 0x07) << 18) | ((in & 0x3f00) << 4) | ((in & 0x3f0000) >> 10) | ((in & 0x3f000000) >> 24); -#else - ucs = ((in & 0x07000000) >> 6) | ((in & 0x3f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f); -#endif - if (ucs < 0x10000) - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Overlong 4 byte UTF-8 sequence detected when encoding string"); - return FALSE; - } - - io += 4; - break; - } - - - case 5: - case 6: - { - enc->offset += (of - enc->offset); - SetError (obj, enc, "Unsupported UTF-8 sequence length when encoding string"); - return FALSE; - } - - case 29: - { - if (enc->encodeHTMLChars) - { - // Fall through to \u00XX case 30 below. - } - else - { - // Same as case 1 above. - *(of++) = (*io++); - continue; - } - } - - case 30: - { - // \uXXXX encode - *(of++) = '\\'; - *(of++) = 'u'; - *(of++) = '0'; - *(of++) = '0'; - *(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)]; - *(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)]; - io ++; - continue; - } - case 10: - case 12: - case 14: - case 16: - case 18: - case 20: - case 22: - { - *(of++) = *( (char *) (g_escapeChars + utflen + 0)); - *(of++) = *( (char *) (g_escapeChars + utflen + 1)); - io ++; - continue; - } - case 24: - { - if (enc->escapeForwardSlashes) - { - *(of++) = *( (char *) (g_escapeChars + utflen + 0)); - *(of++) = *( (char *) (g_escapeChars + utflen + 1)); - io ++; - } - else - { - // Same as case 1 above. - *(of++) = (*io++); - } - continue; - } - // This can never happen, it's here to make L4 VC++ happy - default: - { - ucs = 0; - break; - } - } - - /* - If the character is a UTF8 sequence of length > 1 we end up here */ - if (ucs >= 0x10000) - { - ucs -= 0x10000; - *(of++) = '\\'; - *(of++) = 'u'; - Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs >> 10) + 0xd800); - of += 4; - - *(of++) = '\\'; - *(of++) = 'u'; - Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs & 0x3ff) + 0xdc00); - of += 4; - } - else - { - *(of++) = '\\'; - *(of++) = 'u'; - Buffer_AppendShortHexUnchecked(of, (unsigned short) ucs); - of += 4; - } - } -} - -static FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendCharUnchecked(JSONObjectEncoder *enc, char chr) -{ -#ifdef DEBUG - if (enc->end <= enc->offset) - { - fprintf(stderr, "Overflow writing byte %d '%c'. The last few characters were:\n'''", chr, chr); - char * recent = enc->offset - 1000; - if (enc->start > recent) - { - recent = enc->start; - } - for (; recent < enc->offset; recent++) - { - fprintf(stderr, "%c", *recent); - } - fprintf(stderr, "'''\n"); - abort(); - } -#endif - *(enc->offset++) = chr; -} - -FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC strreverse(char* begin, char* end) -{ - char aux; - while (end > begin) - aux = *end, *end-- = *begin, *begin++ = aux; -} - -void Buffer_AppendIndentNewlineUnchecked(JSONObjectEncoder *enc) -{ - if (enc->indent > 0) Buffer_AppendCharUnchecked(enc, '\n'); -} - -void Buffer_AppendIndentUnchecked(JSONObjectEncoder *enc, JSINT32 value) -{ - int i; - if (enc->indent > 0) - while (value-- > 0) - for (i = 0; i < enc->indent; i++) - Buffer_AppendCharUnchecked(enc, ' '); -} - -void Buffer_AppendIntUnchecked(JSONObjectEncoder *enc, JSINT32 value) -{ - char* wstr; - JSUINT32 uvalue = (value < 0) ? -value : value; - - wstr = enc->offset; - // Conversion. Number is reversed. - - do *wstr++ = (char)(48 + (uvalue % 10)); while(uvalue /= 10); - if (value < 0) *wstr++ = '-'; - - // Reverse string - strreverse(enc->offset,wstr - 1); - enc->offset += (wstr - (enc->offset)); -} - -void Buffer_AppendLongUnchecked(JSONObjectEncoder *enc, JSINT64 value) -{ - char* wstr; - JSUINT64 uvalue = (value < 0) ? -value : value; - - wstr = enc->offset; - // Conversion. Number is reversed. - - do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL); - if (value < 0) *wstr++ = '-'; - - // Reverse string - strreverse(enc->offset,wstr - 1); - enc->offset += (wstr - (enc->offset)); -} - -void Buffer_AppendUnsignedLongUnchecked(JSONObjectEncoder *enc, JSUINT64 value) -{ - char* wstr; - JSUINT64 uvalue = value; - - wstr = enc->offset; - // Conversion. Number is reversed. - - do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL); - - // Reverse string - strreverse(enc->offset,wstr - 1); - enc->offset += (wstr - (enc->offset)); -} - -int Buffer_AppendDoubleUnchecked(JSOBJ obj, JSONObjectEncoder *enc, double value) -{ - /* if input is larger than thres_max, revert to exponential */ - const double thres_max = (double) 1e16 - 1; - int count; - double diff = 0.0; - char* str = enc->offset; - char* wstr = str; - unsigned long long whole; - double tmp; - unsigned long long frac; - int neg; - double pow10; - - if (value == HUGE_VAL || value == -HUGE_VAL) - { - SetError (obj, enc, "Invalid Inf value when encoding double"); - return FALSE; - } - - if (!(value == value)) - { - SetError (obj, enc, "Invalid Nan value when encoding double"); - return FALSE; - } - - /* we'll work in positive values and deal with the - negative sign issue later */ - neg = 0; - if (value < 0) - { - neg = 1; - value = -value; - } - - pow10 = g_pow10[enc->doublePrecision]; - - whole = (unsigned long long) value; - tmp = (value - whole) * pow10; - frac = (unsigned long long)(tmp); - diff = tmp - frac; - - if (diff > 0.5) - { - ++frac; - /* handle rollover, e.g. case 0.99 with prec 1 is 1.0 */ - if (frac >= pow10) - { - frac = 0; - ++whole; - } - } - else - if (diff == 0.5 && ((frac == 0) || (frac & 1))) - { - /* if halfway, round up if odd, OR - if last digit is 0. That last part is strange */ - ++frac; - } - - /* for very large numbers switch back to native sprintf for exponentials. - anyone want to write code to replace this? */ - /* - normal printf behavior is to print EVERY whole number digit - which can be 100s of characters overflowing your buffers == bad - */ - if (value > thres_max) - { - enc->offset += snprintf(str, enc->end - enc->offset, "%.15e", neg ? -value : value); - return TRUE; - } - - if (enc->doublePrecision == 0) - { - diff = value - whole; - - if (diff > 0.5) - { - /* greater than 0.5, round up, e.g. 1.6 -> 2 */ - ++whole; - } - else - if (diff == 0.5 && (whole & 1)) - { - /* exactly 0.5 and ODD, then round up */ - /* 1.5 -> 2, but 2.5 -> 2 */ - ++whole; - } - - //vvvvvvvvvvvvvvvvvvv Diff from modp_dto2 - } - else - if (frac) - { - count = enc->doublePrecision; - // now do fractional part, as an unsigned number - // we know it is not 0 but we can have leading zeros, these - // should be removed - while (!(frac % 10)) - { - --count; - frac /= 10; - } - //^^^^^^^^^^^^^^^^^^^ Diff from modp_dto2 - - // now do fractional part, as an unsigned number - do - { - --count; - *wstr++ = (char)(48 + (frac % 10)); - } while (frac /= 10); - // add extra 0s - while (count-- > 0) - { - *wstr++ = '0'; - } - // add decimal - *wstr++ = '.'; - } - else - { - *wstr++ = '0'; - *wstr++ = '.'; - } - - // do whole part - // Take care of sign - // Conversion. Number is reversed. - do *wstr++ = (char)(48 + (whole % 10)); while (whole /= 10); - - if (neg) - { - *wstr++ = '-'; - } - strreverse(str, wstr-1); - enc->offset += (wstr - (enc->offset)); - - return TRUE; -} - -/* -FIXME: -Handle integration functions returning NULL here */ - -/* -FIXME: -Perhaps implement recursion detection */ - -void encode(JSOBJ obj, JSONObjectEncoder *enc, const char *name, size_t cbName) -{ - const char *value; - char *objName; - int count; - JSOBJ iterObj; - size_t szlen; - JSONTypeContext tc; - - if (enc->level > enc->recursionMax) - { - SetError (obj, enc, "Maximum recursion level reached"); - return; - } - - if (enc->errorMsg) - { - return; - } - - if (name) - { - // 2 extra for the colon and optional space after it - Buffer_Reserve(enc, RESERVE_STRING(cbName) + 2); - Buffer_AppendCharUnchecked(enc, '\"'); - - if (enc->forceASCII) - { - if (!Buffer_EscapeStringValidated(obj, enc, name, name + cbName)) - { - return; - } - } - else - { - if (!Buffer_EscapeStringUnvalidated(enc, name, name + cbName)) - { - return; - } - } - - Buffer_AppendCharUnchecked(enc, '\"'); - - Buffer_AppendCharUnchecked (enc, ':'); - } - - tc.encoder_prv = enc->prv; - enc->beginTypeContext(obj, &tc, enc); - - /* - This reservation covers any additions on non-variable parts below, specifically: - - Opening brackets for JT_ARRAY and JT_OBJECT - - Number representation for JT_LONG, JT_ULONG, JT_INT, and JT_DOUBLE - - Constant value for JT_TRUE, JT_FALSE, JT_NULL - The length of 128 is the worst case length of the Buffer_AppendDoubleDconv addition. - The other types above all have smaller representations. - */ - Buffer_Reserve (enc, 128); - - switch (tc.type) - { - case JT_INVALID: - { - return; - } - - case JT_ARRAY: - { - count = 0; - - Buffer_AppendCharUnchecked (enc, '['); - Buffer_AppendIndentNewlineUnchecked (enc); - - while (enc->iterNext(obj, &tc)) - { - // The extra 2 bytes cover the comma and (optional) newline. - Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2); - - if (count > 0) - { - Buffer_AppendCharUnchecked (enc, ','); - Buffer_AppendIndentNewlineUnchecked (enc); - } - - iterObj = enc->iterGetValue(obj, &tc); - - enc->level ++; - Buffer_AppendIndentUnchecked (enc, enc->level); - encode (iterObj, enc, NULL, 0); - count ++; - } - - enc->iterEnd(obj, &tc); - // Reserve space for the indentation plus the newline. - Buffer_Reserve (enc, enc->indent * enc->level + 1); - Buffer_AppendIndentNewlineUnchecked (enc); - Buffer_AppendIndentUnchecked (enc, enc->level); - Buffer_Reserve (enc, 1); - Buffer_AppendCharUnchecked (enc, ']'); - break; - } - - case JT_OBJECT: - { - count = 0; - - Buffer_AppendCharUnchecked (enc, '{'); - Buffer_AppendIndentNewlineUnchecked (enc); - - while (enc->iterNext(obj, &tc)) - { - // The extra 2 bytes cover the comma and optional newline. - Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2); - - if (count > 0) - { - Buffer_AppendCharUnchecked (enc, ','); - Buffer_AppendIndentNewlineUnchecked (enc); - } - - iterObj = enc->iterGetValue(obj, &tc); - objName = enc->iterGetName(obj, &tc, &szlen); - - enc->level ++; - Buffer_AppendIndentUnchecked (enc, enc->level); - encode (iterObj, enc, objName, szlen); - count ++; - } - - enc->iterEnd(obj, &tc); - Buffer_Reserve (enc, enc->indent * enc->level + 1); - Buffer_AppendIndentNewlineUnchecked (enc); - Buffer_AppendIndentUnchecked (enc, enc->level); - Buffer_Reserve (enc, 1); - Buffer_AppendCharUnchecked (enc, '}'); - break; - } - - case JT_LONG: - { - Buffer_AppendLongUnchecked (enc, enc->getLongValue(obj, &tc)); - break; - } - - case JT_ULONG: - { - Buffer_AppendUnsignedLongUnchecked (enc, enc->getUnsignedLongValue(obj, &tc)); - break; - } - - case JT_INT: - { - Buffer_AppendIntUnchecked (enc, enc->getIntValue(obj, &tc)); - break; - } - - case JT_TRUE: - { - Buffer_AppendCharUnchecked (enc, 't'); - Buffer_AppendCharUnchecked (enc, 'r'); - Buffer_AppendCharUnchecked (enc, 'u'); - Buffer_AppendCharUnchecked (enc, 'e'); - break; - } - - case JT_FALSE: - { - Buffer_AppendCharUnchecked (enc, 'f'); - Buffer_AppendCharUnchecked (enc, 'a'); - Buffer_AppendCharUnchecked (enc, 'l'); - Buffer_AppendCharUnchecked (enc, 's'); - Buffer_AppendCharUnchecked (enc, 'e'); - break; - } - - - case JT_NULL: - { - Buffer_AppendCharUnchecked (enc, 'n'); - Buffer_AppendCharUnchecked (enc, 'u'); - Buffer_AppendCharUnchecked (enc, 'l'); - Buffer_AppendCharUnchecked (enc, 'l'); - break; - } - - case JT_DOUBLE: - { - if (!Buffer_AppendDoubleUnchecked (obj, enc, enc->getDoubleValue(obj, &tc))) - { - enc->endTypeContext(obj, &tc); - enc->level --; - return; - } - break; - } - - case JT_UTF8: - { - value = enc->getStringValue(obj, &tc, &szlen); - if(!value) - { - SetError(obj, enc, "utf-8 encoding error"); - return; - } - - Buffer_Reserve(enc, RESERVE_STRING(szlen)); - if (enc->errorMsg) - { - enc->endTypeContext(obj, &tc); - return; - } - Buffer_AppendCharUnchecked (enc, '\"'); - - if (enc->forceASCII) - { - if (!Buffer_EscapeStringValidated(obj, enc, value, value + szlen)) - { - enc->endTypeContext(obj, &tc); - enc->level --; - return; - } - } - else - { - if (!Buffer_EscapeStringUnvalidated(enc, value, value + szlen)) - { - enc->endTypeContext(obj, &tc); - enc->level --; - return; - } - } - - Buffer_AppendCharUnchecked (enc, '\"'); - break; - } - - case JT_RAW: - { - value = enc->getStringValue(obj, &tc, &szlen); - if(!value) - { - SetError(obj, enc, "utf-8 encoding error"); - return; - } - - Buffer_Reserve(enc, szlen); - if (enc->errorMsg) - { - enc->endTypeContext(obj, &tc); - return; - } - - memcpy(enc->offset, value, szlen); - enc->offset += szlen; - - break; - } - } - - enc->endTypeContext(obj, &tc); - enc->level --; -} - -char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *_buffer, size_t _cbBuffer) -{ - enc->malloc = enc->malloc ? enc->malloc : malloc; - enc->free = enc->free ? enc->free : free; - enc->realloc = enc->realloc ? enc->realloc : realloc; - enc->errorMsg = NULL; - enc->errorObj = NULL; - enc->level = 0; - - if (enc->recursionMax < 1) - { - enc->recursionMax = JSON_MAX_RECURSION_DEPTH; - } - - if (enc->doublePrecision < 0 || - enc->doublePrecision > JSON_DOUBLE_MAX_DECIMALS) - { - enc->doublePrecision = JSON_DOUBLE_MAX_DECIMALS; - } - - if (_buffer == NULL) - { - _cbBuffer = 32768; - enc->start = (char *) enc->malloc (_cbBuffer); - if (!enc->start) - { - SetError(obj, enc, "Could not reserve memory block"); - return NULL; - } - enc->heap = 1; - } - else - { - enc->start = _buffer; - enc->heap = 0; - } - - enc->end = enc->start + _cbBuffer; - enc->offset = enc->start; - - encode (obj, enc, NULL, 0); - - Buffer_Reserve(enc, 1); - if (enc->errorMsg) - { - return NULL; - } - Buffer_AppendCharUnchecked(enc, '\0'); - - return enc->start; -} diff --git a/srsly/ujson/objToJSON.c b/srsly/ujson/objToJSON.c deleted file mode 100644 index c22f861..0000000 --- a/srsly/ujson/objToJSON.c +++ /dev/null @@ -1,984 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms -* Copyright (c) 1988-1993 The Regents of the University of California. -* Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include "py_defines.h" -#include -#include -#include - -#define EPOCH_ORD 719163 -static PyObject* type_decimal = NULL; - -typedef void *(*PFN_PyTypeToJSON)(JSOBJ obj, JSONTypeContext *ti, void *outValue, size_t *_outLen); - -#if (PY_VERSION_HEX < 0x02050000) -typedef ssize_t Py_ssize_t; -#endif - -typedef struct __TypeContext -{ - JSPFN_ITEREND iterEnd; - JSPFN_ITERNEXT iterNext; - JSPFN_ITERGETNAME iterGetName; - JSPFN_ITERGETVALUE iterGetValue; - PFN_PyTypeToJSON PyTypeToJSON; - PyObject *newObj; - PyObject *dictObj; - Py_ssize_t index; - Py_ssize_t size; - PyObject *itemValue; - PyObject *itemName; - PyObject *attrList; - PyObject *iterator; - - union - { - PyObject *rawJSONValue; - JSINT64 longValue; - JSUINT64 unsignedLongValue; - }; -} TypeContext; - -#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv)) - -struct PyDictIterState -{ - PyObject *keys; - size_t i; - size_t sz; -}; - -//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__) -#define PRINTMARK() - -void initObjToJSON(void) -{ - PyObject* mod_decimal = PyImport_ImportModule("decimal"); - if (mod_decimal) - { - type_decimal = PyObject_GetAttrString(mod_decimal, "Decimal"); - Py_INCREF(type_decimal); - Py_DECREF(mod_decimal); - } - else - PyErr_Clear(); - - PyDateTime_IMPORT; -} - -#ifdef _LP64 -static void *PyIntToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - *((JSINT64 *) outValue) = PyInt_AS_LONG (obj); - return NULL; -} -#else -static void *PyIntToINT32(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - *((JSINT32 *) outValue) = PyInt_AS_LONG (obj); - return NULL; -} -#endif - -static void *PyLongToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - *((JSINT64 *) outValue) = GET_TC(tc)->longValue; - return NULL; -} - -static void *PyLongToUINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - *((JSUINT64 *) outValue) = GET_TC(tc)->unsignedLongValue; - return NULL; -} - -static void *PyFloatToDOUBLE(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - *((double *) outValue) = PyFloat_AsDouble (obj); - return NULL; -} - -static void *PyStringToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - *_outLen = PyString_GET_SIZE(obj); - return PyString_AS_STRING(obj); -} - -static void *PyUnicodeToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - PyObject *newObj; -#if (PY_VERSION_HEX >= 0x03030000) - if(PyUnicode_IS_COMPACT_ASCII(obj)) - { - Py_ssize_t len; - char *data = PyUnicode_AsUTF8AndSize(obj, &len); - *_outLen = len; - return data; - } -#endif - newObj = PyUnicode_AsUTF8String(obj); - if(!newObj) - { - return NULL; - } - - GET_TC(tc)->newObj = newObj; - - *_outLen = PyString_GET_SIZE(newObj); - return PyString_AS_STRING(newObj); -} - -static void *PyRawJSONToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = GET_TC(tc)->rawJSONValue; - if (PyUnicode_Check(obj)) { - return PyUnicodeToUTF8(obj, tc, outValue, _outLen); - } - else { - return PyStringToUTF8(obj, tc, outValue, _outLen); - } -} - -static void *PyDateTimeToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - PyObject *date, *ord, *utcoffset; - int y, m, d, h, mn, s, days; - - utcoffset = PyObject_CallMethod(obj, "utcoffset", NULL); - if(utcoffset != Py_None){ - obj = PyNumber_Subtract(obj, utcoffset); - } - - y = PyDateTime_GET_YEAR(obj); - m = PyDateTime_GET_MONTH(obj); - d = PyDateTime_GET_DAY(obj); - h = PyDateTime_DATE_GET_HOUR(obj); - mn = PyDateTime_DATE_GET_MINUTE(obj); - s = PyDateTime_DATE_GET_SECOND(obj); - - date = PyDate_FromDate(y, m, 1); - ord = PyObject_CallMethod(date, "toordinal", NULL); - days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1; - Py_DECREF(date); - Py_DECREF(ord); - *( (JSINT64 *) outValue) = (((JSINT64) ((days * 24 + h) * 60 + mn)) * 60 + s); - return NULL; -} - -static void *PyDateToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen) -{ - PyObject *obj = (PyObject *) _obj; - PyObject *date, *ord; - int y, m, d, days; - - y = PyDateTime_GET_YEAR(obj); - m = PyDateTime_GET_MONTH(obj); - d = PyDateTime_GET_DAY(obj); - - date = PyDate_FromDate(y, m, 1); - ord = PyObject_CallMethod(date, "toordinal", NULL); - days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1; - Py_DECREF(date); - Py_DECREF(ord); - *( (JSINT64 *) outValue) = ((JSINT64) days * 86400); - - return NULL; -} - -int Tuple_iterNext(JSOBJ obj, JSONTypeContext *tc) -{ - PyObject *item; - - if (GET_TC(tc)->index >= GET_TC(tc)->size) - { - return 0; - } - - item = PyTuple_GET_ITEM (obj, GET_TC(tc)->index); - - GET_TC(tc)->itemValue = item; - GET_TC(tc)->index ++; - return 1; -} - -void Tuple_iterEnd(JSOBJ obj, JSONTypeContext *tc) -{ -} - -JSOBJ Tuple_iterGetValue(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->itemValue; -} - -char *Tuple_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) -{ - return NULL; -} - -int List_iterNext(JSOBJ obj, JSONTypeContext *tc) -{ - if (GET_TC(tc)->index >= GET_TC(tc)->size) - { - PRINTMARK(); - return 0; - } - - GET_TC(tc)->itemValue = PyList_GET_ITEM (obj, GET_TC(tc)->index); - GET_TC(tc)->index ++; - return 1; -} - -void List_iterEnd(JSOBJ obj, JSONTypeContext *tc) -{ -} - -JSOBJ List_iterGetValue(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->itemValue; -} - -char *List_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) -{ - return NULL; -} - -//============================================================================= -// Dict iteration functions -// itemName might converted to string (Python_Str). Do refCounting -// itemValue is borrowed from object (which is dict). No refCounting -//============================================================================= - -int Dict_iterNext(JSOBJ obj, JSONTypeContext *tc) -{ - PyObject* itemNameTmp; - - if (GET_TC(tc)->itemName) - { - Py_DECREF(GET_TC(tc)->itemName); - GET_TC(tc)->itemName = NULL; - } - - if (!(GET_TC(tc)->itemName = PyIter_Next(GET_TC(tc)->iterator))) - { - PRINTMARK(); - return 0; - } - - if (GET_TC(tc)->itemValue) { - Py_DECREF(GET_TC(tc)->itemValue); - GET_TC(tc)->itemValue = NULL; - } - - if (!(GET_TC(tc)->itemValue = PyObject_GetItem(GET_TC(tc)->dictObj, GET_TC(tc)->itemName))) { - PRINTMARK(); - return 0; - } - - if (PyUnicode_Check(GET_TC(tc)->itemName)) - { - itemNameTmp = GET_TC(tc)->itemName; - GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp); - Py_DECREF(itemNameTmp); - } - else - if (!PyString_Check(GET_TC(tc)->itemName)) - { - itemNameTmp = GET_TC(tc)->itemName; - GET_TC(tc)->itemName = PyObject_Str(itemNameTmp); - Py_DECREF(itemNameTmp); -#if PY_MAJOR_VERSION >= 3 - itemNameTmp = GET_TC(tc)->itemName; - GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp); - Py_DECREF(itemNameTmp); -#endif - } - PRINTMARK(); - return 1; -} - -void Dict_iterEnd(JSOBJ obj, JSONTypeContext *tc) -{ - if (GET_TC(tc)->itemName) { - Py_DECREF(GET_TC(tc)->itemName); - GET_TC(tc)->itemName = NULL; - } - if (GET_TC(tc)->itemValue) { - Py_DECREF(GET_TC(tc)->itemValue); - GET_TC(tc)->itemValue = NULL; - } - Py_CLEAR(GET_TC(tc)->iterator); - Py_DECREF(GET_TC(tc)->dictObj); - PRINTMARK(); -} - -JSOBJ Dict_iterGetValue(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->itemValue; -} - -char *Dict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) -{ - *outLen = PyString_GET_SIZE(GET_TC(tc)->itemName); - return PyString_AS_STRING(GET_TC(tc)->itemName); -} - -int SortedDict_iterNext(JSOBJ obj, JSONTypeContext *tc) -{ - PyObject *items = NULL, *item = NULL, *key = NULL, *value = NULL; - Py_ssize_t i, nitems; -#if PY_MAJOR_VERSION >= 3 - PyObject* keyTmp; -#endif - - // Upon first call, obtain a list of the keys and sort them. This follows the same logic as the - // stanard library's _json.c sort_keys handler. - if (GET_TC(tc)->newObj == NULL) - { - // Obtain the list of keys from the dictionary. - items = PyMapping_Keys(GET_TC(tc)->dictObj); - if (items == NULL) - { - goto error; - } - else if (!PyList_Check(items)) - { - PyErr_SetString(PyExc_ValueError, "keys must return list"); - goto error; - } - - // Sort the list. - if (PyList_Sort(items) < 0) - { - goto error; - } - - // Obtain the value for each key, and pack a list of (key, value) 2-tuples. - nitems = PyList_GET_SIZE(items); - for (i = 0; i < nitems; i++) - { - key = PyList_GET_ITEM(items, i); - value = PyDict_GetItem(GET_TC(tc)->dictObj, key); - - // Subject the key to the same type restrictions and conversions as in Dict_iterGetValue. - if (PyUnicode_Check(key)) - { - key = PyUnicode_AsUTF8String(key); - } - else if (!PyString_Check(key)) - { - key = PyObject_Str(key); -#if PY_MAJOR_VERSION >= 3 - keyTmp = key; - key = PyUnicode_AsUTF8String(key); - Py_DECREF(keyTmp); -#endif - } - else - { - Py_INCREF(key); - } - - item = PyTuple_Pack(2, key, value); - if (item == NULL) - { - goto error; - } - if (PyList_SetItem(items, i, item)) - { - goto error; - } - Py_DECREF(key); - } - - // Store the sorted list of tuples in the newObj slot. - GET_TC(tc)->newObj = items; - GET_TC(tc)->size = nitems; - } - - if (GET_TC(tc)->index >= GET_TC(tc)->size) - { - PRINTMARK(); - return 0; - } - - item = PyList_GET_ITEM(GET_TC(tc)->newObj, GET_TC(tc)->index); - GET_TC(tc)->itemName = PyTuple_GET_ITEM(item, 0); - GET_TC(tc)->itemValue = PyTuple_GET_ITEM(item, 1); - GET_TC(tc)->index++; - return 1; - -error: - Py_XDECREF(item); - Py_XDECREF(key); - Py_XDECREF(value); - Py_XDECREF(items); - return -1; -} - -void SortedDict_iterEnd(JSOBJ obj, JSONTypeContext *tc) -{ - GET_TC(tc)->itemName = NULL; - GET_TC(tc)->itemValue = NULL; - Py_DECREF(GET_TC(tc)->newObj); - Py_DECREF(GET_TC(tc)->dictObj); - PRINTMARK(); -} - -JSOBJ SortedDict_iterGetValue(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->itemValue; -} - -char *SortedDict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) -{ - *outLen = PyString_GET_SIZE(GET_TC(tc)->itemName); - return PyString_AS_STRING(GET_TC(tc)->itemName); -} - - -void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder *enc) -{ - pc->dictObj = dictObj; - if (enc->sortKeys) - { - pc->iterEnd = SortedDict_iterEnd; - pc->iterNext = SortedDict_iterNext; - pc->iterGetValue = SortedDict_iterGetValue; - pc->iterGetName = SortedDict_iterGetName; - pc->index = 0; - } - else - { - pc->iterEnd = Dict_iterEnd; - pc->iterNext = Dict_iterNext; - pc->iterGetValue = Dict_iterGetValue; - pc->iterGetName = Dict_iterGetName; - pc->iterator = PyObject_GetIter(dictObj); - } -} - -void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc) -{ - PyObject *obj, *objRepr, *exc; - TypeContext *pc; - PRINTMARK(); - if (!_obj) - { - tc->type = JT_INVALID; - return; - } - - obj = (PyObject*) _obj; - - tc->prv = PyObject_Malloc(sizeof(TypeContext)); - pc = (TypeContext *) tc->prv; - if (!pc) - { - tc->type = JT_INVALID; - PyErr_NoMemory(); - return; - } - pc->newObj = NULL; - pc->dictObj = NULL; - pc->itemValue = NULL; - pc->itemName = NULL; - pc->iterator = NULL; - pc->attrList = NULL; - pc->index = 0; - pc->size = 0; - pc->longValue = 0; - pc->rawJSONValue = NULL; - - if (PyIter_Check(obj)) - { - PRINTMARK(); - goto ISITERABLE; - } - - if (PyBool_Check(obj)) - { - PRINTMARK(); - tc->type = (obj == Py_True) ? JT_TRUE : JT_FALSE; - return; - } - else - if (PyLong_Check(obj)) - { - PRINTMARK(); - pc->PyTypeToJSON = PyLongToINT64; - tc->type = JT_LONG; - GET_TC(tc)->longValue = PyLong_AsLongLong(obj); - - exc = PyErr_Occurred(); - if (!exc) - { - return; - } - - if (exc && PyErr_ExceptionMatches(PyExc_OverflowError)) - { - PyErr_Clear(); - pc->PyTypeToJSON = PyLongToUINT64; - tc->type = JT_ULONG; - GET_TC(tc)->unsignedLongValue = PyLong_AsUnsignedLongLong(obj); - - exc = PyErr_Occurred(); - if (exc && PyErr_ExceptionMatches(PyExc_OverflowError)) - { - PRINTMARK(); - goto INVALID; - } - } - - return; - } - else - if (PyInt_Check(obj)) - { - PRINTMARK(); -#ifdef _LP64 - pc->PyTypeToJSON = PyIntToINT64; tc->type = JT_LONG; -#else - pc->PyTypeToJSON = PyIntToINT32; tc->type = JT_INT; -#endif - return; - } - else - if (PyString_Check(obj) && !PyObject_HasAttrString(obj, "__json__")) - { - PRINTMARK(); - pc->PyTypeToJSON = PyStringToUTF8; tc->type = JT_UTF8; - return; - } - else - if (PyUnicode_Check(obj)) - { - PRINTMARK(); - pc->PyTypeToJSON = PyUnicodeToUTF8; tc->type = JT_UTF8; - return; - } - else - if (PyFloat_Check(obj) || (type_decimal && PyObject_IsInstance(obj, type_decimal))) - { - PRINTMARK(); - pc->PyTypeToJSON = PyFloatToDOUBLE; tc->type = JT_DOUBLE; - return; - } - else - if (PyDateTime_Check(obj)) - { - PRINTMARK(); - pc->PyTypeToJSON = PyDateTimeToINT64; tc->type = JT_LONG; - return; - } - else - if (PyDate_Check(obj)) - { - PRINTMARK(); - pc->PyTypeToJSON = PyDateToINT64; tc->type = JT_LONG; - return; - } - else - if (obj == Py_None) - { - PRINTMARK(); - tc->type = JT_NULL; - return; - } - -ISITERABLE: - if (PyDict_Check(obj)) - { - PRINTMARK(); - tc->type = JT_OBJECT; - SetupDictIter(obj, pc, enc); - Py_INCREF(obj); - return; - } - else - if (PyList_Check(obj)) - { - PRINTMARK(); - tc->type = JT_ARRAY; - pc->iterEnd = List_iterEnd; - pc->iterNext = List_iterNext; - pc->iterGetValue = List_iterGetValue; - pc->iterGetName = List_iterGetName; - GET_TC(tc)->index = 0; - GET_TC(tc)->size = PyList_GET_SIZE( (PyObject *) obj); - return; - } - else - if (PyTuple_Check(obj)) - { - PRINTMARK(); - tc->type = JT_ARRAY; - pc->iterEnd = Tuple_iterEnd; - pc->iterNext = Tuple_iterNext; - pc->iterGetValue = Tuple_iterGetValue; - pc->iterGetName = Tuple_iterGetName; - GET_TC(tc)->index = 0; - GET_TC(tc)->size = PyTuple_GET_SIZE( (PyObject *) obj); - GET_TC(tc)->itemValue = NULL; - - return; - } - - if (PyObject_HasAttrString(obj, "toDict")) - { - PyObject* toDictFunc = PyObject_GetAttrString(obj, "toDict"); - PyObject* tuple = PyTuple_New(0); - PyObject* toDictResult = PyObject_Call(toDictFunc, tuple, NULL); - Py_DECREF(tuple); - Py_DECREF(toDictFunc); - - if (toDictResult == NULL) - { - goto INVALID; - } - - if (!PyDict_Check(toDictResult)) - { - Py_DECREF(toDictResult); - tc->type = JT_NULL; - return; - } - - PRINTMARK(); - tc->type = JT_OBJECT; - SetupDictIter(toDictResult, pc, enc); - return; - } - else - if (PyObject_HasAttrString(obj, "__json__")) - { - PyObject* toJSONFunc = PyObject_GetAttrString(obj, "__json__"); - PyObject* tuple = PyTuple_New(0); - PyObject* toJSONResult = PyObject_Call(toJSONFunc, tuple, NULL); - Py_DECREF(tuple); - Py_DECREF(toJSONFunc); - - if (toJSONResult == NULL) - { - goto INVALID; - } - - if (PyErr_Occurred()) - { - Py_DECREF(toJSONResult); - goto INVALID; - } - - if (!PyString_Check(toJSONResult) && !PyUnicode_Check(toJSONResult)) - { - Py_DECREF(toJSONResult); - PyErr_Format (PyExc_TypeError, "expected string"); - goto INVALID; - } - - PRINTMARK(); - pc->PyTypeToJSON = PyRawJSONToUTF8; - tc->type = JT_RAW; - GET_TC(tc)->rawJSONValue = toJSONResult; - return; - } - - PRINTMARK(); - PyErr_Clear(); - - objRepr = PyObject_Repr(obj); -#if PY_MAJOR_VERSION >= 3 - PyObject* str = PyUnicode_AsEncodedString(objRepr, "utf-8", "~E~"); - PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(str)); - Py_XDECREF(str); -#else - PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(objRepr)); -#endif - Py_DECREF(objRepr); - -INVALID: - PRINTMARK(); - tc->type = JT_INVALID; - PyObject_Free(tc->prv); - tc->prv = NULL; - return; -} - -void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc) -{ - Py_XDECREF(GET_TC(tc)->newObj); - - if (tc->type == JT_RAW) - { - Py_XDECREF(GET_TC(tc)->rawJSONValue); - } - PyObject_Free(tc->prv); - tc->prv = NULL; -} - -const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen) -{ - return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen); -} - -JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc) -{ - JSINT64 ret; - GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); - return ret; -} - -JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc) -{ - JSUINT64 ret; - GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); - return ret; -} - -JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc) -{ - JSINT32 ret; - GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); - return ret; -} - -double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc) -{ - double ret; - GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); - return ret; -} - -static void Object_releaseObject(JSOBJ _obj) -{ - Py_DECREF( (PyObject *) _obj); -} - -int Object_iterNext(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->iterNext(obj, tc); -} - -void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc) -{ - GET_TC(tc)->iterEnd(obj, tc); -} - -JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc) -{ - return GET_TC(tc)->iterGetValue(obj, tc); -} - -char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) -{ - return GET_TC(tc)->iterGetName(obj, tc, outLen); -} - -PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) -{ - static char *kwlist[] = { "obj", "ensure_ascii", "double_precision", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", NULL }; - - char buffer[65536]; - char *ret; - PyObject *newobj; - PyObject *oinput = NULL; - PyObject *oensureAscii = NULL; - PyObject *oencodeHTMLChars = NULL; - PyObject *oescapeForwardSlashes = NULL; - PyObject *osortKeys = NULL; - - JSONObjectEncoder encoder = - { - Object_beginTypeContext, - Object_endTypeContext, - Object_getStringValue, - Object_getLongValue, - Object_getUnsignedLongValue, - Object_getIntValue, - Object_getDoubleValue, - Object_iterNext, - Object_iterEnd, - Object_iterGetValue, - Object_iterGetName, - Object_releaseObject, - PyObject_Malloc, - PyObject_Realloc, - PyObject_Free, - -1, //recursionMax - 10, // default double precision setting - 1, //forceAscii - 0, //encodeHTMLChars - 1, //escapeForwardSlashes - 0, //sortKeys - 0, //indent - NULL, //prv - }; - - - PRINTMARK(); - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OiOOOi", kwlist, &oinput, &oensureAscii, &encoder.doublePrecision, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent)) - { - return NULL; - } - - if (oensureAscii != NULL && !PyObject_IsTrue(oensureAscii)) - { - encoder.forceASCII = 0; - } - - if (oencodeHTMLChars != NULL && PyObject_IsTrue(oencodeHTMLChars)) - { - encoder.encodeHTMLChars = 1; - } - - if (oescapeForwardSlashes != NULL && !PyObject_IsTrue(oescapeForwardSlashes)) - { - encoder.escapeForwardSlashes = 0; - } - - if (osortKeys != NULL && PyObject_IsTrue(osortKeys)) - { - encoder.sortKeys = 1; - } - - PRINTMARK(); - ret = JSON_EncodeObject (oinput, &encoder, buffer, sizeof (buffer)); - PRINTMARK(); - - if (PyErr_Occurred()) - { - return NULL; - } - - if (encoder.errorMsg) - { - if (ret != buffer) - { - encoder.free (ret); - } - - PyErr_Format (PyExc_OverflowError, "%s", encoder.errorMsg); - return NULL; - } - - newobj = PyString_FromString (ret); - - if (ret != buffer) - { - encoder.free (ret); - } - - PRINTMARK(); - - return newobj; -} - -PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs) -{ - PyObject *data; - PyObject *file; - PyObject *string; - PyObject *write; - PyObject *argtuple; - PyObject *write_result; - - PRINTMARK(); - - if (!PyArg_ParseTuple (args, "OO", &data, &file)) - { - return NULL; - } - - if (!PyObject_HasAttrString (file, "write")) - { - PyErr_Format (PyExc_TypeError, "expected file"); - return NULL; - } - - write = PyObject_GetAttrString (file, "write"); - - if (!PyCallable_Check (write)) - { - Py_XDECREF(write); - PyErr_Format (PyExc_TypeError, "expected file"); - return NULL; - } - - argtuple = PyTuple_Pack(1, data); - - string = objToJSON (self, argtuple, kwargs); - - if (string == NULL) - { - Py_XDECREF(write); - Py_XDECREF(argtuple); - return NULL; - } - - Py_XDECREF(argtuple); - - argtuple = PyTuple_Pack (1, string); - if (argtuple == NULL) - { - Py_XDECREF(write); - return NULL; - } - write_result = PyObject_CallObject (write, argtuple); - if (write_result == NULL) - { - Py_XDECREF(write); - Py_XDECREF(argtuple); - return NULL; - } - - Py_DECREF(write_result); - Py_XDECREF(write); - Py_DECREF(argtuple); - Py_XDECREF(string); - - PRINTMARK(); - - Py_RETURN_NONE; -} diff --git a/srsly/ujson/py_defines.h b/srsly/ujson/py_defines.h deleted file mode 100644 index 2b38b41..0000000 --- a/srsly/ujson/py_defines.h +++ /dev/null @@ -1,53 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms - * Copyright (c) 1988-1993 The Regents of the University of California. - * Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include - -#if PY_MAJOR_VERSION >= 3 - -#define PyInt_Check PyLong_Check -#define PyInt_AS_LONG PyLong_AsLong -#define PyInt_FromLong PyLong_FromLong - -#define PyString_Check PyBytes_Check -#define PyString_GET_SIZE PyBytes_GET_SIZE -#define PyString_AS_STRING PyBytes_AS_STRING - -#define PyString_FromString PyUnicode_FromString - -#endif diff --git a/srsly/ujson/ujson.c b/srsly/ujson/ujson.c deleted file mode 100644 index d0b15c6..0000000 --- a/srsly/ujson/ujson.c +++ /dev/null @@ -1,113 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms -* Copyright (c) 1988-1993 The Regents of the University of California. -* Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#include "py_defines.h" -#include "version.h" - -/* objToJSON */ -PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs); -void initObjToJSON(void); - -/* JSONToObj */ -PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs); - -/* objToJSONFile */ -PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs); - -/* JSONFileToObj */ -PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs); - - -#define ENCODER_HELP_TEXT "Use ensure_ascii=false to output UTF-8. Pass in double_precision to alter the maximum digit precision of doubles. Set encode_html_chars=True to encode < > & as unicode escape sequences. Set escape_forward_slashes=False to prevent escaping / characters." - -static PyMethodDef ujsonMethods[] = { - {"encode", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT}, - {"decode", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."}, - {"dumps", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT}, - {"loads", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."}, - {"dump", (PyCFunction) objToJSONFile, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON file. " ENCODER_HELP_TEXT}, - {"load", (PyCFunction) JSONFileToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as file to dict object structure. Use precise_float=True to use high precision float decoder."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#if PY_MAJOR_VERSION >= 3 - -static struct PyModuleDef moduledef = { - PyModuleDef_HEAD_INIT, - "ujson", - 0, /* m_doc */ - -1, /* m_size */ - ujsonMethods, /* m_methods */ - NULL, /* m_reload */ - NULL, /* m_traverse */ - NULL, /* m_clear */ - NULL /* m_free */ -}; - -#define PYMODINITFUNC PyObject *PyInit_ujson(void) -#define PYMODULE_CREATE() PyModule_Create(&moduledef) -#define MODINITERROR return NULL - -#else - -#define PYMODINITFUNC PyMODINIT_FUNC initujson(void) -#define PYMODULE_CREATE() Py_InitModule("ujson", ujsonMethods) -#define MODINITERROR return - -#endif - -PYMODINITFUNC -{ - PyObject *module; - PyObject *version_string; - - initObjToJSON(); - module = PYMODULE_CREATE(); - - if (module == NULL) - { - MODINITERROR; - } - - version_string = PyString_FromString (UJSON_VERSION); - PyModule_AddObject (module, "__version__", version_string); - -#if PY_MAJOR_VERSION >= 3 - return module; -#endif -} diff --git a/srsly/ujson/version.h b/srsly/ujson/version.h deleted file mode 100644 index f0ce6bb..0000000 --- a/srsly/ujson/version.h +++ /dev/null @@ -1,39 +0,0 @@ -/* -Developed by ESN, an Electronic Arts Inc. studio. -Copyright (c) 2014, Electronic Arts Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -* Neither the name of ESN, Electronic Arts Inc. nor the -names of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc) -http://code.google.com/p/stringencoders/ -Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved. - -Numeric decoder derived from from TCL library -http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms - * Copyright (c) 1988-1993 The Regents of the University of California. - * Copyright (c) 1994 Sun Microsystems, Inc. -*/ - -#define UJSON_VERSION "1.35"