diff --git a/.github/workflows/cibuildwheel.yml b/.github/workflows/cibuildwheel.yml
index c8dcbab..1ab2818 100644
--- a/.github/workflows/cibuildwheel.yml
+++ b/.github/workflows/cibuildwheel.yml
@@ -21,7 +21,7 @@ jobs:
actions: read
with:
wheel-name-pattern: "srsly-*.whl"
- pure-python: false
+ pure-python: true
create-release: ${{ startsWith(github.ref, 'refs/tags/release-') || startsWith(github.ref, 'refs/tags/prerelease-') }}
secrets:
gh-token: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml
index 84b3ceb..c31bff0 100644
--- a/.github/workflows/publish_pypi.yml
+++ b/.github/workflows/publish_pypi.yml
@@ -1,5 +1,5 @@
-# The cibuildwheel action triggers on creation of a release, this
-# triggers on publication.
+# The cibuildwheel action triggers on all pushes and PRs;
+# this action triggers on release publication.
# The expected workflow is to create a draft release and let the wheels
# upload, and then hit 'publish', which uploads to PyPi.
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 2bfc614..e208c3a 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -14,17 +14,26 @@ concurrency:
cancel-in-progress: true
env:
- MODULE_NAME: 'srsly'
RUN_MYPY: 'false'
jobs:
tests:
+ name: ${{ matrix.python_version }} ${{ matrix.os }} numpy=${{ matrix.numpy }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
- # FIXME: ujson segfault on 3.14
- python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
+ # Note: ruamel.yaml does not support Python 3.13t
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"]
+ numpy: [true]
+ include:
+ - os: ubuntu-latest
+ python_version: "3.9"
+ numpy: false
+ - os: ubuntu-latest
+ python_version: "3.14"
+ numpy: false
+
runs-on: ${{ matrix.os }}
steps:
@@ -35,46 +44,26 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python_version }}
- architecture: x64
-
- - name: Build sdist
- run: |
- python -m pip install -U build pip setuptools
- python -m pip install -U -r requirements.txt
- python -m build --sdist
- name: Run mypy
- shell: bash
if: ${{ env.RUN_MYPY == 'true' }}
run: |
- python -m mypy $MODULE_NAME
-
- - name: Delete source directory
- shell: bash
- run: |
- rm -rf $MODULE_NAME
+ python -m pip install mypy
+ python -m mypy srsly
- - name: Uninstall all packages
- run: |
- python -m pip freeze > installed.txt
- python -m pip uninstall -y -r installed.txt
+ - name: Install package
+ run: python -m pip install .
- - name: Install from sdist
+ - name: Test that numpy was not installed
shell: bash
- run: |
- SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
- python -m pip install dist/$SDIST
+ run: if pip list | grep -q '^numpy'; then exit 1; fi
- - name: Test import
- shell: bash
- run: |
- python -c "import $MODULE_NAME" -Werror
+ - name: Install numpy
+ if: ${{ matrix.numpy }}
+ run: python -m pip install numpy
- name: Install test requirements
- run: |
- python -m pip install -U -r requirements.txt
+ run: python -m pip install pytest
- name: Run tests
- shell: bash
- run: |
- python -m pytest --pyargs $MODULE_NAME -Werror
+ run: pytest
diff --git a/.gitignore b/.gitignore
index a327cc0..b5cbd43 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,6 @@
.env/
.env*
.vscode/
-cythonize.json
# Byte-compiled / optimized / DLL files
__pycache__/
@@ -108,8 +107,5 @@ venv.bak/
# mypy
.mypy_cache/
-# Cython intermediate files
-*.cpp
-
# Vim files
*.sw*
diff --git a/README.md b/README.md
index b5e49d5..56a691d 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
# srsly: Modern high-performance serialization utilities for Python
This package bundles some of the best Python serialization libraries into one
-standalone package, with a high-level API that makes it easy to write code
+convenience package, with a high-level API that makes it easy to write code
that's correct across platforms and Pythons. This allows us to provide all the
serialization utilities we need in a single binary wheel. Currently supports
**JSON**, **JSONL**, **MessagePack**, **Pickle** and **YAML**.
@@ -24,31 +24,24 @@ wrap the multiple serialization formats we need to support (especially `json`,
`msgpack` and `pickle`). These wrapping functions ended up duplicated across our
codebases, so we wanted to put them in one place.
-At the same time, we noticed that having a lot of small dependencies was making
-maintenance harder, and making installation slower. To solve this, we've made
-`srsly` standalone, by including the component packages directly within it. This
-way we can provide all the serialization utilities we need in a single binary
-wheel.
-
-`srsly` currently includes forks of the following packages:
+`srsly` currently includes wrappers around the following packages:
- [`ujson`](https://github.com/esnme/ultrajson)
- [`msgpack`](https://github.com/msgpack/msgpack-python)
-- [`msgpack-numpy`](https://github.com/lebedov/msgpack-numpy)
- [`cloudpickle`](https://github.com/cloudpipe/cloudpickle)
- [`ruamel.yaml`](https://github.com/pycontribs/ruamel-yaml) (without unsafe
implementations!)
-## Installation
+Additionally, it includes a heavily customized fork of
+[`msgpack-numpy`](https://github.com/lebedov/msgpack-numpy), with corrected
+round-trip behaviour for np.float64 objects.
+
-> ⚠️ Note that `v2.x` is only compatible with **Python 3.6+**. For 2.7+
-> compatibility, use `v1.x`.
+## Installation
-`srsly` can be installed from pip. Before installing, make sure that your `pip`,
-`setuptools` and `wheel` are up to date.
+`srsly` can be installed from pip.
```bash
-python -m pip install -U pip setuptools wheel
python -m pip install srsly
```
@@ -58,12 +51,15 @@ Or from conda via conda-forge:
conda install -c conda-forge srsly
```
-Alternatively, you can also compile the library from source. You'll need to make
-sure that you have a development environment with a Python distribution
-including header files, a compiler (XCode command-line tools on macOS / OS X or
-Visual C++ build tools on Windows), pip and git installed.
+This will automatically install/upgrade all dependencies.
+
+numpy and cupy are optional dependencies for msgpack.
+If numpy is installed, numpy objects can be serialized.
+If cupy is installed, cupy objects will be automaticaly converted
+to numpy and then serialized.
+
-Install from source:
+Alternatively, you can also install the library from the repository:
```bash
# clone the repo
@@ -74,10 +70,7 @@ cd srsly
python -m venv .env
source .env/bin/activate
-# update pip
-python -m pip install -U pip setuptools wheel
-
-# compile and install from source
+# install from source
python -m pip install .
```
@@ -86,7 +79,6 @@ mode without build isolation:
```bash
# install in editable mode
-python -m pip install -r requirements.txt
python -m pip install --no-build-isolation --editable .
# run test suite
@@ -97,9 +89,6 @@ python -m pytest --pyargs srsly
### JSON
-> 📦 The underlying module is exposed via `srsly.ujson`. However, we normally
-> interact with it via the utility functions only.
-
#### function `srsly.json_dumps`
Serialize an object to a JSON string. Falls back to `json` if `sort_keys=True`
@@ -264,9 +253,6 @@ assert srsly.is_json_serializable(lambda x: x) is False
### msgpack
-> 📦 The underlying module is exposed via `srsly.msgpack`. However, we normally
-> interact with it via the utility functions only.
-
#### function `srsly.msgpack_dumps`
Serialize an object to a msgpack byte string.
@@ -326,9 +312,6 @@ data = srsly.read_msgpack("/path/to/file.msg")
### pickle
-> 📦 The underlying module is exposed via `srsly.cloudpickle`. However, we
-> normally interact with it via the utility functions only.
-
#### function `srsly.pickle_dumps`
Serialize a Python object with pickle.
@@ -360,9 +343,6 @@ data = srsly.pickle_loads(pickled_data)
### YAML
-> 📦 The underlying module is exposed via `srsly.ruamel_yaml`. However, we
-> normally interact with it via the utility functions only.
-
#### function `srsly.yaml_dumps`
Serialize an object to a YAML string. See the
diff --git a/pyproject.toml b/pyproject.toml
index 483235c..fed528d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,63 +1,3 @@
[build-system]
-requires = [
- "setuptools",
- "cython>=0.29.1",
-]
+requires = ["setuptools"]
build-backend = "setuptools.build_meta"
-
-[tool.cibuildwheel]
-build = "*"
-skip = [
- "cp38*", # Obsolete
- "cp314-*", # FIXME ujson segfaults
- "cp314t-*", # TODO free-threading support (note: 3.13t is skipped by default)
-]
-test-skip = ""
-
-archs = ["native"]
-
-build-frontend = "default"
-config-settings = {}
-dependency-versions = "pinned"
-environment = {}
-environment-pass = []
-build-verbosity = 0
-
-before-all = ""
-before-build = ""
-repair-wheel-command = ""
-
-test-command = ""
-before-test = ""
-test-requires = []
-test-extras = []
-
-container-engine = "docker"
-
-manylinux-x86_64-image = "manylinux2014"
-manylinux-i686-image = "manylinux2014"
-manylinux-aarch64-image = "manylinux2014"
-manylinux-ppc64le-image = "manylinux2014"
-manylinux-s390x-image = "manylinux2014"
-manylinux-pypy_x86_64-image = "manylinux2014"
-manylinux-pypy_i686-image = "manylinux2014"
-manylinux-pypy_aarch64-image = "manylinux2014"
-
-musllinux-x86_64-image = "musllinux_1_2"
-musllinux-i686-image = "musllinux_1_2"
-musllinux-aarch64-image = "musllinux_1_2"
-musllinux-ppc64le-image = "musllinux_1_2"
-musllinux-s390x-image = "musllinux_1_2"
-
-
-[tool.cibuildwheel.linux]
-repair-wheel-command = "auditwheel repair -w {dest_dir} {wheel}"
-
-[tool.cibuildwheel.macos]
-repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel}"
-
-[tool.cibuildwheel.windows]
-
-[tool.cibuildwheel.pyodide]
-
-
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 7de940c..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-catalogue>=2.0.3,<2.1.0
-# Development requirements
-cython>=0.29.1
-pytest>=4.6.5
-pytest-timeout>=1.3.3
-mock>=2.0.0,<3.0.0
-numpy>=1.15.0
-psutil
diff --git a/setup.cfg b/setup.cfg
index 2b3679a..8097293 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -12,27 +12,25 @@ classifiers =
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
- Operating System :: POSIX :: Linux
- Operating System :: MacOS :: MacOS X
- Operating System :: Microsoft :: Windows
- Programming Language :: Cython
+ Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.13
+ Programming Language :: Python :: 3.14
Topic :: Scientific/Engineering
[options]
zip_safe = true
include_package_data = true
-# FIXME ujson segfaults on 3.14
-python_requires = >=3.9,<3.14
-setup_requires =
- cython>=0.29.1
+python_requires = >=3.9
install_requires =
- catalogue>=2.0.3,<2.1.0
+ cloudpickle >=3.1.2,<4
+ msgpack >=1.1,<2
+ ruamel.yaml >=0.18.16,<1
+ ujson >=5.11.0,<6
[options.entry_points]
# If spaCy is installed in the same environment as srsly, it will automatically
@@ -44,7 +42,7 @@ spacy_readers =
srsly.read_msgpack.v1 = srsly:read_msgpack
[bdist_wheel]
-universal = false
+universal = true
[sdist]
formats = gztar
@@ -55,11 +53,6 @@ max-line-length = 80
select = B,C,E,F,W,T4,B9
exclude =
srsly/__init__.py
- srsly/msgpack/__init__.py
- srsly/cloudpickle/__init__.py
[mypy]
ignore_missing_imports = True
-
-[mypy-srsly.cloudpickle.*]
-ignore_errors=True
diff --git a/setup.py b/setup.py
index 7fd4252..4346012 100644
--- a/setup.py
+++ b/setup.py
@@ -1,141 +1,19 @@
#!/usr/bin/env python
-import sys
-from setuptools.command.build_ext import build_ext
-from sysconfig import get_path
-from setuptools import Extension, setup, find_packages
+from setuptools import setup, find_packages
from pathlib import Path
-from Cython.Build import cythonize
-from Cython.Compiler import Options
-import contextlib
-import os
-
-
-# Preserve `__doc__` on functions and classes
-# http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options
-Options.docstrings = True
-
-
-PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.c", "*.h", "*.cpp"]}
-PACKAGES = find_packages()
-# msgpack has this whacky build where it only builds _cmsgpack which textually includes
-# _packer and _unpacker. I refactored this.
-MOD_NAMES = ["srsly.msgpack._epoch", "srsly.msgpack._packer", "srsly.msgpack._unpacker"]
-COMPILE_OPTIONS = {
- "msvc": ["/Ox", "/EHsc"],
- "mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
- "other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
-}
-COMPILER_DIRECTIVES = {
- "language_level": -3,
- "embedsignature": True,
- "annotation_typing": False,
-}
-LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": ["-lstdc++", "-lm"]}
-
-if sys.byteorder == "big":
- macros = [("__BIG_ENDIAN__", "1")]
-else:
- macros = [("__LITTLE_ENDIAN__", "1")]
-
-
-# By subclassing build_extensions we have the actual compiler that will be used
-# which is really known only after finalize_options
-# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
-class build_ext_options:
- def build_options(self):
- if hasattr(self.compiler, "initialize"):
- self.compiler.initialize()
- self.compiler.platform = sys.platform[:6]
- for e in self.extensions:
- e.extra_compile_args += COMPILE_OPTIONS.get(
- self.compiler.compiler_type, COMPILE_OPTIONS["other"]
- )
- e.extra_link_args += LINK_OPTIONS.get(
- self.compiler.compiler_type, LINK_OPTIONS["other"]
- )
-
-
-class build_ext_subclass(build_ext, build_ext_options):
- def build_extensions(self):
- build_ext_options.build_options(self)
- build_ext.build_extensions(self)
-
-
-def clean(path):
- n_cleaned = 0
- for name in MOD_NAMES:
- name = name.replace(".", "/")
- for ext in ["so", "html", "cpp", "c"]:
- file_path = path / f"{name}.{ext}"
- if file_path.exists():
- file_path.unlink()
- n_cleaned += 1
- print(f"Cleaned {n_cleaned} files")
-
-
-@contextlib.contextmanager
-def chdir(new_dir):
- old_dir = os.getcwd()
- try:
- os.chdir(new_dir)
- sys.path.insert(0, new_dir)
- yield
- finally:
- del sys.path[0]
- os.chdir(old_dir)
def setup_package():
root = Path(__file__).parent
-
- if len(sys.argv) > 1 and sys.argv[1] == "clean":
- return clean(root)
-
with (root / "srsly" / "about.py").open("r") as f:
about = {}
exec(f.read(), about)
- with chdir(str(root)):
- include_dirs = [get_path("include"), ".", "srsly"]
- ext_modules = []
- for name in MOD_NAMES:
- mod_path = name.replace(".", "/") + ".pyx"
- ext_modules.append(
- Extension(
- name,
- [mod_path],
- language="c++",
- include_dirs=include_dirs,
- define_macros=macros,
- )
- )
- ext_modules.append(
- Extension(
- "srsly.ujson.ujson",
- sources=[
- "./srsly/ujson/ujson.c",
- "./srsly/ujson/objToJSON.c",
- "./srsly/ujson/JSONtoObj.c",
- "./srsly/ujson/lib/ultrajsonenc.c",
- "./srsly/ujson/lib/ultrajsondec.c",
- ],
- include_dirs=["./srsly/ujson", "./srsly/ujson/lib"],
- extra_compile_args=["-D_GNU_SOURCE"],
- )
- )
- print("Cythonizing sources")
- ext_modules = cythonize(
- ext_modules, compiler_directives=COMPILER_DIRECTIVES, language_level=2
- )
-
- setup(
- name="srsly",
- packages=PACKAGES,
- version=about["__version__"],
- ext_modules=ext_modules,
- cmdclass={"build_ext": build_ext_subclass},
- package_data=PACKAGE_DATA,
- )
+ setup(
+ name="srsly",
+ packages=find_packages(),
+ version=about["__version__"],
+ )
if __name__ == "__main__":
diff --git a/srsly/_json_api.py b/srsly/_json_api.py
index 24d25fd..e755217 100644
--- a/srsly/_json_api.py
+++ b/srsly/_json_api.py
@@ -1,9 +1,10 @@
-from typing import Union, Iterable, Sequence, Any, Optional, Iterator
+from typing import Union, Iterable, Any, Optional, Iterator
import sys
import json as _builtin_json
import gzip
-from . import ujson
+import ujson
+
from .util import force_path, force_string, FilePath, JSONInput, JSONOutput
diff --git a/srsly/_msgpack_api.py b/srsly/_msgpack_api.py
index 3da0fe6..5259380 100644
--- a/srsly/_msgpack_api.py
+++ b/srsly/_msgpack_api.py
@@ -1,8 +1,103 @@
import gc
+from contextlib import contextmanager
+
+import msgpack
-from . import msgpack
-from .msgpack import msgpack_encoders, msgpack_decoders # noqa: F401
from .util import force_path, FilePath, JSONInputBin, JSONOutputBin
+from ._msgpack_numpy import encode_numpy, decode_numpy
+
+
+class _MsgpackExtensions:
+ """API for extending msgpack (de)serialization:
+
+ srsly.msgpack_encoders.register(name, func)
+ srsly.msgpack_decoders.register(name, func)
+
+ where `name` is a unique ID and `func` is a callable that accepts a single
+ argument:
+
+ - For encoders, the argument is the object to serialize. The callable should
+ return a new object (typically a dict) or the original object if the callback
+ does not recognize it.
+ - For decoders, the argument is the dict to deserialize, as returned by the encoders.
+ The callable should return a new object or the original dict if the callback
+ does not recognize it.
+ """
+
+ __slots__ = ("_ext",)
+
+ def __init__(self):
+ self._ext = {}
+
+ def register(self, name, func):
+ """Register a custom encoder/decoder function"""
+ self._ext[name] = func
+
+ def deregister(self, name):
+ del self._ext[name]
+
+ def _run(self, obj):
+ for func in self._ext.values():
+ out = func(obj)
+ if out is not obj:
+ return out
+ return obj
+
+
+class _MsgpackEncoderExtensions(_MsgpackExtensions):
+ def _run(self, obj):
+ out = super()._run(obj)
+ if out is not obj:
+ return out
+
+ # Convert subtypes of base types and tuples to lists.
+ # Effectively this undoes the strict_types=True option of msgpack.
+ # This is needed to support np.float64, which is a subclass of builtin float.
+ # Run this last to allow the user to register their own handlers first.
+
+ if isinstance(obj, tuple):
+ return list(obj)
+ # Note: bool and memoryview can't be subclassed
+ # set and frozenset are not supported by msgpack
+ for cls in (int, float, list, dict, str, bytes):
+ if isinstance(obj, cls):
+ return cls(obj)
+
+ return obj
+
+
+msgpack_encoders = _MsgpackEncoderExtensions()
+msgpack_decoders = _MsgpackExtensions()
+
+
+def encode_complex(obj):
+ if isinstance(obj, complex):
+ return {b"complex": True, b"data": repr(obj)}
+ return obj
+
+
+def decode_complex(obj):
+ if b"complex" in obj:
+ return complex(obj[b"data"])
+ return obj
+
+
+msgpack_encoders.register("numpy", func=encode_numpy)
+msgpack_decoders.register("numpy", func=decode_numpy)
+# Note: np.complex128 is a subclass of built-in complex, so
+# encode_complex must be registered after encode_numpy.
+msgpack_encoders.register("complex", func=encode_complex)
+msgpack_decoders.register("complex", func=decode_complex)
+
+
+@contextmanager
+def _without_gc():
+ """msgpack-python docs suggest disabling gc before unpacking large messages"""
+ gc.disable()
+ try:
+ yield
+ finally:
+ gc.enable()
def msgpack_dumps(data: JSONInputBin) -> bytes:
@@ -11,7 +106,13 @@ def msgpack_dumps(data: JSONInputBin) -> bytes:
data: The data to serialize.
RETURNS (bytes): The serialized bytes.
"""
- return msgpack.dumps(data, use_bin_type=True)
+ return msgpack.dumps(
+ data,
+ # strict_types is False for everything except np.float64
+ # and np.complex128 (see above)
+ strict_types=True,
+ default=msgpack_encoders._run,
+ )
def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin:
@@ -22,11 +123,10 @@ def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin:
deserialization slower.
RETURNS: The deserialized Python object.
"""
- # msgpack-python docs suggest disabling gc before unpacking large messages
- gc.disable()
- msg = msgpack.loads(data, raw=False, use_list=use_list)
- gc.enable()
- return msg
+ with _without_gc():
+ return msgpack.loads(
+ data, raw=False, use_list=use_list, object_hook=msgpack_decoders._run
+ )
def write_msgpack(path: FilePath, data: JSONInputBin) -> None:
@@ -37,7 +137,7 @@ def write_msgpack(path: FilePath, data: JSONInputBin) -> None:
"""
file_path = force_path(path, require_exists=False)
with file_path.open("wb") as f:
- msgpack.dump(data, f, use_bin_type=True)
+ msgpack.dump(data, f, strict_types=True, default=msgpack_encoders._run)
def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin:
@@ -49,9 +149,7 @@ def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin:
RETURNS (JSONOutputBin): The loaded and deserialized content.
"""
file_path = force_path(path)
- with file_path.open("rb") as f:
- # msgpack-python docs suggest disabling gc before unpacking large messages
- gc.disable()
- msg = msgpack.load(f, raw=False, use_list=use_list)
- gc.enable()
- return msg
+ with file_path.open("rb") as f, _without_gc():
+ return msgpack.load(
+ f, raw=False, use_list=use_list, object_hook=msgpack_decoders._run
+ )
diff --git a/srsly/_msgpack_numpy.py b/srsly/_msgpack_numpy.py
new file mode 100644
index 0000000..a3a842c
--- /dev/null
+++ b/srsly/_msgpack_numpy.py
@@ -0,0 +1,80 @@
+"""
+Support for serialization of numpy data types with msgpack.
+This is a heavily modified fork of a very old version of msgpack-numpy.
+"""
+
+# Copyright (c) 2013-2018, Lev E. Givon
+# All rights reserved.
+# Distributed under the terms of the BSD license:
+# http://www.opensource.org/licenses/bsd-license
+try:
+ import numpy as np
+
+ has_numpy = True
+except ImportError:
+ has_numpy = False
+
+try:
+ import cupy
+
+ has_cupy = True
+except ImportError:
+ has_cupy = False
+
+
+def encode_numpy(obj):
+ """
+ Data encoder for serializing numpy data types.
+ """
+ if not has_numpy:
+ return obj
+ if has_cupy and isinstance(obj, cupy.ndarray):
+ obj = obj.get()
+ if isinstance(obj, np.ndarray):
+ # If the dtype is structured, store the interface description;
+ # otherwise, store the corresponding array protocol type string:
+ if obj.dtype.kind == "V":
+ kind = b"V"
+ descr = obj.dtype.descr
+ else:
+ kind = b""
+ descr = obj.dtype.str
+ return {
+ b"nd": True,
+ b"type": descr,
+ b"kind": kind,
+ b"shape": obj.shape,
+ b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(),
+ }
+ if isinstance(obj, (np.bool_, np.number)):
+ return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data}
+
+ return obj
+
+
+def decode_numpy(obj):
+ """
+ Decoder for deserializing numpy data types.
+ """
+ if b"nd" not in obj:
+ return obj
+
+ # Crash with a clean ModuleNotFoundError if numpy is not available
+ # instead of AttributeError
+ import numpy # noqa: F401
+
+ if obj[b"nd"]:
+ # Check if "kind" is in obj to enable decoding of data
+ # serialized with older versions (#20):
+ if b"kind" in obj and obj[b"kind"] == b"V":
+ descr = [
+ tuple(t.decode() if type(t) is bytes else t for t in d)
+ for d in obj[b"type"]
+ ]
+ else:
+ descr = obj[b"type"]
+ return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(obj[b"shape"])
+ else:
+ # NumPy scalar
+ descr = obj[b"type"]
+ return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0]
diff --git a/srsly/_pickle_api.py b/srsly/_pickle_api.py
index 0e894d9..3a2df96 100644
--- a/srsly/_pickle_api.py
+++ b/srsly/_pickle_api.py
@@ -1,6 +1,7 @@
from typing import Optional
-from . import cloudpickle
+import cloudpickle
+
from .util import JSONInput, JSONOutput
diff --git a/srsly/_yaml_api.py b/srsly/_yaml_api.py
index 36f24aa..101d3f5 100644
--- a/srsly/_yaml_api.py
+++ b/srsly/_yaml_api.py
@@ -2,8 +2,9 @@
from io import StringIO
import sys
-from .ruamel_yaml import YAML
-from .ruamel_yaml.representer import RepresenterError
+from ruamel.yaml import YAML
+from ruamel.yaml.representer import RepresenterError
+
from .util import force_path, FilePath, YAMLInput, YAMLOutput
diff --git a/srsly/about.py b/srsly/about.py
index 667b52f..528787c 100644
--- a/srsly/about.py
+++ b/srsly/about.py
@@ -1 +1 @@
-__version__ = "2.5.2"
+__version__ = "3.0.0"
diff --git a/srsly/cloudpickle/__init__.py b/srsly/cloudpickle/__init__.py
deleted file mode 100644
index c802221..0000000
--- a/srsly/cloudpickle/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from .cloudpickle import * # noqa
-from .cloudpickle_fast import CloudPickler, dumps, dump # noqa
-
-# Conform to the convention used by python serialization libraries, which
-# expose their Pickler subclass at top-level under the "Pickler" name.
-Pickler = CloudPickler
-
-__version__ = '2.2.0'
diff --git a/srsly/cloudpickle/cloudpickle.py b/srsly/cloudpickle/cloudpickle.py
deleted file mode 100644
index 317be69..0000000
--- a/srsly/cloudpickle/cloudpickle.py
+++ /dev/null
@@ -1,948 +0,0 @@
-"""
-This class is defined to override standard pickle functionality
-
-The goals of it follow:
--Serialize lambdas and nested functions to compiled byte code
--Deal with main module correctly
--Deal with other non-serializable objects
-
-It does not include an unpickler, as standard python unpickling suffices.
-
-This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
-`_.
-
-Copyright (c) 2012, Regents of the University of California.
-Copyright (c) 2009 `PiCloud, Inc. `_.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions
-are met:
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in the
- documentation and/or other materials provided with the distribution.
- * Neither the name of the University of California, Berkeley nor the
- names of its contributors may be used to endorse or promote
- products derived from this software without specific prior written
- permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
-TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
-LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
-NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""
-
-import builtins
-import dis
-import opcode
-import platform
-import sys
-import types
-import weakref
-import uuid
-import threading
-import typing
-import warnings
-
-from .compat import pickle
-from collections import OrderedDict
-from typing import ClassVar, Generic, Union, Tuple, Callable
-from pickle import _getattribute
-from importlib._bootstrap import _find_spec
-
-try: # pragma: no branch
- import typing_extensions as _typing_extensions
- from typing_extensions import Literal, Final
-except ImportError:
- _typing_extensions = Literal = Final = None
-
-if sys.version_info >= (3, 8):
- from types import CellType
-else:
- def f():
- a = 1
-
- def g():
- return a
- return g
- CellType = type(f().__closure__[0])
-
-
-# cloudpickle is meant for inter process communication: we expect all
-# communicating processes to run the same Python version hence we favor
-# communication speed over compatibility:
-DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
-
-# Names of modules whose resources should be treated as dynamic.
-_PICKLE_BY_VALUE_MODULES = set()
-
-# Track the provenance of reconstructed dynamic classes to make it possible to
-# reconstruct instances from the matching singleton class definition when
-# appropriate and preserve the usual "isinstance" semantics of Python objects.
-_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
-_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
-_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
-
-PYPY = platform.python_implementation() == "PyPy"
-
-builtin_code_type = None
-if PYPY:
- # builtin-code objects only exist in pypy
- builtin_code_type = type(float.__new__.__code__)
-
-_extract_code_globals_cache = weakref.WeakKeyDictionary()
-
-
-def _get_or_create_tracker_id(class_def):
- with _DYNAMIC_CLASS_TRACKER_LOCK:
- class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
- if class_tracker_id is None:
- class_tracker_id = uuid.uuid4().hex
- _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
- _DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
- return class_tracker_id
-
-
-def _lookup_class_or_track(class_tracker_id, class_def):
- if class_tracker_id is not None:
- with _DYNAMIC_CLASS_TRACKER_LOCK:
- class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
- class_tracker_id, class_def)
- _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
- return class_def
-
-
-def register_pickle_by_value(module):
- """Register a module to make it functions and classes picklable by value.
-
- By default, functions and classes that are attributes of an importable
- module are to be pickled by reference, that is relying on re-importing
- the attribute from the module at load time.
-
- If `register_pickle_by_value(module)` is called, all its functions and
- classes are subsequently to be pickled by value, meaning that they can
- be loaded in Python processes where the module is not importable.
-
- This is especially useful when developing a module in a distributed
- execution environment: restarting the client Python process with the new
- source code is enough: there is no need to re-install the new version
- of the module on all the worker nodes nor to restart the workers.
-
- Note: this feature is considered experimental. See the cloudpickle
- README.md file for more details and limitations.
- """
- if not isinstance(module, types.ModuleType):
- raise ValueError(
- f"Input should be a module object, got {str(module)} instead"
- )
- # In the future, cloudpickle may need a way to access any module registered
- # for pickling by value in order to introspect relative imports inside
- # functions pickled by value. (see
- # https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633).
- # This access can be ensured by checking that module is present in
- # sys.modules at registering time and assuming that it will still be in
- # there when accessed during pickling. Another alternative would be to
- # store a weakref to the module. Even though cloudpickle does not implement
- # this introspection yet, in order to avoid a possible breaking change
- # later, we still enforce the presence of module inside sys.modules.
- if module.__name__ not in sys.modules:
- raise ValueError(
- f"{module} was not imported correctly, have you used an "
- f"`import` statement to access it?"
- )
- _PICKLE_BY_VALUE_MODULES.add(module.__name__)
-
-
-def unregister_pickle_by_value(module):
- """Unregister that the input module should be pickled by value."""
- if not isinstance(module, types.ModuleType):
- raise ValueError(
- f"Input should be a module object, got {str(module)} instead"
- )
- if module.__name__ not in _PICKLE_BY_VALUE_MODULES:
- raise ValueError(f"{module} is not registered for pickle by value")
- else:
- _PICKLE_BY_VALUE_MODULES.remove(module.__name__)
-
-
-def list_registry_pickle_by_value():
- return _PICKLE_BY_VALUE_MODULES.copy()
-
-
-def _is_registered_pickle_by_value(module):
- module_name = module.__name__
- if module_name in _PICKLE_BY_VALUE_MODULES:
- return True
- while True:
- parent_name = module_name.rsplit(".", 1)[0]
- if parent_name == module_name:
- break
- if parent_name in _PICKLE_BY_VALUE_MODULES:
- return True
- module_name = parent_name
- return False
-
-
-def _whichmodule(obj, name):
- """Find the module an object belongs to.
-
- This function differs from ``pickle.whichmodule`` in two ways:
- - it does not mangle the cases where obj's module is __main__ and obj was
- not found in any module.
- - Errors arising during module introspection are ignored, as those errors
- are considered unwanted side effects.
- """
- if sys.version_info[:2] < (3, 7) and isinstance(obj, typing.TypeVar): # pragma: no branch # noqa
- # Workaround bug in old Python versions: prior to Python 3.7,
- # T.__module__ would always be set to "typing" even when the TypeVar T
- # would be defined in a different module.
- if name is not None and getattr(typing, name, None) is obj:
- # Built-in TypeVar defined in typing such as AnyStr
- return 'typing'
- else:
- # User defined or third-party TypeVar: __module__ attribute is
- # irrelevant, thus trigger a exhaustive search for obj in all
- # modules.
- module_name = None
- else:
- module_name = getattr(obj, '__module__', None)
-
- if module_name is not None:
- return module_name
- # Protect the iteration by using a copy of sys.modules against dynamic
- # modules that trigger imports of other modules upon calls to getattr or
- # other threads importing at the same time.
- for module_name, module in sys.modules.copy().items():
- # Some modules such as coverage can inject non-module objects inside
- # sys.modules
- if (
- module_name == '__main__' or
- module is None or
- not isinstance(module, types.ModuleType)
- ):
- continue
- try:
- if _getattribute(module, name)[0] is obj:
- return module_name
- except Exception:
- pass
- return None
-
-
-def _should_pickle_by_reference(obj, name=None):
- """Test whether an function or a class should be pickled by reference
-
- Pickling by reference means by that the object (typically a function or a
- class) is an attribute of a module that is assumed to be importable in the
- target Python environment. Loading will therefore rely on importing the
- module and then calling `getattr` on it to access the function or class.
-
- Pickling by reference is the only option to pickle functions and classes
- in the standard library. In cloudpickle the alternative option is to
- pickle by value (for instance for interactively or locally defined
- functions and classes or for attributes of modules that have been
- explicitly registered to be pickled by value.
- """
- if isinstance(obj, types.FunctionType) or issubclass(type(obj), type):
- module_and_name = _lookup_module_and_qualname(obj, name=name)
- if module_and_name is None:
- return False
- module, name = module_and_name
- return not _is_registered_pickle_by_value(module)
-
- elif isinstance(obj, types.ModuleType):
- # We assume that sys.modules is primarily used as a cache mechanism for
- # the Python import machinery. Checking if a module has been added in
- # is sys.modules therefore a cheap and simple heuristic to tell us
- # whether we can assume that a given module could be imported by name
- # in another Python process.
- if _is_registered_pickle_by_value(obj):
- return False
- return obj.__name__ in sys.modules
- else:
- raise TypeError(
- "cannot check importability of {} instances".format(
- type(obj).__name__)
- )
-
-
-def _lookup_module_and_qualname(obj, name=None):
- if name is None:
- name = getattr(obj, '__qualname__', None)
- if name is None: # pragma: no cover
- # This used to be needed for Python 2.7 support but is probably not
- # needed anymore. However we keep the __name__ introspection in case
- # users of cloudpickle rely on this old behavior for unknown reasons.
- name = getattr(obj, '__name__', None)
-
- module_name = _whichmodule(obj, name)
-
- if module_name is None:
- # In this case, obj.__module__ is None AND obj was not found in any
- # imported module. obj is thus treated as dynamic.
- return None
-
- if module_name == "__main__":
- return None
-
- # Note: if module_name is in sys.modules, the corresponding module is
- # assumed importable at unpickling time. See #357
- module = sys.modules.get(module_name, None)
- if module is None:
- # The main reason why obj's module would not be imported is that this
- # module has been dynamically created, using for example
- # types.ModuleType. The other possibility is that module was removed
- # from sys.modules after obj was created/imported. But this case is not
- # supported, as the standard pickle does not support it either.
- return None
-
- try:
- obj2, parent = _getattribute(module, name)
- except AttributeError:
- # obj was not found inside the module it points to
- return None
- if obj2 is not obj:
- return None
- return module, name
-
-
-def _extract_code_globals(co):
- """
- Find all globals names read or written to by codeblock co
- """
- out_names = _extract_code_globals_cache.get(co)
- if out_names is None:
- # We use a dict with None values instead of a set to get a
- # deterministic order (assuming Python 3.6+) and avoid introducing
- # non-deterministic pickle bytes as a results.
- out_names = {name: None for name in _walk_global_ops(co)}
-
- # Declaring a function inside another one using the "def ..."
- # syntax generates a constant code object corresponding to the one
- # of the nested function's As the nested function may itself need
- # global variables, we need to introspect its code, extract its
- # globals, (look for code object in it's co_consts attribute..) and
- # add the result to code_globals
- if co.co_consts:
- for const in co.co_consts:
- if isinstance(const, types.CodeType):
- out_names.update(_extract_code_globals(const))
-
- _extract_code_globals_cache[co] = out_names
-
- return out_names
-
-
-def _find_imported_submodules(code, top_level_dependencies):
- """
- Find currently imported submodules used by a function.
-
- Submodules used by a function need to be detected and referenced for the
- function to work correctly at depickling time. Because submodules can be
- referenced as attribute of their parent package (``package.submodule``), we
- need a special introspection technique that does not rely on GLOBAL-related
- opcodes to find references of them in a code object.
-
- Example:
- ```
- import concurrent.futures
- import cloudpickle
- def func():
- x = concurrent.futures.ThreadPoolExecutor
- if __name__ == '__main__':
- cloudpickle.dumps(func)
- ```
- The globals extracted by cloudpickle in the function's state include the
- concurrent package, but not its submodule (here, concurrent.futures), which
- is the module used by func. Find_imported_submodules will detect the usage
- of concurrent.futures. Saving this module alongside with func will ensure
- that calling func once depickled does not fail due to concurrent.futures
- not being imported
- """
-
- subimports = []
- # check if any known dependency is an imported package
- for x in top_level_dependencies:
- if (isinstance(x, types.ModuleType) and
- hasattr(x, '__package__') and x.__package__):
- # check if the package has any currently loaded sub-imports
- prefix = x.__name__ + '.'
- # A concurrent thread could mutate sys.modules,
- # make sure we iterate over a copy to avoid exceptions
- for name in list(sys.modules):
- # Older versions of pytest will add a "None" module to
- # sys.modules.
- if name is not None and name.startswith(prefix):
- # check whether the function can address the sub-module
- tokens = set(name[len(prefix):].split('.'))
- if not tokens - set(code.co_names):
- subimports.append(sys.modules[name])
- return subimports
-
-
-def cell_set(cell, value):
- """Set the value of a closure cell.
-
- The point of this function is to set the cell_contents attribute of a cell
- after its creation. This operation is necessary in case the cell contains a
- reference to the function the cell belongs to, as when calling the
- function's constructor
- ``f = types.FunctionType(code, globals, name, argdefs, closure)``,
- closure will not be able to contain the yet-to-be-created f.
-
- In Python3.7, cell_contents is writeable, so setting the contents of a cell
- can be done simply using
- >>> cell.cell_contents = value
-
- In earlier Python3 versions, the cell_contents attribute of a cell is read
- only, but this limitation can be worked around by leveraging the Python 3
- ``nonlocal`` keyword.
-
- In Python2 however, this attribute is read only, and there is no
- ``nonlocal`` keyword. For this reason, we need to come up with more
- complicated hacks to set this attribute.
-
- The chosen approach is to create a function with a STORE_DEREF opcode,
- which sets the content of a closure variable. Typically:
-
- >>> def inner(value):
- ... lambda: cell # the lambda makes cell a closure
- ... cell = value # cell is a closure, so this triggers a STORE_DEREF
-
- (Note that in Python2, A STORE_DEREF can never be triggered from an inner
- function. The function g for example here
- >>> def f(var):
- ... def g():
- ... var += 1
- ... return g
-
- will not modify the closure variable ``var```inplace, but instead try to
- load a local variable var and increment it. As g does not assign the local
- variable ``var`` any initial value, calling f(1)() will fail at runtime.)
-
- Our objective is to set the value of a given cell ``cell``. So we need to
- somewhat reference our ``cell`` object into the ``inner`` function so that
- this object (and not the smoke cell of the lambda function) gets affected
- by the STORE_DEREF operation.
-
- In inner, ``cell`` is referenced as a cell variable (an enclosing variable
- that is referenced by the inner function). If we create a new function
- cell_set with the exact same code as ``inner``, but with ``cell`` marked as
- a free variable instead, the STORE_DEREF will be applied on its closure -
- ``cell``, which we can specify explicitly during construction! The new
- cell_set variable thus actually sets the contents of a specified cell!
-
- Note: we do not make use of the ``nonlocal`` keyword to set the contents of
- a cell in early python3 versions to limit possible syntax errors in case
- test and checker libraries decide to parse the whole file.
- """
-
- if sys.version_info[:2] >= (3, 7): # pragma: no branch
- cell.cell_contents = value
- else:
- _cell_set = types.FunctionType(
- _cell_set_template_code, {}, '_cell_set', (), (cell,),)
- _cell_set(value)
-
-
-def _make_cell_set_template_code():
- def _cell_set_factory(value):
- lambda: cell
- cell = value
-
- co = _cell_set_factory.__code__
-
- _cell_set_template_code = types.CodeType(
- co.co_argcount,
- co.co_kwonlyargcount, # Python 3 only argument
- co.co_nlocals,
- co.co_stacksize,
- co.co_flags,
- co.co_code,
- co.co_consts,
- co.co_names,
- co.co_varnames,
- co.co_filename,
- co.co_name,
- co.co_firstlineno,
- co.co_lnotab,
- co.co_cellvars, # co_freevars is initialized with co_cellvars
- (), # co_cellvars is made empty
- )
- return _cell_set_template_code
-
-
-if sys.version_info[:2] < (3, 7):
- _cell_set_template_code = _make_cell_set_template_code()
-
-# relevant opcodes
-STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
-DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
-LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
-GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
-HAVE_ARGUMENT = dis.HAVE_ARGUMENT
-EXTENDED_ARG = dis.EXTENDED_ARG
-
-
-_BUILTIN_TYPE_NAMES = {}
-for k, v in types.__dict__.items():
- if type(v) is type:
- _BUILTIN_TYPE_NAMES[v] = k
-
-
-def _builtin_type(name):
- if name == "ClassType": # pragma: no cover
- # Backward compat to load pickle files generated with cloudpickle
- # < 1.3 even if loading pickle files from older versions is not
- # officially supported.
- return type
- return getattr(types, name)
-
-
-def _walk_global_ops(code):
- """
- Yield referenced name for all global-referencing instructions in *code*.
- """
- for instr in dis.get_instructions(code):
- op = instr.opcode
- if op in GLOBAL_OPS:
- yield instr.argval
-
-
-def _extract_class_dict(cls):
- """Retrieve a copy of the dict of a class without the inherited methods"""
- clsdict = dict(cls.__dict__) # copy dict proxy to a dict
- if len(cls.__bases__) == 1:
- inherited_dict = cls.__bases__[0].__dict__
- else:
- inherited_dict = {}
- for base in reversed(cls.__bases__):
- inherited_dict.update(base.__dict__)
- to_remove = []
- for name, value in clsdict.items():
- try:
- base_value = inherited_dict[name]
- if value is base_value:
- to_remove.append(name)
- except KeyError:
- pass
- for name in to_remove:
- clsdict.pop(name)
- return clsdict
-
-
-if sys.version_info[:2] < (3, 7): # pragma: no branch
- def _is_parametrized_type_hint(obj):
- # This is very cheap but might generate false positives. So try to
- # narrow it down is good as possible.
- type_module = getattr(type(obj), '__module__', None)
- from_typing_extensions = type_module == 'typing_extensions'
- from_typing = type_module == 'typing'
-
- # general typing Constructs
- is_typing = getattr(obj, '__origin__', None) is not None
-
- # typing_extensions.Literal
- is_literal = (
- (getattr(obj, '__values__', None) is not None)
- and from_typing_extensions
- )
-
- # typing_extensions.Final
- is_final = (
- (getattr(obj, '__type__', None) is not None)
- and from_typing_extensions
- )
-
- # typing.ClassVar
- is_classvar = (
- (getattr(obj, '__type__', None) is not None) and from_typing
- )
-
- # typing.Union/Tuple for old Python 3.5
- is_union = getattr(obj, '__union_params__', None) is not None
- is_tuple = getattr(obj, '__tuple_params__', None) is not None
- is_callable = (
- getattr(obj, '__result__', None) is not None and
- getattr(obj, '__args__', None) is not None
- )
- return any((is_typing, is_literal, is_final, is_classvar, is_union,
- is_tuple, is_callable))
-
- def _create_parametrized_type_hint(origin, args):
- return origin[args]
-else:
- _is_parametrized_type_hint = None
- _create_parametrized_type_hint = None
-
-
-def parametrized_type_hint_getinitargs(obj):
- # The distorted type check sematic for typing construct becomes:
- # ``type(obj) is type(TypeHint)``, which means "obj is a
- # parametrized TypeHint"
- if type(obj) is type(Literal): # pragma: no branch
- initargs = (Literal, obj.__values__)
- elif type(obj) is type(Final): # pragma: no branch
- initargs = (Final, obj.__type__)
- elif type(obj) is type(ClassVar):
- initargs = (ClassVar, obj.__type__)
- elif type(obj) is type(Generic):
- initargs = (obj.__origin__, obj.__args__)
- elif type(obj) is type(Union):
- initargs = (Union, obj.__args__)
- elif type(obj) is type(Tuple):
- initargs = (Tuple, obj.__args__)
- elif type(obj) is type(Callable):
- (*args, result) = obj.__args__
- if len(args) == 1 and args[0] is Ellipsis:
- args = Ellipsis
- else:
- args = list(args)
- initargs = (Callable, (args, result))
- else: # pragma: no cover
- raise pickle.PicklingError(
- f"Cloudpickle Error: Unknown type {type(obj)}"
- )
- return initargs
-
-
-# Tornado support
-
-def is_tornado_coroutine(func):
- """
- Return whether *func* is a Tornado coroutine function.
- Running coroutines are not supported.
- """
- if 'tornado.gen' not in sys.modules:
- return False
- gen = sys.modules['tornado.gen']
- if not hasattr(gen, "is_coroutine_function"):
- # Tornado version is too old
- return False
- return gen.is_coroutine_function(func)
-
-
-def _rebuild_tornado_coroutine(func):
- from tornado import gen
- return gen.coroutine(func)
-
-
-# including pickles unloading functions in this namespace
-load = pickle.load
-loads = pickle.loads
-
-
-def subimport(name):
- # We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is
- # the name of a submodule, __import__ will return the top-level root module
- # of this submodule. For instance, __import__('os.path') returns the `os`
- # module.
- __import__(name)
- return sys.modules[name]
-
-
-def dynamic_subimport(name, vars):
- mod = types.ModuleType(name)
- mod.__dict__.update(vars)
- mod.__dict__['__builtins__'] = builtins.__dict__
- return mod
-
-
-def _gen_ellipsis():
- return Ellipsis
-
-
-def _gen_not_implemented():
- return NotImplemented
-
-
-def _get_cell_contents(cell):
- try:
- return cell.cell_contents
- except ValueError:
- # sentinel used by ``_fill_function`` which will leave the cell empty
- return _empty_cell_value
-
-
-def instance(cls):
- """Create a new instance of a class.
-
- Parameters
- ----------
- cls : type
- The class to create an instance of.
-
- Returns
- -------
- instance : cls
- A new instance of ``cls``.
- """
- return cls()
-
-
-@instance
-class _empty_cell_value:
- """sentinel for empty closures
- """
- @classmethod
- def __reduce__(cls):
- return cls.__name__
-
-
-def _fill_function(*args):
- """Fills in the rest of function data into the skeleton function object
-
- The skeleton itself is create by _make_skel_func().
- """
- if len(args) == 2:
- func = args[0]
- state = args[1]
- elif len(args) == 5:
- # Backwards compat for cloudpickle v0.4.0, after which the `module`
- # argument was introduced
- func = args[0]
- keys = ['globals', 'defaults', 'dict', 'closure_values']
- state = dict(zip(keys, args[1:]))
- elif len(args) == 6:
- # Backwards compat for cloudpickle v0.4.1, after which the function
- # state was passed as a dict to the _fill_function it-self.
- func = args[0]
- keys = ['globals', 'defaults', 'dict', 'module', 'closure_values']
- state = dict(zip(keys, args[1:]))
- else:
- raise ValueError(f'Unexpected _fill_value arguments: {args!r}')
-
- # - At pickling time, any dynamic global variable used by func is
- # serialized by value (in state['globals']).
- # - At unpickling time, func's __globals__ attribute is initialized by
- # first retrieving an empty isolated namespace that will be shared
- # with other functions pickled from the same original module
- # by the same CloudPickler instance and then updated with the
- # content of state['globals'] to populate the shared isolated
- # namespace with all the global variables that are specifically
- # referenced for this function.
- func.__globals__.update(state['globals'])
-
- func.__defaults__ = state['defaults']
- func.__dict__ = state['dict']
- if 'annotations' in state:
- func.__annotations__ = state['annotations']
- if 'doc' in state:
- func.__doc__ = state['doc']
- if 'name' in state:
- func.__name__ = state['name']
- if 'module' in state:
- func.__module__ = state['module']
- if 'qualname' in state:
- func.__qualname__ = state['qualname']
- if 'kwdefaults' in state:
- func.__kwdefaults__ = state['kwdefaults']
- # _cloudpickle_subimports is a set of submodules that must be loaded for
- # the pickled function to work correctly at unpickling time. Now that these
- # submodules are depickled (hence imported), they can be removed from the
- # object's state (the object state only served as a reference holder to
- # these submodules)
- if '_cloudpickle_submodules' in state:
- state.pop('_cloudpickle_submodules')
-
- cells = func.__closure__
- if cells is not None:
- for cell, value in zip(cells, state['closure_values']):
- if value is not _empty_cell_value:
- cell_set(cell, value)
-
- return func
-
-
-def _make_function(code, globals, name, argdefs, closure):
- # Setting __builtins__ in globals is needed for nogil CPython.
- globals["__builtins__"] = __builtins__
- return types.FunctionType(code, globals, name, argdefs, closure)
-
-
-def _make_empty_cell():
- if False:
- # trick the compiler into creating an empty cell in our lambda
- cell = None
- raise AssertionError('this route should not be executed')
-
- return (lambda: cell).__closure__[0]
-
-
-def _make_cell(value=_empty_cell_value):
- cell = _make_empty_cell()
- if value is not _empty_cell_value:
- cell_set(cell, value)
- return cell
-
-
-def _make_skel_func(code, cell_count, base_globals=None):
- """ Creates a skeleton function object that contains just the provided
- code and the correct number of cells in func_closure. All other
- func attributes (e.g. func_globals) are empty.
- """
- # This function is deprecated and should be removed in cloudpickle 1.7
- warnings.warn(
- "A pickle file created using an old (<=1.4.1) version of cloudpickle "
- "is currently being loaded. This is not supported by cloudpickle and "
- "will break in cloudpickle 1.7", category=UserWarning
- )
- # This is backward-compatibility code: for cloudpickle versions between
- # 0.5.4 and 0.7, base_globals could be a string or None. base_globals
- # should now always be a dictionary.
- if base_globals is None or isinstance(base_globals, str):
- base_globals = {}
-
- base_globals['__builtins__'] = __builtins__
-
- closure = (
- tuple(_make_empty_cell() for _ in range(cell_count))
- if cell_count >= 0 else
- None
- )
- return types.FunctionType(code, base_globals, None, None, closure)
-
-
-def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
- class_tracker_id, extra):
- """Build dynamic class with an empty __dict__ to be filled once memoized
-
- If class_tracker_id is not None, try to lookup an existing class definition
- matching that id. If none is found, track a newly reconstructed class
- definition under that id so that other instances stemming from the same
- class id will also reuse this class definition.
-
- The "extra" variable is meant to be a dict (or None) that can be used for
- forward compatibility shall the need arise.
- """
- skeleton_class = types.new_class(
- name, bases, {'metaclass': type_constructor},
- lambda ns: ns.update(type_kwargs)
- )
- return _lookup_class_or_track(class_tracker_id, skeleton_class)
-
-
-def _rehydrate_skeleton_class(skeleton_class, class_dict):
- """Put attributes from `class_dict` back on `skeleton_class`.
-
- See CloudPickler.save_dynamic_class for more info.
- """
- registry = None
- for attrname, attr in class_dict.items():
- if attrname == "_abc_impl":
- registry = attr
- else:
- setattr(skeleton_class, attrname, attr)
- if registry is not None:
- for subclass in registry:
- skeleton_class.register(subclass)
-
- return skeleton_class
-
-
-def _make_skeleton_enum(bases, name, qualname, members, module,
- class_tracker_id, extra):
- """Build dynamic enum with an empty __dict__ to be filled once memoized
-
- The creation of the enum class is inspired by the code of
- EnumMeta._create_.
-
- If class_tracker_id is not None, try to lookup an existing enum definition
- matching that id. If none is found, track a newly reconstructed enum
- definition under that id so that other instances stemming from the same
- class id will also reuse this enum definition.
-
- The "extra" variable is meant to be a dict (or None) that can be used for
- forward compatibility shall the need arise.
- """
- # enums always inherit from their base Enum class at the last position in
- # the list of base classes:
- enum_base = bases[-1]
- metacls = enum_base.__class__
- classdict = metacls.__prepare__(name, bases)
-
- for member_name, member_value in members.items():
- classdict[member_name] = member_value
- enum_class = metacls.__new__(metacls, name, bases, classdict)
- enum_class.__module__ = module
- enum_class.__qualname__ = qualname
-
- return _lookup_class_or_track(class_tracker_id, enum_class)
-
-
-def _make_typevar(name, bound, constraints, covariant, contravariant,
- class_tracker_id):
- tv = typing.TypeVar(
- name, *constraints, bound=bound,
- covariant=covariant, contravariant=contravariant
- )
- if class_tracker_id is not None:
- return _lookup_class_or_track(class_tracker_id, tv)
- else: # pragma: nocover
- # Only for Python 3.5.3 compat.
- return tv
-
-
-def _decompose_typevar(obj):
- return (
- obj.__name__, obj.__bound__, obj.__constraints__,
- obj.__covariant__, obj.__contravariant__,
- _get_or_create_tracker_id(obj),
- )
-
-
-def _typevar_reduce(obj):
- # TypeVar instances require the module information hence why we
- # are not using the _should_pickle_by_reference directly
- module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
-
- if module_and_name is None:
- return (_make_typevar, _decompose_typevar(obj))
- elif _is_registered_pickle_by_value(module_and_name[0]):
- return (_make_typevar, _decompose_typevar(obj))
-
- return (getattr, module_and_name)
-
-
-def _get_bases(typ):
- if '__orig_bases__' in getattr(typ, '__dict__', {}):
- # For generic types (see PEP 560)
- # Note that simply checking `hasattr(typ, '__orig_bases__')` is not
- # correct. Subclasses of a fully-parameterized generic class does not
- # have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')`
- # will return True because it's defined in the base class.
- bases_attr = '__orig_bases__'
- else:
- # For regular class objects
- bases_attr = '__bases__'
- return getattr(typ, bases_attr)
-
-
-def _make_dict_keys(obj, is_ordered=False):
- if is_ordered:
- return OrderedDict.fromkeys(obj).keys()
- else:
- return dict.fromkeys(obj).keys()
-
-
-def _make_dict_values(obj, is_ordered=False):
- if is_ordered:
- return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
- else:
- return {i: _ for i, _ in enumerate(obj)}.values()
-
-
-def _make_dict_items(obj, is_ordered=False):
- if is_ordered:
- return OrderedDict(obj).items()
- else:
- return obj.items()
diff --git a/srsly/cloudpickle/cloudpickle_fast.py b/srsly/cloudpickle/cloudpickle_fast.py
deleted file mode 100644
index 8741dcb..0000000
--- a/srsly/cloudpickle/cloudpickle_fast.py
+++ /dev/null
@@ -1,844 +0,0 @@
-"""
-New, fast version of the CloudPickler.
-
-This new CloudPickler class can now extend the fast C Pickler instead of the
-previous Python implementation of the Pickler class. Because this functionality
-is only available for Python versions 3.8+, a lot of backward-compatibility
-code is also removed.
-
-Note that the C Pickler subclassing API is CPython-specific. Therefore, some
-guards present in cloudpickle.py that were written to handle PyPy specificities
-are not present in cloudpickle_fast.py
-"""
-import _collections_abc
-import abc
-import copyreg
-import io
-import itertools
-import logging
-import sys
-import struct
-import types
-import weakref
-import typing
-
-from enum import Enum
-from collections import ChainMap, OrderedDict
-
-from .compat import pickle, Pickler
-from .cloudpickle import (
- _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
- _find_imported_submodules, _get_cell_contents, _should_pickle_by_reference,
- _builtin_type, _get_or_create_tracker_id, _make_skeleton_class,
- _make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport,
- _typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType,
- _is_parametrized_type_hint, PYPY, cell_set,
- parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
- builtin_code_type,
- _make_dict_keys, _make_dict_values, _make_dict_items, _make_function,
-)
-
-
-if pickle.HIGHEST_PROTOCOL >= 5:
- # Shorthands similar to pickle.dump/pickle.dumps
-
- def dump(obj, file, protocol=None, buffer_callback=None):
- """Serialize obj as bytes streamed into file
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
- speed between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- CloudPickler(
- file, protocol=protocol, buffer_callback=buffer_callback
- ).dump(obj)
-
- def dumps(obj, protocol=None, buffer_callback=None):
- """Serialize obj as a string of bytes allocated in memory
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
- speed between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- with io.BytesIO() as file:
- cp = CloudPickler(
- file, protocol=protocol, buffer_callback=buffer_callback
- )
- cp.dump(obj)
- return file.getvalue()
-
-else:
- # Shorthands similar to pickle.dump/pickle.dumps
- def dump(obj, file, protocol=None):
- """Serialize obj as bytes streamed into file
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
- speed between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- CloudPickler(file, protocol=protocol).dump(obj)
-
- def dumps(obj, protocol=None):
- """Serialize obj as a string of bytes allocated in memory
-
- protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
- pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
- speed between processes running the same Python version.
-
- Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
- compatibility with older versions of Python.
- """
- with io.BytesIO() as file:
- cp = CloudPickler(file, protocol=protocol)
- cp.dump(obj)
- return file.getvalue()
-
-
-load, loads = pickle.load, pickle.loads
-
-
-# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
-# -------------------------------------------------
-
-def _class_getnewargs(obj):
- type_kwargs = {}
- if "__slots__" in obj.__dict__:
- type_kwargs["__slots__"] = obj.__slots__
-
- __dict__ = obj.__dict__.get('__dict__', None)
- if isinstance(__dict__, property):
- type_kwargs['__dict__'] = __dict__
-
- return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
- _get_or_create_tracker_id(obj), None)
-
-
-def _enum_getnewargs(obj):
- members = {e.name: e.value for e in obj}
- return (obj.__bases__, obj.__name__, obj.__qualname__, members,
- obj.__module__, _get_or_create_tracker_id(obj), None)
-
-
-# COLLECTION OF OBJECTS RECONSTRUCTORS
-# ------------------------------------
-def _file_reconstructor(retval):
- return retval
-
-
-# COLLECTION OF OBJECTS STATE GETTERS
-# -----------------------------------
-def _function_getstate(func):
- # - Put func's dynamic attributes (stored in func.__dict__) in state. These
- # attributes will be restored at unpickling time using
- # f.__dict__.update(state)
- # - Put func's members into slotstate. Such attributes will be restored at
- # unpickling time by iterating over slotstate and calling setattr(func,
- # slotname, slotvalue)
- slotstate = {
- "__name__": func.__name__,
- "__qualname__": func.__qualname__,
- "__annotations__": func.__annotations__,
- "__kwdefaults__": func.__kwdefaults__,
- "__defaults__": func.__defaults__,
- "__module__": func.__module__,
- "__doc__": func.__doc__,
- "__closure__": func.__closure__,
- }
-
- f_globals_ref = _extract_code_globals(func.__code__)
- f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
- func.__globals__}
-
- closure_values = (
- list(map(_get_cell_contents, func.__closure__))
- if func.__closure__ is not None else ()
- )
-
- # Extract currently-imported submodules used by func. Storing these modules
- # in a smoke _cloudpickle_subimports attribute of the object's state will
- # trigger the side effect of importing these modules at unpickling time
- # (which is necessary for func to work correctly once depickled)
- slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
- func.__code__, itertools.chain(f_globals.values(), closure_values))
- slotstate["__globals__"] = f_globals
-
- state = func.__dict__
- return state, slotstate
-
-
-def _class_getstate(obj):
- clsdict = _extract_class_dict(obj)
- clsdict.pop('__weakref__', None)
-
- if issubclass(type(obj), abc.ABCMeta):
- # If obj is an instance of an ABCMeta subclass, don't pickle the
- # cache/negative caches populated during isinstance/issubclass
- # checks, but pickle the list of registered subclasses of obj.
- clsdict.pop('_abc_cache', None)
- clsdict.pop('_abc_negative_cache', None)
- clsdict.pop('_abc_negative_cache_version', None)
- registry = clsdict.pop('_abc_registry', None)
- if registry is None:
- # in Python3.7+, the abc caches and registered subclasses of a
- # class are bundled into the single _abc_impl attribute
- clsdict.pop('_abc_impl', None)
- (registry, _, _, _) = abc._get_dump(obj)
-
- clsdict["_abc_impl"] = [subclass_weakref()
- for subclass_weakref in registry]
- else:
- # In the above if clause, registry is a set of weakrefs -- in
- # this case, registry is a WeakSet
- clsdict["_abc_impl"] = [type_ for type_ in registry]
-
- if "__slots__" in clsdict:
- # pickle string length optimization: member descriptors of obj are
- # created automatically from obj's __slots__ attribute, no need to
- # save them in obj's state
- if isinstance(obj.__slots__, str):
- clsdict.pop(obj.__slots__)
- else:
- for k in obj.__slots__:
- clsdict.pop(k, None)
-
- clsdict.pop('__dict__', None) # unpicklable property object
-
- return (clsdict, {})
-
-
-def _enum_getstate(obj):
- clsdict, slotstate = _class_getstate(obj)
-
- members = {e.name: e.value for e in obj}
- # Cleanup the clsdict that will be passed to _rehydrate_skeleton_class:
- # Those attributes are already handled by the metaclass.
- for attrname in ["_generate_next_value_", "_member_names_",
- "_member_map_", "_member_type_",
- "_value2member_map_"]:
- clsdict.pop(attrname, None)
- for member in members:
- clsdict.pop(member)
- # Special handling of Enum subclasses
- return clsdict, slotstate
-
-
-# COLLECTIONS OF OBJECTS REDUCERS
-# -------------------------------
-# A reducer is a function taking a single argument (obj), and that returns a
-# tuple with all the necessary data to re-construct obj. Apart from a few
-# exceptions (list, dict, bytes, int, etc.), a reducer is necessary to
-# correctly pickle an object.
-# While many built-in objects (Exceptions objects, instances of the "object"
-# class, etc), are shipped with their own built-in reducer (invoked using
-# obj.__reduce__), some do not. The following methods were created to "fill
-# these holes".
-
-def _code_reduce(obj):
- """codeobject reducer"""
- # If you are not sure about the order of arguments, take a look at help
- # of the specific type from types, for example:
- # >>> from types import CodeType
- # >>> help(CodeType)
- if hasattr(obj, "co_exceptiontable"): # pragma: no branch
- # Python 3.11 and later: there are some new attributes
- # related to the enhanced exceptions.
- args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
- obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable,
- obj.co_freevars, obj.co_cellvars,
- )
- elif hasattr(obj, "co_linetable"): # pragma: no branch
- # Python 3.10 and later: obj.co_lnotab is deprecated and constructor
- # expects obj.co_linetable instead.
- args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_linetable, obj.co_freevars,
- obj.co_cellvars
- )
- elif hasattr(obj, "co_nmeta"): # pragma: no cover
- # "nogil" Python: modified attributes from 3.9
- args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_framesize,
- obj.co_ndefaultargs, obj.co_nmeta,
- obj.co_flags, obj.co_code, obj.co_consts,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_lnotab, obj.co_exc_handlers,
- obj.co_jump_table, obj.co_freevars, obj.co_cellvars,
- obj.co_free2reg, obj.co_cell2reg
- )
- elif hasattr(obj, "co_posonlyargcount"):
- # Backward compat for 3.9 and older
- args = (
- obj.co_argcount, obj.co_posonlyargcount,
- obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
- obj.co_cellvars
- )
- else:
- # Backward compat for even older versions of Python
- args = (
- obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
- obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
- obj.co_names, obj.co_varnames, obj.co_filename,
- obj.co_name, obj.co_firstlineno, obj.co_lnotab,
- obj.co_freevars, obj.co_cellvars
- )
- return types.CodeType, args
-
-
-def _cell_reduce(obj):
- """Cell (containing values of a function's free variables) reducer"""
- try:
- obj.cell_contents
- except ValueError: # cell is empty
- return _make_empty_cell, ()
- else:
- return _make_cell, (obj.cell_contents, )
-
-
-def _classmethod_reduce(obj):
- orig_func = obj.__func__
- return type(obj), (orig_func,)
-
-
-def _file_reduce(obj):
- """Save a file"""
- import io
-
- if not hasattr(obj, "name") or not hasattr(obj, "mode"):
- raise pickle.PicklingError(
- "Cannot pickle files that do not map to an actual file"
- )
- if obj is sys.stdout:
- return getattr, (sys, "stdout")
- if obj is sys.stderr:
- return getattr, (sys, "stderr")
- if obj is sys.stdin:
- raise pickle.PicklingError("Cannot pickle standard input")
- if obj.closed:
- raise pickle.PicklingError("Cannot pickle closed files")
- if hasattr(obj, "isatty") and obj.isatty():
- raise pickle.PicklingError(
- "Cannot pickle files that map to tty objects"
- )
- if "r" not in obj.mode and "+" not in obj.mode:
- raise pickle.PicklingError(
- "Cannot pickle files that are not opened for reading: %s"
- % obj.mode
- )
-
- name = obj.name
-
- retval = io.StringIO()
-
- try:
- # Read the whole file
- curloc = obj.tell()
- obj.seek(0)
- contents = obj.read()
- obj.seek(curloc)
- except IOError as e:
- raise pickle.PicklingError(
- "Cannot pickle file %s as it cannot be read" % name
- ) from e
- retval.write(contents)
- retval.seek(curloc)
-
- retval.name = name
- return _file_reconstructor, (retval,)
-
-
-def _getset_descriptor_reduce(obj):
- return getattr, (obj.__objclass__, obj.__name__)
-
-
-def _mappingproxy_reduce(obj):
- return types.MappingProxyType, (dict(obj),)
-
-
-def _memoryview_reduce(obj):
- return bytes, (obj.tobytes(),)
-
-
-def _module_reduce(obj):
- if _should_pickle_by_reference(obj):
- return subimport, (obj.__name__,)
- else:
- # Some external libraries can populate the "__builtins__" entry of a
- # module's `__dict__` with unpicklable objects (see #316). For that
- # reason, we do not attempt to pickle the "__builtins__" entry, and
- # restore a default value for it at unpickling time.
- state = obj.__dict__.copy()
- state.pop('__builtins__', None)
- return dynamic_subimport, (obj.__name__, state)
-
-
-def _method_reduce(obj):
- return (types.MethodType, (obj.__func__, obj.__self__))
-
-
-def _logger_reduce(obj):
- return logging.getLogger, (obj.name,)
-
-
-def _root_logger_reduce(obj):
- return logging.getLogger, ()
-
-
-def _property_reduce(obj):
- return property, (obj.fget, obj.fset, obj.fdel, obj.__doc__)
-
-
-def _weakset_reduce(obj):
- return weakref.WeakSet, (list(obj),)
-
-
-def _dynamic_class_reduce(obj):
- """
- Save a class that can't be stored as module global.
-
- This method is used to serialize classes that are defined inside
- functions, or that otherwise can't be serialized as attribute lookups
- from global modules.
- """
- if Enum is not None and issubclass(obj, Enum):
- return (
- _make_skeleton_enum, _enum_getnewargs(obj), _enum_getstate(obj),
- None, None, _class_setstate
- )
- else:
- return (
- _make_skeleton_class, _class_getnewargs(obj), _class_getstate(obj),
- None, None, _class_setstate
- )
-
-
-def _class_reduce(obj):
- """Select the reducer depending on the dynamic nature of the class obj"""
- if obj is type(None): # noqa
- return type, (None,)
- elif obj is type(Ellipsis):
- return type, (Ellipsis,)
- elif obj is type(NotImplemented):
- return type, (NotImplemented,)
- elif obj in _BUILTIN_TYPE_NAMES:
- return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],)
- elif not _should_pickle_by_reference(obj):
- return _dynamic_class_reduce(obj)
- return NotImplemented
-
-
-def _dict_keys_reduce(obj):
- # Safer not to ship the full dict as sending the rest might
- # be unintended and could potentially cause leaking of
- # sensitive information
- return _make_dict_keys, (list(obj), )
-
-
-def _dict_values_reduce(obj):
- # Safer not to ship the full dict as sending the rest might
- # be unintended and could potentially cause leaking of
- # sensitive information
- return _make_dict_values, (list(obj), )
-
-
-def _dict_items_reduce(obj):
- return _make_dict_items, (dict(obj), )
-
-
-def _odict_keys_reduce(obj):
- # Safer not to ship the full dict as sending the rest might
- # be unintended and could potentially cause leaking of
- # sensitive information
- return _make_dict_keys, (list(obj), True)
-
-
-def _odict_values_reduce(obj):
- # Safer not to ship the full dict as sending the rest might
- # be unintended and could potentially cause leaking of
- # sensitive information
- return _make_dict_values, (list(obj), True)
-
-
-def _odict_items_reduce(obj):
- return _make_dict_items, (dict(obj), True)
-
-
-# COLLECTIONS OF OBJECTS STATE SETTERS
-# ------------------------------------
-# state setters are called at unpickling time, once the object is created and
-# it has to be updated to how it was at unpickling time.
-
-
-def _function_setstate(obj, state):
- """Update the state of a dynamic function.
-
- As __closure__ and __globals__ are readonly attributes of a function, we
- cannot rely on the native setstate routine of pickle.load_build, that calls
- setattr on items of the slotstate. Instead, we have to modify them inplace.
- """
- state, slotstate = state
- obj.__dict__.update(state)
-
- obj_globals = slotstate.pop("__globals__")
- obj_closure = slotstate.pop("__closure__")
- # _cloudpickle_subimports is a set of submodules that must be loaded for
- # the pickled function to work correctly at unpickling time. Now that these
- # submodules are depickled (hence imported), they can be removed from the
- # object's state (the object state only served as a reference holder to
- # these submodules)
- slotstate.pop("_cloudpickle_submodules")
-
- obj.__globals__.update(obj_globals)
- obj.__globals__["__builtins__"] = __builtins__
-
- if obj_closure is not None:
- for i, cell in enumerate(obj_closure):
- try:
- value = cell.cell_contents
- except ValueError: # cell is empty
- continue
- cell_set(obj.__closure__[i], value)
-
- for k, v in slotstate.items():
- setattr(obj, k, v)
-
-
-def _class_setstate(obj, state):
- state, slotstate = state
- registry = None
- for attrname, attr in state.items():
- if attrname == "_abc_impl":
- registry = attr
- else:
- setattr(obj, attrname, attr)
- if registry is not None:
- for subclass in registry:
- obj.register(subclass)
-
- return obj
-
-
-class CloudPickler(Pickler):
- # set of reducers defined and used by cloudpickle (private)
- _dispatch_table = {}
- _dispatch_table[classmethod] = _classmethod_reduce
- _dispatch_table[io.TextIOWrapper] = _file_reduce
- _dispatch_table[logging.Logger] = _logger_reduce
- _dispatch_table[logging.RootLogger] = _root_logger_reduce
- _dispatch_table[memoryview] = _memoryview_reduce
- _dispatch_table[property] = _property_reduce
- _dispatch_table[staticmethod] = _classmethod_reduce
- _dispatch_table[CellType] = _cell_reduce
- _dispatch_table[types.CodeType] = _code_reduce
- _dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce
- _dispatch_table[types.ModuleType] = _module_reduce
- _dispatch_table[types.MethodType] = _method_reduce
- _dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
- _dispatch_table[weakref.WeakSet] = _weakset_reduce
- _dispatch_table[typing.TypeVar] = _typevar_reduce
- _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
- _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
- _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
- _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce
- _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce
- _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce
- _dispatch_table[abc.abstractmethod] = _classmethod_reduce
- _dispatch_table[abc.abstractclassmethod] = _classmethod_reduce
- _dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce
- _dispatch_table[abc.abstractproperty] = _property_reduce
-
- dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
-
- # function reducers are defined as instance methods of CloudPickler
- # objects, as they rely on a CloudPickler attribute (globals_ref)
- def _dynamic_function_reduce(self, func):
- """Reduce a function that is not pickleable via attribute lookup."""
- newargs = self._function_getnewargs(func)
- state = _function_getstate(func)
- return (_make_function, newargs, state, None, None,
- _function_setstate)
-
- def _function_reduce(self, obj):
- """Reducer for function objects.
-
- If obj is a top-level attribute of a file-backed module, this
- reducer returns NotImplemented, making the CloudPickler fallback to
- traditional _pickle.Pickler routines to save obj. Otherwise, it reduces
- obj using a custom cloudpickle reducer designed specifically to handle
- dynamic functions.
-
- As opposed to cloudpickle.py, There no special handling for builtin
- pypy functions because cloudpickle_fast is CPython-specific.
- """
- if _should_pickle_by_reference(obj):
- return NotImplemented
- else:
- return self._dynamic_function_reduce(obj)
-
- def _function_getnewargs(self, func):
- code = func.__code__
-
- # base_globals represents the future global namespace of func at
- # unpickling time. Looking it up and storing it in
- # CloudpiPickler.globals_ref allow functions sharing the same globals
- # at pickling time to also share them once unpickled, at one condition:
- # since globals_ref is an attribute of a CloudPickler instance, and
- # that a new CloudPickler is created each time pickle.dump or
- # pickle.dumps is called, functions also need to be saved within the
- # same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
- # cloudpickle.dumps([f1, f2])). There is no such limitation when using
- # CloudPickler.dump, as long as the multiple invocations are bound to
- # the same CloudPickler.
- base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
-
- if base_globals == {}:
- # Add module attributes used to resolve relative imports
- # instructions inside func.
- for k in ["__package__", "__name__", "__path__", "__file__"]:
- if k in func.__globals__:
- base_globals[k] = func.__globals__[k]
-
- # Do not bind the free variables before the function is created to
- # avoid infinite recursion.
- if func.__closure__ is None:
- closure = None
- else:
- closure = tuple(
- _make_empty_cell() for _ in range(len(code.co_freevars)))
-
- return code, base_globals, None, None, closure
-
- def dump(self, obj):
- try:
- return Pickler.dump(self, obj)
- except RuntimeError as e:
- if "recursion" in e.args[0]:
- msg = (
- "Could not pickle object as excessively deep recursion "
- "required."
- )
- raise pickle.PicklingError(msg) from e
- else:
- raise
-
- if pickle.HIGHEST_PROTOCOL >= 5:
- def __init__(self, file, protocol=None, buffer_callback=None):
- if protocol is None:
- protocol = DEFAULT_PROTOCOL
- Pickler.__init__(
- self, file, protocol=protocol, buffer_callback=buffer_callback
- )
- # map functions __globals__ attribute ids, to ensure that functions
- # sharing the same global namespace at pickling time also share
- # their global namespace at unpickling time.
- self.globals_ref = {}
- self.proto = int(protocol)
- else:
- def __init__(self, file, protocol=None):
- if protocol is None:
- protocol = DEFAULT_PROTOCOL
- Pickler.__init__(self, file, protocol=protocol)
- # map functions __globals__ attribute ids, to ensure that functions
- # sharing the same global namespace at pickling time also share
- # their global namespace at unpickling time.
- self.globals_ref = {}
- assert hasattr(self, 'proto')
-
- if pickle.HIGHEST_PROTOCOL >= 5 and not PYPY:
- # Pickler is the C implementation of the CPython pickler and therefore
- # we rely on reduce_override method to customize the pickler behavior.
-
- # `CloudPickler.dispatch` is only left for backward compatibility - note
- # that when using protocol 5, `CloudPickler.dispatch` is not an
- # extension of `Pickler.dispatch` dictionary, because CloudPickler
- # subclasses the C-implemented Pickler, which does not expose a
- # `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler
- # used `CloudPickler.dispatch` as a class-level attribute storing all
- # reducers implemented by cloudpickle, but the attribute name was not a
- # great choice given the meaning of `CloudPickler.dispatch` when
- # `CloudPickler` extends the pure-python pickler.
- dispatch = dispatch_table
-
- # Implementation of the reducer_override callback, in order to
- # efficiently serialize dynamic functions and classes by subclassing
- # the C-implemented Pickler.
- # TODO: decorrelate reducer_override (which is tied to CPython's
- # implementation - would it make sense to backport it to pypy? - and
- # pickle's protocol 5 which is implementation agnostic. Currently, the
- # availability of both notions coincide on CPython's pickle and the
- # pickle5 backport, but it may not be the case anymore when pypy
- # implements protocol 5
-
- def reducer_override(self, obj):
- """Type-agnostic reducing callback for function and classes.
-
- For performance reasons, subclasses of the C _pickle.Pickler class
- cannot register custom reducers for functions and classes in the
- dispatch_table. Reducer for such types must instead implemented in
- the special reducer_override method.
-
- Note that method will be called for any object except a few
- builtin-types (int, lists, dicts etc.), which differs from reducers
- in the Pickler's dispatch_table, each of them being invoked for
- objects of a specific type only.
-
- This property comes in handy for classes: although most classes are
- instances of the ``type`` metaclass, some of them can be instances
- of other custom metaclasses (such as enum.EnumMeta for example). In
- particular, the metaclass will likely not be known in advance, and
- thus cannot be special-cased using an entry in the dispatch_table.
- reducer_override, among other things, allows us to register a
- reducer that will be called for any class, independently of its
- type.
-
-
- Notes:
-
- * reducer_override has the priority over dispatch_table-registered
- reducers.
- * reducer_override can be used to fix other limitations of
- cloudpickle for other types that suffered from type-specific
- reducers, such as Exceptions. See
- https://github.com/cloudpipe/cloudpickle/issues/248
- """
- if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
- return (
- _create_parametrized_type_hint,
- parametrized_type_hint_getinitargs(obj)
- )
- t = type(obj)
- try:
- is_anyclass = issubclass(t, type)
- except TypeError: # t is not a class (old Boost; see SF #502085)
- is_anyclass = False
-
- if is_anyclass:
- return _class_reduce(obj)
- elif isinstance(obj, types.FunctionType):
- return self._function_reduce(obj)
- else:
- # fallback to save_global, including the Pickler's
- # dispatch_table
- return NotImplemented
-
- else:
- # When reducer_override is not available, hack the pure-Python
- # Pickler's types.FunctionType and type savers. Note: the type saver
- # must override Pickler.save_global, because pickle.py contains a
- # hard-coded call to save_global when pickling meta-classes.
- dispatch = Pickler.dispatch.copy()
-
- def _save_reduce_pickle5(self, func, args, state=None, listitems=None,
- dictitems=None, state_setter=None, obj=None):
- save = self.save
- write = self.write
- self.save_reduce(
- func, args, state=None, listitems=listitems,
- dictitems=dictitems, obj=obj
- )
- # backport of the Python 3.8 state_setter pickle operations
- save(state_setter)
- save(obj) # simple BINGET opcode as obj is already memoized.
- save(state)
- write(pickle.TUPLE2)
- # Trigger a state_setter(obj, state) function call.
- write(pickle.REDUCE)
- # The purpose of state_setter is to carry-out an
- # inplace modification of obj. We do not care about what the
- # method might return, so its output is eventually removed from
- # the stack.
- write(pickle.POP)
-
- def save_global(self, obj, name=None, pack=struct.pack):
- """
- Save a "global".
-
- The name of this method is somewhat misleading: all types get
- dispatched here.
- """
- if obj is type(None): # noqa
- return self.save_reduce(type, (None,), obj=obj)
- elif obj is type(Ellipsis):
- return self.save_reduce(type, (Ellipsis,), obj=obj)
- elif obj is type(NotImplemented):
- return self.save_reduce(type, (NotImplemented,), obj=obj)
- elif obj in _BUILTIN_TYPE_NAMES:
- return self.save_reduce(
- _builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
-
- if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
- # Parametrized typing constructs in Python < 3.7 are not
- # compatible with type checks and ``isinstance`` semantics. For
- # this reason, it is easier to detect them using a
- # duck-typing-based check (``_is_parametrized_type_hint``) than
- # to populate the Pickler's dispatch with type-specific savers.
- self.save_reduce(
- _create_parametrized_type_hint,
- parametrized_type_hint_getinitargs(obj),
- obj=obj
- )
- elif name is not None:
- Pickler.save_global(self, obj, name=name)
- elif not _should_pickle_by_reference(obj, name=name):
- self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
- else:
- Pickler.save_global(self, obj, name=name)
- dispatch[type] = save_global
-
- def save_function(self, obj, name=None):
- """ Registered with the dispatch to handle all function types.
-
- Determines what kind of function obj is (e.g. lambda, defined at
- interactive prompt, etc) and handles the pickling appropriately.
- """
- if _should_pickle_by_reference(obj, name=name):
- return Pickler.save_global(self, obj, name=name)
- elif PYPY and isinstance(obj.__code__, builtin_code_type):
- return self.save_pypy_builtin_func(obj)
- else:
- return self._save_reduce_pickle5(
- *self._dynamic_function_reduce(obj), obj=obj
- )
-
- def save_pypy_builtin_func(self, obj):
- """Save pypy equivalent of builtin functions.
- PyPy does not have the concept of builtin-functions. Instead,
- builtin-functions are simple function instances, but with a
- builtin-code attribute.
- Most of the time, builtin functions should be pickled by attribute.
- But PyPy has flaky support for __qualname__, so some builtin
- functions such as float.__new__ will be classified as dynamic. For
- this reason only, we created this special routine. Because
- builtin-functions are not expected to have closure or globals,
- there is no additional hack (compared the one already implemented
- in pickle) to protect ourselves from reference cycles. A simple
- (reconstructor, newargs, obj.__dict__) tuple is save_reduced. Note
- also that PyPy improved their support for __qualname__ in v3.6, so
- this routing should be removed when cloudpickle supports only PyPy
- 3.6 and later.
- """
- rv = (types.FunctionType, (obj.__code__, {}, obj.__name__,
- obj.__defaults__, obj.__closure__),
- obj.__dict__)
- self.save_reduce(*rv, obj=obj)
-
- dispatch[types.FunctionType] = save_function
diff --git a/srsly/cloudpickle/compat.py b/srsly/cloudpickle/compat.py
deleted file mode 100644
index 5e9b527..0000000
--- a/srsly/cloudpickle/compat.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import sys
-
-
-if sys.version_info < (3, 8):
- try:
- import pickle5 as pickle # noqa: F401
- from pickle5 import Pickler # noqa: F401
- except ImportError:
- import pickle # noqa: F401
-
- # Use the Python pickler for old CPython versions
- from pickle import _Pickler as Pickler # noqa: F401
-else:
- import pickle # noqa: F401
-
- # Pickler will the C implementation in CPython and the Python
- # implementation in PyPy
- from pickle import Pickler # noqa: F401
diff --git a/srsly/msgpack/__init__.py b/srsly/msgpack/__init__.py
deleted file mode 100644
index 9bf2473..0000000
--- a/srsly/msgpack/__init__.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# coding: utf-8
-
-import functools
-import catalogue
-
-# These need to be imported before packer and unpacker
-from ._epoch import utc, epoch # noqa
-
-from ._version import version
-from .exceptions import *
-
-# In msgpack-python these are put under a _cmsgpack module that textually includes
-# them. I dislike this so I refactored it.
-from ._packer import Packer as _Packer
-from ._unpacker import unpackb as _unpackb
-from ._unpacker import Unpacker as _Unpacker
-from .ext import ExtType
-from ._msgpack_numpy import encode_numpy as _encode_numpy
-from ._msgpack_numpy import decode_numpy as _decode_numpy
-
-
-msgpack_encoders = catalogue.create("srsly", "msgpack_encoders", entry_points=True)
-msgpack_decoders = catalogue.create("srsly", "msgpack_decoders", entry_points=True)
-
-msgpack_encoders.register("numpy", func=_encode_numpy)
-msgpack_decoders.register("numpy", func=_decode_numpy)
-
-
-# msgpack_numpy extensions
-class Packer(_Packer):
- def __init__(self, *args, **kwargs):
- default = kwargs.get("default")
- for encoder in msgpack_encoders.get_all().values():
- default = functools.partial(encoder, chain=default)
- kwargs["default"] = default
- super(Packer, self).__init__(*args, **kwargs)
-
-
-class Unpacker(_Unpacker):
- def __init__(self, *args, **kwargs):
- object_hook = kwargs.get("object_hook")
- for decoder in msgpack_decoders.get_all().values():
- object_hook = functools.partial(decoder, chain=object_hook)
- kwargs["object_hook"] = object_hook
- super(Unpacker, self).__init__(*args, **kwargs)
-
-
-def pack(o, stream, **kwargs):
- """
- Pack an object and write it to a stream.
- """
- packer = Packer(**kwargs)
- stream.write(packer.pack(o))
-
-
-def packb(o, **kwargs):
- """
- Pack an object and return the packed bytes.
- """
- return Packer(**kwargs).pack(o)
-
-
-def unpack(stream, **kwargs):
- """
- Unpack a packed object from a stream.
- """
- if "object_pairs_hook" not in kwargs:
- object_hook = kwargs.get("object_hook")
- for decoder in msgpack_decoders.get_all().values():
- object_hook = functools.partial(decoder, chain=object_hook)
- kwargs["object_hook"] = object_hook
- data = stream.read()
- return _unpackb(data, **kwargs)
-
-
-def unpackb(packed, **kwargs):
- """
- Unpack a packed object.
- """
- if "object_pairs_hook" not in kwargs:
- object_hook = kwargs.get("object_hook")
- for decoder in msgpack_decoders.get_all().values():
- object_hook = functools.partial(decoder, chain=object_hook)
- kwargs["object_hook"] = object_hook
- return _unpackb(packed, **kwargs)
-
-
-# alias for compatibility to simplejson/marshal/pickle.
-load = unpack
-loads = unpackb
-
-dump = pack
-dumps = packb
diff --git a/srsly/msgpack/_epoch.pyx b/srsly/msgpack/_epoch.pyx
deleted file mode 100644
index 27b4bea..0000000
--- a/srsly/msgpack/_epoch.pyx
+++ /dev/null
@@ -1,7 +0,0 @@
-from cpython.datetime cimport import_datetime, datetime_new
-
-import_datetime()
-import datetime
-
-utc = datetime.timezone.utc
-epoch = datetime_new(1970, 1, 1, 0, 0, 0, 0, tz=utc)
diff --git a/srsly/msgpack/_msgpack_numpy.py b/srsly/msgpack/_msgpack_numpy.py
deleted file mode 100644
index 4748bed..0000000
--- a/srsly/msgpack/_msgpack_numpy.py
+++ /dev/null
@@ -1,94 +0,0 @@
-#!/usr/bin/env python
-
-"""
-Support for serialization of numpy data types with msgpack.
-"""
-
-# Copyright (c) 2013-2018, Lev E. Givon
-# All rights reserved.
-# Distributed under the terms of the BSD license:
-# http://www.opensource.org/licenses/bsd-license
-try:
- import numpy as np
-
- has_numpy = True
-except ImportError:
- has_numpy = False
-
-try:
- import cupy
-
- has_cupy = True
-except ImportError:
- has_cupy = False
-
-
-def encode_numpy(obj, chain=None):
- """
- Data encoder for serializing numpy data types.
- """
- if not has_numpy:
- return obj if chain is None else chain(obj)
- if has_cupy and isinstance(obj, cupy.ndarray):
- obj = obj.get()
- if isinstance(obj, np.ndarray):
- # If the dtype is structured, store the interface description;
- # otherwise, store the corresponding array protocol type string:
- if obj.dtype.kind == "V":
- kind = b"V"
- descr = obj.dtype.descr
- else:
- kind = b""
- descr = obj.dtype.str
- return {
- b"nd": True,
- b"type": descr,
- b"kind": kind,
- b"shape": obj.shape,
- b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(),
- }
- elif isinstance(obj, (np.bool_, np.number)):
- return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data}
- elif isinstance(obj, complex):
- return {b"complex": True, b"data": obj.__repr__()}
- else:
- return obj if chain is None else chain(obj)
-
-
-def tostr(x):
- if isinstance(x, bytes):
- return x.decode()
- else:
- return str(x)
-
-
-def decode_numpy(obj, chain=None):
- """
- Decoder for deserializing numpy data types.
- """
-
- try:
- if b"nd" in obj:
- if obj[b"nd"] is True:
-
- # Check if b'kind' is in obj to enable decoding of data
- # serialized with older versions (#20):
- if b"kind" in obj and obj[b"kind"] == b"V":
- descr = [
- tuple(tostr(t) if type(t) is bytes else t for t in d)
- for d in obj[b"type"]
- ]
- else:
- descr = obj[b"type"]
- return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(
- obj[b"shape"]
- )
- else:
- descr = obj[b"type"]
- return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0]
- elif b"complex" in obj:
- return complex(tostr(obj[b"data"]))
- else:
- return obj if chain is None else chain(obj)
- except KeyError:
- return obj if chain is None else chain(obj)
diff --git a/srsly/msgpack/_packer.pyx b/srsly/msgpack/_packer.pyx
deleted file mode 100644
index 25f10a8..0000000
--- a/srsly/msgpack/_packer.pyx
+++ /dev/null
@@ -1,376 +0,0 @@
-# coding: utf-8
-
-from cpython cimport *
-from cpython.bytearray cimport PyByteArray_Check, PyByteArray_CheckExact
-from cpython.datetime cimport (
- PyDateTime_CheckExact, PyDelta_CheckExact,
- datetime_tzinfo, timedelta_days, timedelta_seconds, timedelta_microseconds,
-)
-from ._epoch import utc, epoch
-
-cdef ExtType
-cdef Timestamp
-
-from .ext import ExtType, Timestamp
-from .util import ensure_bytes
-
-
-cdef extern from "Python.h":
-
- int PyMemoryView_Check(object obj)
-
-cdef extern from "pack.h":
- struct msgpack_packer:
- char* buf
- size_t length
- size_t buf_size
- bint use_bin_type
-
- int msgpack_pack_nil(msgpack_packer* pk) except -1
- int msgpack_pack_true(msgpack_packer* pk) except -1
- int msgpack_pack_false(msgpack_packer* pk) except -1
- int msgpack_pack_long_long(msgpack_packer* pk, long long d) except -1
- int msgpack_pack_unsigned_long_long(msgpack_packer* pk, unsigned long long d) except -1
- int msgpack_pack_float(msgpack_packer* pk, float d) except -1
- int msgpack_pack_double(msgpack_packer* pk, double d) except -1
- int msgpack_pack_array(msgpack_packer* pk, size_t l) except -1
- int msgpack_pack_map(msgpack_packer* pk, size_t l) except -1
- int msgpack_pack_raw(msgpack_packer* pk, size_t l) except -1
- int msgpack_pack_bin(msgpack_packer* pk, size_t l) except -1
- int msgpack_pack_raw_body(msgpack_packer* pk, char* body, size_t l) except -1
- int msgpack_pack_ext(msgpack_packer* pk, char typecode, size_t l) except -1
- int msgpack_pack_timestamp(msgpack_packer* x, long long seconds, unsigned long nanoseconds) except -1
-
-
-cdef int DEFAULT_RECURSE_LIMIT=511
-cdef long long ITEM_LIMIT = (2**32)-1
-
-
-cdef inline int PyBytesLike_Check(object o):
- return PyBytes_Check(o) or PyByteArray_Check(o)
-
-
-cdef inline int PyBytesLike_CheckExact(object o):
- return PyBytes_CheckExact(o) or PyByteArray_CheckExact(o)
-
-
-cdef class Packer:
- """
- MessagePack Packer
-
- Usage::
-
- packer = Packer()
- astream.write(packer.pack(a))
- astream.write(packer.pack(b))
-
- Packer's constructor has some keyword arguments:
-
- :param default:
- When specified, it should be callable.
- Convert user type to builtin type that Packer supports.
- See also simplejson's document.
-
- :param bool use_single_float:
- Use single precision float type for float. (default: False)
-
- :param bool autoreset:
- Reset buffer after each pack and return its content as `bytes`. (default: True).
- If set this to false, use `bytes()` to get content and `.reset()` to clear buffer.
-
- :param bool use_bin_type:
- Use bin type introduced in msgpack spec 2.0 for bytes.
- It also enables str8 type for unicode. (default: True)
-
- :param bool strict_types:
- If set to true, types will be checked to be exact. Derived classes
- from serializeable types will not be serialized and will be
- treated as unsupported type and forwarded to default.
- Additionally tuples will not be serialized as lists.
- This is useful when trying to implement accurate serialization
- for python types.
-
- :param bool datetime:
- If set to true, datetime with tzinfo is packed into Timestamp type.
- Note that the tzinfo is stripped in the timestamp.
- You can get UTC datetime with `timestamp=3` option of the Unpacker.
-
- :param str unicode_errors:
- The error handler for encoding unicode. (default: 'strict')
- DO NOT USE THIS!! This option is kept for very specific usage.
-
- :param int buf_size:
- The size of the internal buffer. (default: 256*1024)
- Useful if serialisation size can be correctly estimated,
- avoid unnecessary reallocations.
- """
- cdef msgpack_packer pk
- cdef object _default
- cdef size_t exports # number of exported buffers
- cdef bint strict_types
- cdef bint use_float
- cdef bint autoreset
- cdef bint datetime
- cdef object _bencoding
- cdef object _berrors
- cdef const char *encoding
- cdef const char *unicode_errors
-
-
- def __cinit__(self, buf_size=256*1024, **_kwargs):
- self.pk.buf = PyMem_Malloc(buf_size)
- if self.pk.buf == NULL:
- raise MemoryError("Unable to allocate internal buffer.")
- self.pk.buf_size = buf_size
- self.pk.length = 0
- self.exports = 0
-
- def __dealloc__(self):
- PyMem_Free(self.pk.buf)
- self.pk.buf = NULL
- assert self.exports == 0
-
- cdef _check_exports(self):
- if self.exports > 0:
- raise BufferError("Existing exports of data: Packer cannot be changed")
-
- def __init__(self, *, default=None, encoding=None,
- bint use_single_float=False, bint autoreset=True, bint use_bin_type=False,
- bint strict_types=False, bint datetime=False, unicode_errors=None,
- buf_size=256*1024):
- self.use_float = use_single_float
- self.strict_types = strict_types
- self.autoreset = autoreset
- self.datetime = datetime
- self.pk.use_bin_type = use_bin_type
- if default is not None:
- if not PyCallable_Check(default):
- raise TypeError("default must be a callable.")
- self._default = default
-
- if encoding is None:
- if PY_MAJOR_VERSION < 3:
- encoding = 'utf-8'
- if encoding is None:
- self._bencoding = None
- self.encoding = NULL
- else:
- self._bencoding = ensure_bytes(encoding)
- self.encoding = self._bencoding
- else:
- self._bencoding = ensure_bytes(encoding)
- self.encoding = self._bencoding
- unicode_errors = ensure_bytes(unicode_errors)
- self._berrors = unicode_errors
- if unicode_errors is None:
- self.unicode_errors = NULL
- else:
- self.unicode_errors = self._berrors
-
- # returns -2 when default should(o) be called
- cdef int _pack_inner(self, object o, bint will_default, int nest_limit) except -1:
- cdef long long llval
- cdef unsigned long long ullval
- cdef unsigned long ulval
- cdef const char* rawval
- cdef Py_ssize_t L
- cdef Py_buffer view
- cdef bint strict = self.strict_types
-
- if o is None:
- msgpack_pack_nil(&self.pk)
- elif o is True:
- msgpack_pack_true(&self.pk)
- elif o is False:
- msgpack_pack_false(&self.pk)
- elif PyLong_CheckExact(o) if strict else PyLong_Check(o):
- try:
- if o > 0:
- ullval = o
- msgpack_pack_unsigned_long_long(&self.pk, ullval)
- else:
- llval = o
- msgpack_pack_long_long(&self.pk, llval)
- except OverflowError as oe:
- if will_default:
- return -2
- else:
- raise OverflowError("Integer value out of range")
- elif PyFloat_CheckExact(o) if strict else PyFloat_Check(o):
- if self.use_float:
- msgpack_pack_float(&self.pk, o)
- else:
- msgpack_pack_double(&self.pk, o)
- elif PyBytesLike_CheckExact(o) if strict else PyBytesLike_Check(o):
- L = Py_SIZE(o)
- if L > ITEM_LIMIT:
- PyErr_Format(ValueError, b"%.200s object is too large", Py_TYPE(o).tp_name)
- rawval = o
- msgpack_pack_bin(&self.pk, L)
- msgpack_pack_raw_body(&self.pk, rawval, L)
- elif PyUnicode_CheckExact(o) if strict else PyUnicode_Check(o):
- if self.unicode_errors == NULL and self.encoding == NULL:
- rawval = PyUnicode_AsUTF8AndSize(o, &L)
- if L >ITEM_LIMIT:
- raise ValueError("unicode string is too large")
- else:
- o = PyUnicode_AsEncodedString(o, self.encoding, self.unicode_errors)
- L = Py_SIZE(o)
- if L > ITEM_LIMIT:
- raise ValueError("unicode string is too large")
- rawval = o
- msgpack_pack_raw(&self.pk, L)
- msgpack_pack_raw_body(&self.pk, rawval, L)
- elif PyDict_CheckExact(o) if strict else PyDict_Check(o):
- L = len(o)
- if L > ITEM_LIMIT:
- raise ValueError("dict is too large")
- msgpack_pack_map(&self.pk, L)
- for k, v in o.items():
- self._pack(k, nest_limit)
- self._pack(v, nest_limit)
- elif type(o) is ExtType if strict else isinstance(o, ExtType):
- # This should be before Tuple because ExtType is namedtuple.
- rawval = o.data
- L = len(o.data)
- if L > ITEM_LIMIT:
- raise ValueError("EXT data is too large")
- msgpack_pack_ext(&self.pk, o.code, L)
- msgpack_pack_raw_body(&self.pk, rawval, L)
- elif type(o) is Timestamp:
- llval = o.seconds
- ulval = o.nanoseconds
- msgpack_pack_timestamp(&self.pk, llval, ulval)
- elif PyList_CheckExact(o) if strict else (PyTuple_Check(o) or PyList_Check(o)):
- L = Py_SIZE(o)
- if L > ITEM_LIMIT:
- raise ValueError("list is too large")
- msgpack_pack_array(&self.pk, L)
- for v in o:
- self._pack(v, nest_limit)
- elif PyMemoryView_Check(o):
- PyObject_GetBuffer(o, &view, PyBUF_SIMPLE)
- L = view.len
- if L > ITEM_LIMIT:
- PyBuffer_Release(&view);
- raise ValueError("memoryview is too large")
- try:
- msgpack_pack_bin(&self.pk, L)
- msgpack_pack_raw_body(&self.pk, view.buf, L)
- finally:
- PyBuffer_Release(&view);
- elif self.datetime and PyDateTime_CheckExact(o) and datetime_tzinfo(o) is not None:
- delta = o - epoch
- if not PyDelta_CheckExact(delta):
- raise ValueError("failed to calculate delta")
- llval = timedelta_days(delta) * (24*60*60) + timedelta_seconds(delta)
- ulval = timedelta_microseconds(delta) * 1000
- msgpack_pack_timestamp(&self.pk, llval, ulval)
- elif will_default:
- return -2
- elif self.datetime and PyDateTime_CheckExact(o):
- # this should be later than will_default
- PyErr_Format(ValueError, b"can not serialize '%.200s' object where tzinfo=None", Py_TYPE(o).tp_name)
- else:
- PyErr_Format(TypeError, b"can not serialize '%.200s' object", Py_TYPE(o).tp_name)
-
- cdef int _pack(self, object o, int nest_limit=DEFAULT_RECURSE_LIMIT) except -1:
- cdef int ret
- if nest_limit < 0:
- raise ValueError("recursion limit exceeded.")
- nest_limit -= 1
- if self._default is not None:
- ret = self._pack_inner(o, 1, nest_limit)
- if ret == -2:
- o = self._default(o)
- else:
- return ret
- return self._pack_inner(o, 0, nest_limit)
-
- def pack(self, object obj):
- cdef int ret
- self._check_exports()
- try:
- ret = self._pack(obj, DEFAULT_RECURSE_LIMIT)
- except:
- self.pk.length = 0
- raise
- if ret: # should not happen.
- raise RuntimeError("internal error")
- if self.autoreset:
- buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
- self.pk.length = 0
- return buf
-
- def pack_ext_type(self, typecode, data):
- self._check_exports()
- if len(data) > ITEM_LIMIT:
- raise ValueError("ext data too large")
- msgpack_pack_ext(&self.pk, typecode, len(data))
- msgpack_pack_raw_body(&self.pk, data, len(data))
-
- def pack_array_header(self, long long size):
- self._check_exports()
- if size > ITEM_LIMIT:
- raise ValueError("array too large")
- msgpack_pack_array(&self.pk, size)
- if self.autoreset:
- buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
- self.pk.length = 0
- return buf
-
- def pack_map_header(self, long long size):
- self._check_exports()
- if size > ITEM_LIMIT:
- raise ValueError("map too learge")
- msgpack_pack_map(&self.pk, size)
- if self.autoreset:
- buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
- self.pk.length = 0
- return buf
-
- def pack_map_pairs(self, object pairs):
- """
- Pack *pairs* as msgpack map type.
-
- *pairs* should be a sequence of pairs.
- (`len(pairs)` and `for k, v in pairs:` should be supported.)
- """
- self._check_exports()
- size = len(pairs)
- if size > ITEM_LIMIT:
- raise ValueError("map too large")
- msgpack_pack_map(&self.pk, size)
- for k, v in pairs:
- self._pack(k)
- self._pack(v)
- if self.autoreset:
- buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
- self.pk.length = 0
- return buf
-
- def reset(self):
- """Reset internal buffer.
-
- This method is useful only when autoreset=False.
- """
- self._check_exports()
- self.pk.length = 0
-
- def bytes(self):
- """Return internal buffer contents as bytes object"""
- return PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
-
- def getbuffer(self):
- """Return memoryview of internal buffer.
-
- Note: Packer now supports buffer protocol. You can use memoryview(packer).
- """
- return memoryview(self)
-
- def __getbuffer__(self, Py_buffer *buffer, int flags):
- PyBuffer_FillInfo(buffer, self, self.pk.buf, self.pk.length, 1, flags)
- self.exports += 1
-
- def __releasebuffer__(self, Py_buffer *buffer):
- self.exports -= 1
diff --git a/srsly/msgpack/_unpacker.pyx b/srsly/msgpack/_unpacker.pyx
deleted file mode 100644
index 631b9b5..0000000
--- a/srsly/msgpack/_unpacker.pyx
+++ /dev/null
@@ -1,560 +0,0 @@
-# coding: utf-8
-
-from cpython cimport *
-cdef extern from "Python.h":
- ctypedef struct PyObject
- object PyMemoryView_GetContiguous(object obj, int buffertype, char order)
-
-from libc.stdlib cimport *
-from libc.string cimport *
-from libc.limits cimport *
-from libc.stdint cimport uint64_t
-
-from .exceptions import (
- BufferFull,
- OutOfData,
- ExtraData,
- FormatError,
- StackError,
-)
-from .ext import ExtType, Timestamp
-from .util import ensure_bytes
-from ._epoch import utc, epoch
-
-cdef object giga = 1_000_000_000
-
-
-cdef extern from "unpack.h":
- ctypedef struct msgpack_user:
- bint use_list
- bint raw
- bint has_pairs_hook # call object_hook with k-v pairs
- bint strict_map_key
- int timestamp
- PyObject* object_hook
- PyObject* list_hook
- PyObject* ext_hook
- PyObject* timestamp_t
- PyObject *giga;
- PyObject *utc;
- const char *unicode_errors
- const char *encoding
- Py_ssize_t max_str_len
- Py_ssize_t max_bin_len
- Py_ssize_t max_array_len
- Py_ssize_t max_map_len
- Py_ssize_t max_ext_len
-
- ctypedef struct unpack_context:
- msgpack_user user
- PyObject* obj
- Py_ssize_t count
-
- ctypedef int (*execute_fn)(unpack_context* ctx, const char* data,
- Py_ssize_t len, Py_ssize_t* off) except? -1
- execute_fn unpack_construct
- execute_fn unpack_skip
- execute_fn read_array_header
- execute_fn read_map_header
- void unpack_init(unpack_context* ctx)
- object unpack_data(unpack_context* ctx)
- void unpack_clear(unpack_context* ctx)
-
-cdef inline init_ctx(unpack_context *ctx,
- object object_hook, object object_pairs_hook,
- object list_hook, object ext_hook,
- bint use_list, bint raw, int timestamp,
- bint strict_map_key,
- const char* encoding,
- const char* unicode_errors,
- Py_ssize_t max_str_len, Py_ssize_t max_bin_len,
- Py_ssize_t max_array_len, Py_ssize_t max_map_len,
- Py_ssize_t max_ext_len):
- unpack_init(ctx)
- ctx.user.use_list = use_list
- ctx.user.raw = raw
- ctx.user.strict_map_key = strict_map_key
- ctx.user.object_hook = ctx.user.list_hook = NULL
- ctx.user.max_str_len = max_str_len
- ctx.user.max_bin_len = max_bin_len
- ctx.user.max_array_len = max_array_len
- ctx.user.max_map_len = max_map_len
- ctx.user.max_ext_len = max_ext_len
-
- if object_hook is not None and object_pairs_hook is not None:
- raise TypeError("object_pairs_hook and object_hook are mutually exclusive.")
-
- if object_hook is not None:
- if not PyCallable_Check(object_hook):
- raise TypeError("object_hook must be a callable.")
- ctx.user.object_hook = object_hook
-
- if object_pairs_hook is None:
- ctx.user.has_pairs_hook = False
- else:
- if not PyCallable_Check(object_pairs_hook):
- raise TypeError("object_pairs_hook must be a callable.")
- ctx.user.object_hook = object_pairs_hook
- ctx.user.has_pairs_hook = True
-
- if list_hook is not None:
- if not PyCallable_Check(list_hook):
- raise TypeError("list_hook must be a callable.")
- ctx.user.list_hook = list_hook
-
- if ext_hook is not None:
- if not PyCallable_Check(ext_hook):
- raise TypeError("ext_hook must be a callable.")
- ctx.user.ext_hook = ext_hook
-
- if timestamp < 0 or 3 < timestamp:
- raise ValueError("timestamp must be 0..3")
-
- # Add Timestamp type to the user object so it may be used in unpack.h
- ctx.user.timestamp = timestamp
- ctx.user.timestamp_t = Timestamp
- ctx.user.giga = giga
- ctx.user.utc = utc
- ctx.user.unicode_errors = unicode_errors
- ctx.user.encoding = encoding
-
-def default_read_extended_type(typecode, data):
- raise NotImplementedError("Cannot decode extended type with typecode=%d" % typecode)
-
-cdef inline int get_data_from_buffer(object obj,
- Py_buffer *view,
- char **buf,
- Py_ssize_t *buffer_len) except 0:
- cdef object contiguous
- cdef Py_buffer tmp
- if PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) == -1:
- raise
- if view.itemsize != 1:
- PyBuffer_Release(view)
- raise BufferError("cannot unpack from multi-byte object")
- if PyBuffer_IsContiguous(view, b'A') == 0:
- PyBuffer_Release(view)
- # create a contiguous copy and get buffer
- contiguous = PyMemoryView_GetContiguous(obj, PyBUF_READ, b'C')
- PyObject_GetBuffer(contiguous, view, PyBUF_SIMPLE)
- # view must hold the only reference to contiguous,
- # so memory is freed when view is released
- Py_DECREF(contiguous)
- buffer_len[0] = view.len
- buf[0] = view.buf
- return 1
-
-
-def unpackb(object packed, *, object object_hook=None, object list_hook=None,
- bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False,
- encoding=None,
- unicode_errors=None,
- object_pairs_hook=None, ext_hook=ExtType,
- Py_ssize_t max_str_len=-1,
- Py_ssize_t max_bin_len=-1,
- Py_ssize_t max_array_len=-1,
- Py_ssize_t max_map_len=-1,
- Py_ssize_t max_ext_len=-1):
- """
- Unpack packed_bytes to object. Returns an unpacked object.
-
- Raises ``ExtraData`` when *packed* contains extra bytes.
- Raises ``ValueError`` when *packed* is incomplete.
- Raises ``FormatError`` when *packed* is not valid msgpack.
- Raises ``StackError`` when *packed* contains too nested.
- Other exceptions can be raised during unpacking.
-
- See :class:`Unpacker` for options.
-
- *max_xxx_len* options are configured automatically from ``len(packed)``.
- """
- cdef unpack_context ctx
- cdef Py_ssize_t off = 0
- cdef int ret
-
- cdef Py_buffer view
- cdef char* buf = NULL
- cdef Py_ssize_t buf_len
- cdef const char* cenc = NULL
- cdef const char* cerr = NULL
-
- if encoding is not None:
- encoding = ensure_bytes(encoding)
- cenc = encoding
-
- if unicode_errors is not None:
- unicode_errors = ensure_bytes(unicode_errors)
- cerr = unicode_errors
-
- get_data_from_buffer(packed, &view, &buf, &buf_len)
-
- if max_str_len == -1:
- max_str_len = buf_len
- if max_bin_len == -1:
- max_bin_len = buf_len
- if max_array_len == -1:
- max_array_len = buf_len
- if max_map_len == -1:
- max_map_len = buf_len//2
- if max_ext_len == -1:
- max_ext_len = buf_len
-
- try:
- init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook,
- use_list, raw, timestamp, strict_map_key, cenc, cerr,
- max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len)
- ret = unpack_construct(&ctx, buf, buf_len, &off)
- finally:
- PyBuffer_Release(&view);
-
- if ret == 1:
- obj = unpack_data(&ctx)
- if off < buf_len:
- raise ExtraData(obj, PyBytes_FromStringAndSize(buf+off, buf_len-off))
- return obj
- unpack_clear(&ctx)
- if ret == 0:
- raise ValueError("Unpack failed: incomplete input")
- elif ret == -2:
- raise FormatError
- elif ret == -3:
- raise StackError
- raise ValueError("Unpack failed: error = %d" % (ret,))
-
-
-cdef class Unpacker:
- """Streaming unpacker.
-
- Arguments:
-
- :param file_like:
- File-like object having `.read(n)` method.
- If specified, unpacker reads serialized data from it and `.feed()` is not usable.
-
- :param int read_size:
- Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`)
-
- :param bool use_list:
- If true, unpack msgpack array to Python list.
- Otherwise, unpack to Python tuple. (default: True)
-
- :param bool raw:
- If true, unpack msgpack raw to Python bytes.
- Otherwise, unpack to Python str by decoding with UTF-8 encoding (default).
-
- :param int timestamp:
- Control how timestamp type is unpacked:
-
- 0 - Timestamp
- 1 - float (Seconds from the EPOCH)
- 2 - int (Nanoseconds from the EPOCH)
- 3 - datetime.datetime (UTC).
-
- :param bool strict_map_key:
- If true (default), only str or bytes are accepted for map (dict) keys.
-
- :param object_hook:
- When specified, it should be callable.
- Unpacker calls it with a dict argument after unpacking msgpack map.
- (See also simplejson)
-
- :param object_pairs_hook:
- When specified, it should be callable.
- Unpacker calls it with a list of key-value pairs after unpacking msgpack map.
- (See also simplejson)
-
- :param str unicode_errors:
- The error handler for decoding unicode. (default: 'strict')
- This option should be used only when you have msgpack data which
- contains invalid UTF-8 string.
-
- :param int max_buffer_size:
- Limits size of data waiting unpacked. 0 means 2**32-1.
- The default value is 100*1024*1024 (100MiB).
- Raises `BufferFull` exception when it is insufficient.
- You should set this parameter when unpacking data from untrusted source.
-
- :param int max_str_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max length of str. (default: max_buffer_size)
-
- :param int max_bin_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max length of bin. (default: max_buffer_size)
-
- :param int max_array_len:
- Limits max length of array.
- (default: max_buffer_size)
-
- :param int max_map_len:
- Limits max length of map.
- (default: max_buffer_size//2)
-
- :param int max_ext_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max size of ext type. (default: max_buffer_size)
-
- Example of streaming deserialize from file-like object::
-
- unpacker = Unpacker(file_like)
- for o in unpacker:
- process(o)
-
- Example of streaming deserialize from socket::
-
- unpacker = Unpacker()
- while True:
- buf = sock.recv(1024**2)
- if not buf:
- break
- unpacker.feed(buf)
- for o in unpacker:
- process(o)
-
- Raises ``ExtraData`` when *packed* contains extra bytes.
- Raises ``OutOfData`` when *packed* is incomplete.
- Raises ``FormatError`` when *packed* is not valid msgpack.
- Raises ``StackError`` when *packed* contains too nested.
- Other exceptions can be raised during unpacking.
- """
- cdef unpack_context ctx
- cdef char* buf
- cdef Py_ssize_t buf_size, buf_head, buf_tail
- cdef object file_like
- cdef object file_like_read
- cdef Py_ssize_t read_size
- # To maintain refcnt.
- cdef object object_hook, object_pairs_hook, list_hook, ext_hook
- cdef object unicode_errors
- cdef Py_ssize_t max_buffer_size
- cdef uint64_t stream_offset
-
- def __cinit__(self):
- self.buf = NULL
-
- def __dealloc__(self):
- PyMem_Free(self.buf)
- self.buf = NULL
-
- def __init__(self, file_like=None, *, Py_ssize_t read_size=0,
- bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False,
- object object_hook=None, object object_pairs_hook=None, object list_hook=None,
- unicode_errors=None, Py_ssize_t max_buffer_size=100*1024*1024,
- object ext_hook=ExtType,
- Py_ssize_t max_str_len=-1,
- Py_ssize_t max_bin_len=-1,
- Py_ssize_t max_array_len=-1,
- Py_ssize_t max_map_len=-1,
- Py_ssize_t max_ext_len=-1):
- cdef const char* cerr = NULL
- cdef const char* cenc = NULL
-
- self.object_hook = object_hook
- self.object_pairs_hook = object_pairs_hook
- self.list_hook = list_hook
- self.ext_hook = ext_hook
-
- self.file_like = file_like
- if file_like:
- self.file_like_read = file_like.read
- if not PyCallable_Check(self.file_like_read):
- raise TypeError("`file_like.read` must be a callable.")
-
- if not max_buffer_size:
- max_buffer_size = INT_MAX
- if max_str_len == -1:
- max_str_len = max_buffer_size
- if max_bin_len == -1:
- max_bin_len = max_buffer_size
- if max_array_len == -1:
- max_array_len = max_buffer_size
- if max_map_len == -1:
- max_map_len = max_buffer_size//2
- if max_ext_len == -1:
- max_ext_len = max_buffer_size
-
- if read_size > max_buffer_size:
- raise ValueError("read_size should be less or equal to max_buffer_size")
- if not read_size:
- read_size = min(max_buffer_size, 1024**2)
-
- self.max_buffer_size = max_buffer_size
- self.read_size = read_size
- self.buf = PyMem_Malloc(read_size)
- if self.buf == NULL:
- raise MemoryError("Unable to allocate internal buffer.")
- self.buf_size = read_size
- self.buf_head = 0
- self.buf_tail = 0
- self.stream_offset = 0
-
- if unicode_errors is not None:
- self.unicode_errors = unicode_errors
- cerr = unicode_errors
-
- init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook,
- ext_hook, use_list, raw, timestamp, strict_map_key, cenc, cerr,
- max_str_len, max_bin_len, max_array_len,
- max_map_len, max_ext_len)
-
- def feed(self, object next_bytes):
- """Append `next_bytes` to internal buffer."""
- cdef Py_buffer pybuff
- cdef char* buf
- cdef Py_ssize_t buf_len
-
- if self.file_like is not None:
- raise AssertionError(
- "unpacker.feed() is not be able to use with `file_like`.")
-
- get_data_from_buffer(next_bytes, &pybuff, &buf, &buf_len)
- try:
- self.append_buffer(buf, buf_len)
- finally:
- PyBuffer_Release(&pybuff)
-
- cdef append_buffer(self, void* _buf, Py_ssize_t _buf_len):
- cdef:
- char* buf = self.buf
- char* new_buf
- Py_ssize_t head = self.buf_head
- Py_ssize_t tail = self.buf_tail
- Py_ssize_t buf_size = self.buf_size
- Py_ssize_t new_size
-
- if tail + _buf_len > buf_size:
- if ((tail - head) + _buf_len) <= buf_size:
- # move to front.
- memmove(buf, buf + head, tail - head)
- tail -= head
- head = 0
- else:
- # expand buffer.
- new_size = (tail-head) + _buf_len
- if new_size > self.max_buffer_size:
- raise BufferFull
- new_size = min(new_size*2, self.max_buffer_size)
- new_buf = PyMem_Malloc(new_size)
- if new_buf == NULL:
- # self.buf still holds old buffer and will be freed during
- # obj destruction
- raise MemoryError("Unable to enlarge internal buffer.")
- memcpy(new_buf, buf + head, tail - head)
- PyMem_Free(buf)
-
- buf = new_buf
- buf_size = new_size
- tail -= head
- head = 0
-
- memcpy(buf + tail, (_buf), _buf_len)
- self.buf = buf
- self.buf_head = head
- self.buf_size = buf_size
- self.buf_tail = tail + _buf_len
-
- cdef int read_from_file(self) except -1:
- cdef Py_ssize_t remains = self.max_buffer_size - (self.buf_tail - self.buf_head)
- if remains <= 0:
- raise BufferFull
-
- next_bytes = self.file_like_read(min(self.read_size, remains))
- if next_bytes:
- self.append_buffer(PyBytes_AsString(next_bytes), PyBytes_Size(next_bytes))
- else:
- self.file_like = None
- return 0
-
- cdef object _unpack(self, execute_fn execute, bint iter=0):
- cdef int ret
- cdef object obj
- cdef Py_ssize_t prev_head
-
- while 1:
- prev_head = self.buf_head
- if prev_head < self.buf_tail:
- ret = execute(&self.ctx, self.buf, self.buf_tail, &self.buf_head)
- self.stream_offset += self.buf_head - prev_head
- else:
- ret = 0
-
- if ret == 1:
- obj = unpack_data(&self.ctx)
- unpack_init(&self.ctx)
- return obj
- elif ret == 0:
- if self.file_like is not None:
- self.read_from_file()
- continue
- if iter:
- raise StopIteration("No more data to unpack.")
- else:
- raise OutOfData("No more data to unpack.")
- elif ret == -2:
- raise FormatError
- elif ret == -3:
- raise StackError
- else:
- raise ValueError("Unpack failed: error = %d" % (ret,))
-
- def read_bytes(self, Py_ssize_t nbytes):
- """Read a specified number of raw bytes from the stream"""
- cdef Py_ssize_t nread
- nread = min(self.buf_tail - self.buf_head, nbytes)
- ret = PyBytes_FromStringAndSize(self.buf + self.buf_head, nread)
- self.buf_head += nread
- if nread < nbytes and self.file_like is not None:
- ret += self.file_like.read(nbytes - nread)
- nread = len(ret)
- self.stream_offset += nread
- return ret
-
- def unpack(self):
- """Unpack one object
-
- Raises `OutOfData` when there are no more bytes to unpack.
- """
- return self._unpack(unpack_construct)
-
- def skip(self):
- """Read and ignore one object, returning None
-
- Raises `OutOfData` when there are no more bytes to unpack.
- """
- return self._unpack(unpack_skip)
-
- def read_array_header(self):
- """assuming the next object is an array, return its size n, such that
- the next n unpack() calls will iterate over its contents.
-
- Raises `OutOfData` when there are no more bytes to unpack.
- """
- return self._unpack(read_array_header)
-
- def read_map_header(self):
- """assuming the next object is a map, return its size n, such that the
- next n * 2 unpack() calls will iterate over its key-value pairs.
-
- Raises `OutOfData` when there are no more bytes to unpack.
- """
- return self._unpack(read_map_header)
-
- def tell(self):
- """Returns the current position of the Unpacker in bytes, i.e., the
- number of bytes that were read from the input, also the starting
- position of the next object.
- """
- return self.stream_offset
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return self._unpack(unpack_construct, 1)
-
- # for debug.
- #def _buf(self):
- # return PyString_FromStringAndSize(self.buf, self.buf_tail)
-
- #def _off(self):
- # return self.buf_head
diff --git a/srsly/msgpack/_version.py b/srsly/msgpack/_version.py
deleted file mode 100644
index 1b75e0a..0000000
--- a/srsly/msgpack/_version.py
+++ /dev/null
@@ -1 +0,0 @@
-version = (1, 1, 0)
diff --git a/srsly/msgpack/exceptions.py b/srsly/msgpack/exceptions.py
deleted file mode 100644
index d6d2615..0000000
--- a/srsly/msgpack/exceptions.py
+++ /dev/null
@@ -1,48 +0,0 @@
-class UnpackException(Exception):
- """Base class for some exceptions raised while unpacking.
-
- NOTE: unpack may raise exception other than subclass of
- UnpackException. If you want to catch all error, catch
- Exception instead.
- """
-
-
-class BufferFull(UnpackException):
- pass
-
-
-class OutOfData(UnpackException):
- pass
-
-
-class FormatError(ValueError, UnpackException):
- """Invalid msgpack format"""
-
-
-class StackError(ValueError, UnpackException):
- """Too nested"""
-
-
-# Deprecated. Use ValueError instead
-UnpackValueError = ValueError
-
-
-class ExtraData(UnpackValueError):
- """ExtraData is raised when there is trailing data.
-
- This exception is raised while only one-shot (not streaming)
- unpack.
- """
-
- def __init__(self, unpacked, extra):
- self.unpacked = unpacked
- self.extra = extra
-
- def __str__(self):
- return "unpack(b) received extra data."
-
-
-# Deprecated. Use Exception instead to catch all exception during packing.
-PackException = Exception
-PackValueError = ValueError
-PackOverflowError = OverflowError
diff --git a/srsly/msgpack/ext.py b/srsly/msgpack/ext.py
deleted file mode 100644
index 9694819..0000000
--- a/srsly/msgpack/ext.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import datetime
-import struct
-from collections import namedtuple
-
-
-class ExtType(namedtuple("ExtType", "code data")):
- """ExtType represents ext type in msgpack."""
-
- def __new__(cls, code, data):
- if not isinstance(code, int):
- raise TypeError("code must be int")
- if not isinstance(data, bytes):
- raise TypeError("data must be bytes")
- if not 0 <= code <= 127:
- raise ValueError("code must be 0~127")
- return super().__new__(cls, code, data)
-
-
-class Timestamp:
- """Timestamp represents the Timestamp extension type in msgpack.
-
- When built with Cython, msgpack uses C methods to pack and unpack `Timestamp`.
- When using pure-Python msgpack, :func:`to_bytes` and :func:`from_bytes` are used to pack and
- unpack `Timestamp`.
-
- This class is immutable: Do not override seconds and nanoseconds.
- """
-
- __slots__ = ["seconds", "nanoseconds"]
-
- def __init__(self, seconds, nanoseconds=0):
- """Initialize a Timestamp object.
-
- :param int seconds:
- Number of seconds since the UNIX epoch (00:00:00 UTC Jan 1 1970, minus leap seconds).
- May be negative.
-
- :param int nanoseconds:
- Number of nanoseconds to add to `seconds` to get fractional time.
- Maximum is 999_999_999. Default is 0.
-
- Note: Negative times (before the UNIX epoch) are represented as neg. seconds + pos. ns.
- """
- if not isinstance(seconds, int):
- raise TypeError("seconds must be an integer")
- if not isinstance(nanoseconds, int):
- raise TypeError("nanoseconds must be an integer")
- if not (0 <= nanoseconds < 10**9):
- raise ValueError("nanoseconds must be a non-negative integer less than 999999999.")
- self.seconds = seconds
- self.nanoseconds = nanoseconds
-
- def __repr__(self):
- """String representation of Timestamp."""
- return f"Timestamp(seconds={self.seconds}, nanoseconds={self.nanoseconds})"
-
- def __eq__(self, other):
- """Check for equality with another Timestamp object"""
- if type(other) is self.__class__:
- return self.seconds == other.seconds and self.nanoseconds == other.nanoseconds
- return False
-
- def __ne__(self, other):
- """not-equals method (see :func:`__eq__()`)"""
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash((self.seconds, self.nanoseconds))
-
- @staticmethod
- def from_bytes(b):
- """Unpack bytes into a `Timestamp` object.
-
- Used for pure-Python msgpack unpacking.
-
- :param b: Payload from msgpack ext message with code -1
- :type b: bytes
-
- :returns: Timestamp object unpacked from msgpack ext payload
- :rtype: Timestamp
- """
- if len(b) == 4:
- seconds = struct.unpack("!L", b)[0]
- nanoseconds = 0
- elif len(b) == 8:
- data64 = struct.unpack("!Q", b)[0]
- seconds = data64 & 0x00000003FFFFFFFF
- nanoseconds = data64 >> 34
- elif len(b) == 12:
- nanoseconds, seconds = struct.unpack("!Iq", b)
- else:
- raise ValueError(
- "Timestamp type can only be created from 32, 64, or 96-bit byte objects"
- )
- return Timestamp(seconds, nanoseconds)
-
- def to_bytes(self):
- """Pack this Timestamp object into bytes.
-
- Used for pure-Python msgpack packing.
-
- :returns data: Payload for EXT message with code -1 (timestamp type)
- :rtype: bytes
- """
- if (self.seconds >> 34) == 0: # seconds is non-negative and fits in 34 bits
- data64 = self.nanoseconds << 34 | self.seconds
- if data64 & 0xFFFFFFFF00000000 == 0:
- # nanoseconds is zero and seconds < 2**32, so timestamp 32
- data = struct.pack("!L", data64)
- else:
- # timestamp 64
- data = struct.pack("!Q", data64)
- else:
- # timestamp 96
- data = struct.pack("!Iq", self.nanoseconds, self.seconds)
- return data
-
- @staticmethod
- def from_unix(unix_sec):
- """Create a Timestamp from posix timestamp in seconds.
-
- :param unix_float: Posix timestamp in seconds.
- :type unix_float: int or float
- """
- seconds = int(unix_sec // 1)
- nanoseconds = int((unix_sec % 1) * 10**9)
- return Timestamp(seconds, nanoseconds)
-
- def to_unix(self):
- """Get the timestamp as a floating-point value.
-
- :returns: posix timestamp
- :rtype: float
- """
- return self.seconds + self.nanoseconds / 1e9
-
- @staticmethod
- def from_unix_nano(unix_ns):
- """Create a Timestamp from posix timestamp in nanoseconds.
-
- :param int unix_ns: Posix timestamp in nanoseconds.
- :rtype: Timestamp
- """
- return Timestamp(*divmod(unix_ns, 10**9))
-
- def to_unix_nano(self):
- """Get the timestamp as a unixtime in nanoseconds.
-
- :returns: posix timestamp in nanoseconds
- :rtype: int
- """
- return self.seconds * 10**9 + self.nanoseconds
-
- def to_datetime(self):
- """Get the timestamp as a UTC datetime.
-
- :rtype: `datetime.datetime`
- """
- utc = datetime.timezone.utc
- return datetime.datetime.fromtimestamp(0, utc) + datetime.timedelta(
- seconds=self.seconds, microseconds=self.nanoseconds // 1000
- )
-
- @staticmethod
- def from_datetime(dt):
- """Create a Timestamp from datetime with tzinfo.
-
- :rtype: Timestamp
- """
- return Timestamp(seconds=int(dt.timestamp()), nanoseconds=dt.microsecond * 1000)
diff --git a/srsly/msgpack/fallback.py b/srsly/msgpack/fallback.py
deleted file mode 100644
index b02e47c..0000000
--- a/srsly/msgpack/fallback.py
+++ /dev/null
@@ -1,929 +0,0 @@
-"""Fallback pure Python implementation of msgpack"""
-
-import struct
-import sys
-from datetime import datetime as _DateTime
-
-if hasattr(sys, "pypy_version_info"):
- from __pypy__ import newlist_hint
- from __pypy__.builders import BytesBuilder
-
- _USING_STRINGBUILDER = True
-
- class BytesIO:
- def __init__(self, s=b""):
- if s:
- self.builder = BytesBuilder(len(s))
- self.builder.append(s)
- else:
- self.builder = BytesBuilder()
-
- def write(self, s):
- if isinstance(s, memoryview):
- s = s.tobytes()
- elif isinstance(s, bytearray):
- s = bytes(s)
- self.builder.append(s)
-
- def getvalue(self):
- return self.builder.build()
-
-else:
- from io import BytesIO
-
- _USING_STRINGBUILDER = False
-
- def newlist_hint(size):
- return []
-
-
-from .exceptions import BufferFull, ExtraData, FormatError, OutOfData, StackError
-from .ext import ExtType, Timestamp
-
-EX_SKIP = 0
-EX_CONSTRUCT = 1
-EX_READ_ARRAY_HEADER = 2
-EX_READ_MAP_HEADER = 3
-
-TYPE_IMMEDIATE = 0
-TYPE_ARRAY = 1
-TYPE_MAP = 2
-TYPE_RAW = 3
-TYPE_BIN = 4
-TYPE_EXT = 5
-
-DEFAULT_RECURSE_LIMIT = 511
-
-
-def _check_type_strict(obj, t, type=type, tuple=tuple):
- if type(t) is tuple:
- return type(obj) in t
- else:
- return type(obj) is t
-
-
-def _get_data_from_buffer(obj):
- view = memoryview(obj)
- if view.itemsize != 1:
- raise ValueError("cannot unpack from multi-byte object")
- return view
-
-
-def unpackb(packed, **kwargs):
- """
- Unpack an object from `packed`.
-
- Raises ``ExtraData`` when *packed* contains extra bytes.
- Raises ``ValueError`` when *packed* is incomplete.
- Raises ``FormatError`` when *packed* is not valid msgpack.
- Raises ``StackError`` when *packed* contains too nested.
- Other exceptions can be raised during unpacking.
-
- See :class:`Unpacker` for options.
- """
- unpacker = Unpacker(None, max_buffer_size=len(packed), **kwargs)
- unpacker.feed(packed)
- try:
- ret = unpacker._unpack()
- except OutOfData:
- raise ValueError("Unpack failed: incomplete input")
- except RecursionError:
- raise StackError
- if unpacker._got_extradata():
- raise ExtraData(ret, unpacker._get_extradata())
- return ret
-
-
-_NO_FORMAT_USED = ""
-_MSGPACK_HEADERS = {
- 0xC4: (1, _NO_FORMAT_USED, TYPE_BIN),
- 0xC5: (2, ">H", TYPE_BIN),
- 0xC6: (4, ">I", TYPE_BIN),
- 0xC7: (2, "Bb", TYPE_EXT),
- 0xC8: (3, ">Hb", TYPE_EXT),
- 0xC9: (5, ">Ib", TYPE_EXT),
- 0xCA: (4, ">f"),
- 0xCB: (8, ">d"),
- 0xCC: (1, _NO_FORMAT_USED),
- 0xCD: (2, ">H"),
- 0xCE: (4, ">I"),
- 0xCF: (8, ">Q"),
- 0xD0: (1, "b"),
- 0xD1: (2, ">h"),
- 0xD2: (4, ">i"),
- 0xD3: (8, ">q"),
- 0xD4: (1, "b1s", TYPE_EXT),
- 0xD5: (2, "b2s", TYPE_EXT),
- 0xD6: (4, "b4s", TYPE_EXT),
- 0xD7: (8, "b8s", TYPE_EXT),
- 0xD8: (16, "b16s", TYPE_EXT),
- 0xD9: (1, _NO_FORMAT_USED, TYPE_RAW),
- 0xDA: (2, ">H", TYPE_RAW),
- 0xDB: (4, ">I", TYPE_RAW),
- 0xDC: (2, ">H", TYPE_ARRAY),
- 0xDD: (4, ">I", TYPE_ARRAY),
- 0xDE: (2, ">H", TYPE_MAP),
- 0xDF: (4, ">I", TYPE_MAP),
-}
-
-
-class Unpacker:
- """Streaming unpacker.
-
- Arguments:
-
- :param file_like:
- File-like object having `.read(n)` method.
- If specified, unpacker reads serialized data from it and `.feed()` is not usable.
-
- :param int read_size:
- Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`)
-
- :param bool use_list:
- If true, unpack msgpack array to Python list.
- Otherwise, unpack to Python tuple. (default: True)
-
- :param bool raw:
- If true, unpack msgpack raw to Python bytes.
- Otherwise, unpack to Python str by decoding with UTF-8 encoding (default).
-
- :param int timestamp:
- Control how timestamp type is unpacked:
-
- 0 - Timestamp
- 1 - float (Seconds from the EPOCH)
- 2 - int (Nanoseconds from the EPOCH)
- 3 - datetime.datetime (UTC).
-
- :param bool strict_map_key:
- If true (default), only str or bytes are accepted for map (dict) keys.
-
- :param object_hook:
- When specified, it should be callable.
- Unpacker calls it with a dict argument after unpacking msgpack map.
- (See also simplejson)
-
- :param object_pairs_hook:
- When specified, it should be callable.
- Unpacker calls it with a list of key-value pairs after unpacking msgpack map.
- (See also simplejson)
-
- :param str unicode_errors:
- The error handler for decoding unicode. (default: 'strict')
- This option should be used only when you have msgpack data which
- contains invalid UTF-8 string.
-
- :param int max_buffer_size:
- Limits size of data waiting unpacked. 0 means 2**32-1.
- The default value is 100*1024*1024 (100MiB).
- Raises `BufferFull` exception when it is insufficient.
- You should set this parameter when unpacking data from untrusted source.
-
- :param int max_str_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max length of str. (default: max_buffer_size)
-
- :param int max_bin_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max length of bin. (default: max_buffer_size)
-
- :param int max_array_len:
- Limits max length of array.
- (default: max_buffer_size)
-
- :param int max_map_len:
- Limits max length of map.
- (default: max_buffer_size//2)
-
- :param int max_ext_len:
- Deprecated, use *max_buffer_size* instead.
- Limits max size of ext type. (default: max_buffer_size)
-
- Example of streaming deserialize from file-like object::
-
- unpacker = Unpacker(file_like)
- for o in unpacker:
- process(o)
-
- Example of streaming deserialize from socket::
-
- unpacker = Unpacker()
- while True:
- buf = sock.recv(1024**2)
- if not buf:
- break
- unpacker.feed(buf)
- for o in unpacker:
- process(o)
-
- Raises ``ExtraData`` when *packed* contains extra bytes.
- Raises ``OutOfData`` when *packed* is incomplete.
- Raises ``FormatError`` when *packed* is not valid msgpack.
- Raises ``StackError`` when *packed* contains too nested.
- Other exceptions can be raised during unpacking.
- """
-
- def __init__(
- self,
- file_like=None,
- *,
- read_size=0,
- use_list=True,
- raw=False,
- timestamp=0,
- strict_map_key=True,
- object_hook=None,
- object_pairs_hook=None,
- list_hook=None,
- unicode_errors=None,
- max_buffer_size=100 * 1024 * 1024,
- ext_hook=ExtType,
- max_str_len=-1,
- max_bin_len=-1,
- max_array_len=-1,
- max_map_len=-1,
- max_ext_len=-1,
- ):
- if unicode_errors is None:
- unicode_errors = "strict"
-
- if file_like is None:
- self._feeding = True
- else:
- if not callable(file_like.read):
- raise TypeError("`file_like.read` must be callable")
- self.file_like = file_like
- self._feeding = False
-
- #: array of bytes fed.
- self._buffer = bytearray()
- #: Which position we currently reads
- self._buff_i = 0
-
- # When Unpacker is used as an iterable, between the calls to next(),
- # the buffer is not "consumed" completely, for efficiency sake.
- # Instead, it is done sloppily. To make sure we raise BufferFull at
- # the correct moments, we have to keep track of how sloppy we were.
- # Furthermore, when the buffer is incomplete (that is: in the case
- # we raise an OutOfData) we need to rollback the buffer to the correct
- # state, which _buf_checkpoint records.
- self._buf_checkpoint = 0
-
- if not max_buffer_size:
- max_buffer_size = 2**31 - 1
- if max_str_len == -1:
- max_str_len = max_buffer_size
- if max_bin_len == -1:
- max_bin_len = max_buffer_size
- if max_array_len == -1:
- max_array_len = max_buffer_size
- if max_map_len == -1:
- max_map_len = max_buffer_size // 2
- if max_ext_len == -1:
- max_ext_len = max_buffer_size
-
- self._max_buffer_size = max_buffer_size
- if read_size > self._max_buffer_size:
- raise ValueError("read_size must be smaller than max_buffer_size")
- self._read_size = read_size or min(self._max_buffer_size, 16 * 1024)
- self._raw = bool(raw)
- self._strict_map_key = bool(strict_map_key)
- self._unicode_errors = unicode_errors
- self._use_list = use_list
- if not (0 <= timestamp <= 3):
- raise ValueError("timestamp must be 0..3")
- self._timestamp = timestamp
- self._list_hook = list_hook
- self._object_hook = object_hook
- self._object_pairs_hook = object_pairs_hook
- self._ext_hook = ext_hook
- self._max_str_len = max_str_len
- self._max_bin_len = max_bin_len
- self._max_array_len = max_array_len
- self._max_map_len = max_map_len
- self._max_ext_len = max_ext_len
- self._stream_offset = 0
-
- if list_hook is not None and not callable(list_hook):
- raise TypeError("`list_hook` is not callable")
- if object_hook is not None and not callable(object_hook):
- raise TypeError("`object_hook` is not callable")
- if object_pairs_hook is not None and not callable(object_pairs_hook):
- raise TypeError("`object_pairs_hook` is not callable")
- if object_hook is not None and object_pairs_hook is not None:
- raise TypeError("object_pairs_hook and object_hook are mutually exclusive")
- if not callable(ext_hook):
- raise TypeError("`ext_hook` is not callable")
-
- def feed(self, next_bytes):
- assert self._feeding
- view = _get_data_from_buffer(next_bytes)
- if len(self._buffer) - self._buff_i + len(view) > self._max_buffer_size:
- raise BufferFull
-
- # Strip buffer before checkpoint before reading file.
- if self._buf_checkpoint > 0:
- del self._buffer[: self._buf_checkpoint]
- self._buff_i -= self._buf_checkpoint
- self._buf_checkpoint = 0
-
- # Use extend here: INPLACE_ADD += doesn't reliably typecast memoryview in jython
- self._buffer.extend(view)
- view.release()
-
- def _consume(self):
- """Gets rid of the used parts of the buffer."""
- self._stream_offset += self._buff_i - self._buf_checkpoint
- self._buf_checkpoint = self._buff_i
-
- def _got_extradata(self):
- return self._buff_i < len(self._buffer)
-
- def _get_extradata(self):
- return self._buffer[self._buff_i :]
-
- def read_bytes(self, n):
- ret = self._read(n, raise_outofdata=False)
- self._consume()
- return ret
-
- def _read(self, n, raise_outofdata=True):
- # (int) -> bytearray
- self._reserve(n, raise_outofdata=raise_outofdata)
- i = self._buff_i
- ret = self._buffer[i : i + n]
- self._buff_i = i + len(ret)
- return ret
-
- def _reserve(self, n, raise_outofdata=True):
- remain_bytes = len(self._buffer) - self._buff_i - n
-
- # Fast path: buffer has n bytes already
- if remain_bytes >= 0:
- return
-
- if self._feeding:
- self._buff_i = self._buf_checkpoint
- raise OutOfData
-
- # Strip buffer before checkpoint before reading file.
- if self._buf_checkpoint > 0:
- del self._buffer[: self._buf_checkpoint]
- self._buff_i -= self._buf_checkpoint
- self._buf_checkpoint = 0
-
- # Read from file
- remain_bytes = -remain_bytes
- if remain_bytes + len(self._buffer) > self._max_buffer_size:
- raise BufferFull
- while remain_bytes > 0:
- to_read_bytes = max(self._read_size, remain_bytes)
- read_data = self.file_like.read(to_read_bytes)
- if not read_data:
- break
- assert isinstance(read_data, bytes)
- self._buffer += read_data
- remain_bytes -= len(read_data)
-
- if len(self._buffer) < n + self._buff_i and raise_outofdata:
- self._buff_i = 0 # rollback
- raise OutOfData
-
- def _read_header(self):
- typ = TYPE_IMMEDIATE
- n = 0
- obj = None
- self._reserve(1)
- b = self._buffer[self._buff_i]
- self._buff_i += 1
- if b & 0b10000000 == 0:
- obj = b
- elif b & 0b11100000 == 0b11100000:
- obj = -1 - (b ^ 0xFF)
- elif b & 0b11100000 == 0b10100000:
- n = b & 0b00011111
- typ = TYPE_RAW
- if n > self._max_str_len:
- raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})")
- obj = self._read(n)
- elif b & 0b11110000 == 0b10010000:
- n = b & 0b00001111
- typ = TYPE_ARRAY
- if n > self._max_array_len:
- raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})")
- elif b & 0b11110000 == 0b10000000:
- n = b & 0b00001111
- typ = TYPE_MAP
- if n > self._max_map_len:
- raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})")
- elif b == 0xC0:
- obj = None
- elif b == 0xC2:
- obj = False
- elif b == 0xC3:
- obj = True
- elif 0xC4 <= b <= 0xC6:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- self._reserve(size)
- if len(fmt) > 0:
- n = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
- else:
- n = self._buffer[self._buff_i]
- self._buff_i += size
- if n > self._max_bin_len:
- raise ValueError(f"{n} exceeds max_bin_len({self._max_bin_len})")
- obj = self._read(n)
- elif 0xC7 <= b <= 0xC9:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- self._reserve(size)
- L, n = struct.unpack_from(fmt, self._buffer, self._buff_i)
- self._buff_i += size
- if L > self._max_ext_len:
- raise ValueError(f"{L} exceeds max_ext_len({self._max_ext_len})")
- obj = self._read(L)
- elif 0xCA <= b <= 0xD3:
- size, fmt = _MSGPACK_HEADERS[b]
- self._reserve(size)
- if len(fmt) > 0:
- obj = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
- else:
- obj = self._buffer[self._buff_i]
- self._buff_i += size
- elif 0xD4 <= b <= 0xD8:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- if self._max_ext_len < size:
- raise ValueError(f"{size} exceeds max_ext_len({self._max_ext_len})")
- self._reserve(size + 1)
- n, obj = struct.unpack_from(fmt, self._buffer, self._buff_i)
- self._buff_i += size + 1
- elif 0xD9 <= b <= 0xDB:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- self._reserve(size)
- if len(fmt) > 0:
- (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
- else:
- n = self._buffer[self._buff_i]
- self._buff_i += size
- if n > self._max_str_len:
- raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})")
- obj = self._read(n)
- elif 0xDC <= b <= 0xDD:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- self._reserve(size)
- (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
- self._buff_i += size
- if n > self._max_array_len:
- raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})")
- elif 0xDE <= b <= 0xDF:
- size, fmt, typ = _MSGPACK_HEADERS[b]
- self._reserve(size)
- (n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
- self._buff_i += size
- if n > self._max_map_len:
- raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})")
- else:
- raise FormatError("Unknown header: 0x%x" % b)
- return typ, n, obj
-
- def _unpack(self, execute=EX_CONSTRUCT):
- typ, n, obj = self._read_header()
-
- if execute == EX_READ_ARRAY_HEADER:
- if typ != TYPE_ARRAY:
- raise ValueError("Expected array")
- return n
- if execute == EX_READ_MAP_HEADER:
- if typ != TYPE_MAP:
- raise ValueError("Expected map")
- return n
- # TODO should we eliminate the recursion?
- if typ == TYPE_ARRAY:
- if execute == EX_SKIP:
- for i in range(n):
- # TODO check whether we need to call `list_hook`
- self._unpack(EX_SKIP)
- return
- ret = newlist_hint(n)
- for i in range(n):
- ret.append(self._unpack(EX_CONSTRUCT))
- if self._list_hook is not None:
- ret = self._list_hook(ret)
- # TODO is the interaction between `list_hook` and `use_list` ok?
- return ret if self._use_list else tuple(ret)
- if typ == TYPE_MAP:
- if execute == EX_SKIP:
- for i in range(n):
- # TODO check whether we need to call hooks
- self._unpack(EX_SKIP)
- self._unpack(EX_SKIP)
- return
- if self._object_pairs_hook is not None:
- ret = self._object_pairs_hook(
- (self._unpack(EX_CONSTRUCT), self._unpack(EX_CONSTRUCT)) for _ in range(n)
- )
- else:
- ret = {}
- for _ in range(n):
- key = self._unpack(EX_CONSTRUCT)
- if self._strict_map_key and type(key) not in (str, bytes):
- raise ValueError("%s is not allowed for map key" % str(type(key)))
- if isinstance(key, str):
- key = sys.intern(key)
- ret[key] = self._unpack(EX_CONSTRUCT)
- if self._object_hook is not None:
- ret = self._object_hook(ret)
- return ret
- if execute == EX_SKIP:
- return
- if typ == TYPE_RAW:
- if self._raw:
- obj = bytes(obj)
- else:
- obj = obj.decode("utf_8", self._unicode_errors)
- return obj
- if typ == TYPE_BIN:
- return bytes(obj)
- if typ == TYPE_EXT:
- if n == -1: # timestamp
- ts = Timestamp.from_bytes(bytes(obj))
- if self._timestamp == 1:
- return ts.to_unix()
- elif self._timestamp == 2:
- return ts.to_unix_nano()
- elif self._timestamp == 3:
- return ts.to_datetime()
- else:
- return ts
- else:
- return self._ext_hook(n, bytes(obj))
- assert typ == TYPE_IMMEDIATE
- return obj
-
- def __iter__(self):
- return self
-
- def __next__(self):
- try:
- ret = self._unpack(EX_CONSTRUCT)
- self._consume()
- return ret
- except OutOfData:
- self._consume()
- raise StopIteration
- except RecursionError:
- raise StackError
-
- next = __next__
-
- def skip(self):
- self._unpack(EX_SKIP)
- self._consume()
-
- def unpack(self):
- try:
- ret = self._unpack(EX_CONSTRUCT)
- except RecursionError:
- raise StackError
- self._consume()
- return ret
-
- def read_array_header(self):
- ret = self._unpack(EX_READ_ARRAY_HEADER)
- self._consume()
- return ret
-
- def read_map_header(self):
- ret = self._unpack(EX_READ_MAP_HEADER)
- self._consume()
- return ret
-
- def tell(self):
- return self._stream_offset
-
-
-class Packer:
- """
- MessagePack Packer
-
- Usage::
-
- packer = Packer()
- astream.write(packer.pack(a))
- astream.write(packer.pack(b))
-
- Packer's constructor has some keyword arguments:
-
- :param default:
- When specified, it should be callable.
- Convert user type to builtin type that Packer supports.
- See also simplejson's document.
-
- :param bool use_single_float:
- Use single precision float type for float. (default: False)
-
- :param bool autoreset:
- Reset buffer after each pack and return its content as `bytes`. (default: True).
- If set this to false, use `bytes()` to get content and `.reset()` to clear buffer.
-
- :param bool use_bin_type:
- Use bin type introduced in msgpack spec 2.0 for bytes.
- It also enables str8 type for unicode. (default: True)
-
- :param bool strict_types:
- If set to true, types will be checked to be exact. Derived classes
- from serializable types will not be serialized and will be
- treated as unsupported type and forwarded to default.
- Additionally tuples will not be serialized as lists.
- This is useful when trying to implement accurate serialization
- for python types.
-
- :param bool datetime:
- If set to true, datetime with tzinfo is packed into Timestamp type.
- Note that the tzinfo is stripped in the timestamp.
- You can get UTC datetime with `timestamp=3` option of the Unpacker.
-
- :param str unicode_errors:
- The error handler for encoding unicode. (default: 'strict')
- DO NOT USE THIS!! This option is kept for very specific usage.
-
- :param int buf_size:
- Internal buffer size. This option is used only for C implementation.
- """
-
- def __init__(
- self,
- *,
- default=None,
- use_single_float=False,
- autoreset=True,
- use_bin_type=True,
- strict_types=False,
- datetime=False,
- unicode_errors=None,
- buf_size=None,
- ):
- self._strict_types = strict_types
- self._use_float = use_single_float
- self._autoreset = autoreset
- self._use_bin_type = use_bin_type
- self._buffer = BytesIO()
- self._datetime = bool(datetime)
- self._unicode_errors = unicode_errors or "strict"
- if default is not None and not callable(default):
- raise TypeError("default must be callable")
- self._default = default
-
- def _pack(
- self,
- obj,
- nest_limit=DEFAULT_RECURSE_LIMIT,
- check=isinstance,
- check_type_strict=_check_type_strict,
- ):
- default_used = False
- if self._strict_types:
- check = check_type_strict
- list_types = list
- else:
- list_types = (list, tuple)
- while True:
- if nest_limit < 0:
- raise ValueError("recursion limit exceeded")
- if obj is None:
- return self._buffer.write(b"\xc0")
- if check(obj, bool):
- if obj:
- return self._buffer.write(b"\xc3")
- return self._buffer.write(b"\xc2")
- if check(obj, int):
- if 0 <= obj < 0x80:
- return self._buffer.write(struct.pack("B", obj))
- if -0x20 <= obj < 0:
- return self._buffer.write(struct.pack("b", obj))
- if 0x80 <= obj <= 0xFF:
- return self._buffer.write(struct.pack("BB", 0xCC, obj))
- if -0x80 <= obj < 0:
- return self._buffer.write(struct.pack(">Bb", 0xD0, obj))
- if 0xFF < obj <= 0xFFFF:
- return self._buffer.write(struct.pack(">BH", 0xCD, obj))
- if -0x8000 <= obj < -0x80:
- return self._buffer.write(struct.pack(">Bh", 0xD1, obj))
- if 0xFFFF < obj <= 0xFFFFFFFF:
- return self._buffer.write(struct.pack(">BI", 0xCE, obj))
- if -0x80000000 <= obj < -0x8000:
- return self._buffer.write(struct.pack(">Bi", 0xD2, obj))
- if 0xFFFFFFFF < obj <= 0xFFFFFFFFFFFFFFFF:
- return self._buffer.write(struct.pack(">BQ", 0xCF, obj))
- if -0x8000000000000000 <= obj < -0x80000000:
- return self._buffer.write(struct.pack(">Bq", 0xD3, obj))
- if not default_used and self._default is not None:
- obj = self._default(obj)
- default_used = True
- continue
- raise OverflowError("Integer value out of range")
- if check(obj, (bytes, bytearray)):
- n = len(obj)
- if n >= 2**32:
- raise ValueError("%s is too large" % type(obj).__name__)
- self._pack_bin_header(n)
- return self._buffer.write(obj)
- if check(obj, str):
- obj = obj.encode("utf-8", self._unicode_errors)
- n = len(obj)
- if n >= 2**32:
- raise ValueError("String is too large")
- self._pack_raw_header(n)
- return self._buffer.write(obj)
- if check(obj, memoryview):
- n = obj.nbytes
- if n >= 2**32:
- raise ValueError("Memoryview is too large")
- self._pack_bin_header(n)
- return self._buffer.write(obj)
- if check(obj, float):
- if self._use_float:
- return self._buffer.write(struct.pack(">Bf", 0xCA, obj))
- return self._buffer.write(struct.pack(">Bd", 0xCB, obj))
- if check(obj, (ExtType, Timestamp)):
- if check(obj, Timestamp):
- code = -1
- data = obj.to_bytes()
- else:
- code = obj.code
- data = obj.data
- assert isinstance(code, int)
- assert isinstance(data, bytes)
- L = len(data)
- if L == 1:
- self._buffer.write(b"\xd4")
- elif L == 2:
- self._buffer.write(b"\xd5")
- elif L == 4:
- self._buffer.write(b"\xd6")
- elif L == 8:
- self._buffer.write(b"\xd7")
- elif L == 16:
- self._buffer.write(b"\xd8")
- elif L <= 0xFF:
- self._buffer.write(struct.pack(">BB", 0xC7, L))
- elif L <= 0xFFFF:
- self._buffer.write(struct.pack(">BH", 0xC8, L))
- else:
- self._buffer.write(struct.pack(">BI", 0xC9, L))
- self._buffer.write(struct.pack("b", code))
- self._buffer.write(data)
- return
- if check(obj, list_types):
- n = len(obj)
- self._pack_array_header(n)
- for i in range(n):
- self._pack(obj[i], nest_limit - 1)
- return
- if check(obj, dict):
- return self._pack_map_pairs(len(obj), obj.items(), nest_limit - 1)
-
- if self._datetime and check(obj, _DateTime) and obj.tzinfo is not None:
- obj = Timestamp.from_datetime(obj)
- default_used = 1
- continue
-
- if not default_used and self._default is not None:
- obj = self._default(obj)
- default_used = 1
- continue
-
- if self._datetime and check(obj, _DateTime):
- raise ValueError(f"Cannot serialize {obj!r} where tzinfo=None")
-
- raise TypeError(f"Cannot serialize {obj!r}")
-
- def pack(self, obj):
- try:
- self._pack(obj)
- except:
- self._buffer = BytesIO() # force reset
- raise
- if self._autoreset:
- ret = self._buffer.getvalue()
- self._buffer = BytesIO()
- return ret
-
- def pack_map_pairs(self, pairs):
- self._pack_map_pairs(len(pairs), pairs)
- if self._autoreset:
- ret = self._buffer.getvalue()
- self._buffer = BytesIO()
- return ret
-
- def pack_array_header(self, n):
- if n >= 2**32:
- raise ValueError
- self._pack_array_header(n)
- if self._autoreset:
- ret = self._buffer.getvalue()
- self._buffer = BytesIO()
- return ret
-
- def pack_map_header(self, n):
- if n >= 2**32:
- raise ValueError
- self._pack_map_header(n)
- if self._autoreset:
- ret = self._buffer.getvalue()
- self._buffer = BytesIO()
- return ret
-
- def pack_ext_type(self, typecode, data):
- if not isinstance(typecode, int):
- raise TypeError("typecode must have int type.")
- if not 0 <= typecode <= 127:
- raise ValueError("typecode should be 0-127")
- if not isinstance(data, bytes):
- raise TypeError("data must have bytes type")
- L = len(data)
- if L > 0xFFFFFFFF:
- raise ValueError("Too large data")
- if L == 1:
- self._buffer.write(b"\xd4")
- elif L == 2:
- self._buffer.write(b"\xd5")
- elif L == 4:
- self._buffer.write(b"\xd6")
- elif L == 8:
- self._buffer.write(b"\xd7")
- elif L == 16:
- self._buffer.write(b"\xd8")
- elif L <= 0xFF:
- self._buffer.write(b"\xc7" + struct.pack("B", L))
- elif L <= 0xFFFF:
- self._buffer.write(b"\xc8" + struct.pack(">H", L))
- else:
- self._buffer.write(b"\xc9" + struct.pack(">I", L))
- self._buffer.write(struct.pack("B", typecode))
- self._buffer.write(data)
-
- def _pack_array_header(self, n):
- if n <= 0x0F:
- return self._buffer.write(struct.pack("B", 0x90 + n))
- if n <= 0xFFFF:
- return self._buffer.write(struct.pack(">BH", 0xDC, n))
- if n <= 0xFFFFFFFF:
- return self._buffer.write(struct.pack(">BI", 0xDD, n))
- raise ValueError("Array is too large")
-
- def _pack_map_header(self, n):
- if n <= 0x0F:
- return self._buffer.write(struct.pack("B", 0x80 + n))
- if n <= 0xFFFF:
- return self._buffer.write(struct.pack(">BH", 0xDE, n))
- if n <= 0xFFFFFFFF:
- return self._buffer.write(struct.pack(">BI", 0xDF, n))
- raise ValueError("Dict is too large")
-
- def _pack_map_pairs(self, n, pairs, nest_limit=DEFAULT_RECURSE_LIMIT):
- self._pack_map_header(n)
- for k, v in pairs:
- self._pack(k, nest_limit - 1)
- self._pack(v, nest_limit - 1)
-
- def _pack_raw_header(self, n):
- if n <= 0x1F:
- self._buffer.write(struct.pack("B", 0xA0 + n))
- elif self._use_bin_type and n <= 0xFF:
- self._buffer.write(struct.pack(">BB", 0xD9, n))
- elif n <= 0xFFFF:
- self._buffer.write(struct.pack(">BH", 0xDA, n))
- elif n <= 0xFFFFFFFF:
- self._buffer.write(struct.pack(">BI", 0xDB, n))
- else:
- raise ValueError("Raw is too large")
-
- def _pack_bin_header(self, n):
- if not self._use_bin_type:
- return self._pack_raw_header(n)
- elif n <= 0xFF:
- return self._buffer.write(struct.pack(">BB", 0xC4, n))
- elif n <= 0xFFFF:
- return self._buffer.write(struct.pack(">BH", 0xC5, n))
- elif n <= 0xFFFFFFFF:
- return self._buffer.write(struct.pack(">BI", 0xC6, n))
- else:
- raise ValueError("Bin is too large")
-
- def bytes(self):
- """Return internal buffer contents as bytes object"""
- return self._buffer.getvalue()
-
- def reset(self):
- """Reset internal buffer.
-
- This method is useful only when autoreset=False.
- """
- self._buffer = BytesIO()
-
- def getbuffer(self):
- """Return view of internal buffer."""
- if _USING_STRINGBUILDER:
- return memoryview(self.bytes())
- else:
- return self._buffer.getbuffer()
diff --git a/srsly/msgpack/pack.h b/srsly/msgpack/pack.h
deleted file mode 100644
index edf3a3f..0000000
--- a/srsly/msgpack/pack.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * MessagePack for Python packing routine
- *
- * Copyright (C) 2009 Naoki INADA
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include
-#include
-#include "sysdep.h"
-#include
-#include
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-typedef struct msgpack_packer {
- char *buf;
- size_t length;
- size_t buf_size;
- bool use_bin_type;
-} msgpack_packer;
-
-typedef struct Packer Packer;
-
-static inline int msgpack_pack_write(msgpack_packer* pk, const char *data, size_t l)
-{
- char* buf = pk->buf;
- size_t bs = pk->buf_size;
- size_t len = pk->length;
-
- if (len + l > bs) {
- bs = (len + l) * 2;
- buf = (char*)PyMem_Realloc(buf, bs);
- if (!buf) {
- PyErr_NoMemory();
- return -1;
- }
- }
- memcpy(buf + len, data, l);
- len += l;
-
- pk->buf = buf;
- pk->buf_size = bs;
- pk->length = len;
- return 0;
-}
-
-#define msgpack_pack_append_buffer(user, buf, len) \
- return msgpack_pack_write(user, (const char*)buf, len)
-
-#include "pack_template.h"
-
-#ifdef __cplusplus
-}
-#endif
diff --git a/srsly/msgpack/pack_template.h b/srsly/msgpack/pack_template.h
deleted file mode 100644
index b8959f0..0000000
--- a/srsly/msgpack/pack_template.h
+++ /dev/null
@@ -1,596 +0,0 @@
-/*
- * MessagePack packing routine template
- *
- * Copyright (C) 2008-2010 FURUHASHI Sadayuki
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#if defined(__LITTLE_ENDIAN__)
-#define TAKE8_8(d) ((uint8_t*)&d)[0]
-#define TAKE8_16(d) ((uint8_t*)&d)[0]
-#define TAKE8_32(d) ((uint8_t*)&d)[0]
-#define TAKE8_64(d) ((uint8_t*)&d)[0]
-#elif defined(__BIG_ENDIAN__)
-#define TAKE8_8(d) ((uint8_t*)&d)[0]
-#define TAKE8_16(d) ((uint8_t*)&d)[1]
-#define TAKE8_32(d) ((uint8_t*)&d)[3]
-#define TAKE8_64(d) ((uint8_t*)&d)[7]
-#endif
-
-#ifndef msgpack_pack_append_buffer
-#error msgpack_pack_append_buffer callback is not defined
-#endif
-
-
-/*
- * Integer
- */
-
-#define msgpack_pack_real_uint16(x, d) \
-do { \
- if(d < (1<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
- } else if(d < (1<<8)) { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } else { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } \
-} while(0)
-
-#define msgpack_pack_real_uint32(x, d) \
-do { \
- if(d < (1<<8)) { \
- if(d < (1<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
- } else { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } \
- } else { \
- if(d < (1<<16)) { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else { \
- /* unsigned 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } \
- } \
-} while(0)
-
-#define msgpack_pack_real_uint64(x, d) \
-do { \
- if(d < (1ULL<<8)) { \
- if(d < (1ULL<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
- } else { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } \
- } else { \
- if(d < (1ULL<<16)) { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else if(d < (1ULL<<32)) { \
- /* unsigned 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } else { \
- /* unsigned 64 */ \
- unsigned char buf[9]; \
- buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
- msgpack_pack_append_buffer(x, buf, 9); \
- } \
- } \
-} while(0)
-
-#define msgpack_pack_real_int16(x, d) \
-do { \
- if(d < -(1<<5)) { \
- if(d < -(1<<7)) { \
- /* signed 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else { \
- /* signed 8 */ \
- unsigned char buf[2] = {0xd0, TAKE8_16(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } \
- } else if(d < (1<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
- } else { \
- if(d < (1<<8)) { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } else { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } \
- } \
-} while(0)
-
-#define msgpack_pack_real_int32(x, d) \
-do { \
- if(d < -(1<<5)) { \
- if(d < -(1<<15)) { \
- /* signed 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } else if(d < -(1<<7)) { \
- /* signed 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else { \
- /* signed 8 */ \
- unsigned char buf[2] = {0xd0, TAKE8_32(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } \
- } else if(d < (1<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
- } else { \
- if(d < (1<<8)) { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } else if(d < (1<<16)) { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else { \
- /* unsigned 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } \
- } \
-} while(0)
-
-#define msgpack_pack_real_int64(x, d) \
-do { \
- if(d < -(1LL<<5)) { \
- if(d < -(1LL<<15)) { \
- if(d < -(1LL<<31)) { \
- /* signed 64 */ \
- unsigned char buf[9]; \
- buf[0] = 0xd3; _msgpack_store64(&buf[1], d); \
- msgpack_pack_append_buffer(x, buf, 9); \
- } else { \
- /* signed 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } \
- } else { \
- if(d < -(1<<7)) { \
- /* signed 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } else { \
- /* signed 8 */ \
- unsigned char buf[2] = {0xd0, TAKE8_64(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } \
- } \
- } else if(d < (1<<7)) { \
- /* fixnum */ \
- msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
- } else { \
- if(d < (1LL<<16)) { \
- if(d < (1<<8)) { \
- /* unsigned 8 */ \
- unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
- msgpack_pack_append_buffer(x, buf, 2); \
- } else { \
- /* unsigned 16 */ \
- unsigned char buf[3]; \
- buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
- msgpack_pack_append_buffer(x, buf, 3); \
- } \
- } else { \
- if(d < (1LL<<32)) { \
- /* unsigned 32 */ \
- unsigned char buf[5]; \
- buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
- msgpack_pack_append_buffer(x, buf, 5); \
- } else { \
- /* unsigned 64 */ \
- unsigned char buf[9]; \
- buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
- msgpack_pack_append_buffer(x, buf, 9); \
- } \
- } \
- } \
-} while(0)
-
-
-static inline int msgpack_pack_short(msgpack_packer* x, short d)
-{
-#if defined(SIZEOF_SHORT)
-#if SIZEOF_SHORT == 2
- msgpack_pack_real_int16(x, d);
-#elif SIZEOF_SHORT == 4
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#elif defined(SHRT_MAX)
-#if SHRT_MAX == 0x7fff
- msgpack_pack_real_int16(x, d);
-#elif SHRT_MAX == 0x7fffffff
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#else
-if(sizeof(short) == 2) {
- msgpack_pack_real_int16(x, d);
-} else if(sizeof(short) == 4) {
- msgpack_pack_real_int32(x, d);
-} else {
- msgpack_pack_real_int64(x, d);
-}
-#endif
-}
-
-static inline int msgpack_pack_int(msgpack_packer* x, int d)
-{
-#if defined(SIZEOF_INT)
-#if SIZEOF_INT == 2
- msgpack_pack_real_int16(x, d);
-#elif SIZEOF_INT == 4
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#elif defined(INT_MAX)
-#if INT_MAX == 0x7fff
- msgpack_pack_real_int16(x, d);
-#elif INT_MAX == 0x7fffffff
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#else
-if(sizeof(int) == 2) {
- msgpack_pack_real_int16(x, d);
-} else if(sizeof(int) == 4) {
- msgpack_pack_real_int32(x, d);
-} else {
- msgpack_pack_real_int64(x, d);
-}
-#endif
-}
-
-static inline int msgpack_pack_long(msgpack_packer* x, long d)
-{
-#if defined(SIZEOF_LONG)
-#if SIZEOF_LONG == 4
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#elif defined(LONG_MAX)
-#if LONG_MAX == 0x7fffffffL
- msgpack_pack_real_int32(x, d);
-#else
- msgpack_pack_real_int64(x, d);
-#endif
-
-#else
- if (sizeof(long) == 4) {
- msgpack_pack_real_int32(x, d);
- } else {
- msgpack_pack_real_int64(x, d);
- }
-#endif
-}
-
-static inline int msgpack_pack_long_long(msgpack_packer* x, long long d)
-{
- msgpack_pack_real_int64(x, d);
-}
-
-static inline int msgpack_pack_unsigned_long_long(msgpack_packer* x, unsigned long long d)
-{
- msgpack_pack_real_uint64(x, d);
-}
-
-
-/*
- * Float
- */
-
-static inline int msgpack_pack_float(msgpack_packer* x, float d)
-{
- unsigned char buf[5];
- buf[0] = 0xca;
-
-#if PY_VERSION_HEX >= 0x030B00A7
- PyFloat_Pack4(d, (char *)&buf[1], 0);
-#else
- _PyFloat_Pack4(d, &buf[1], 0);
-#endif
- msgpack_pack_append_buffer(x, buf, 5);
-}
-
-static inline int msgpack_pack_double(msgpack_packer* x, double d)
-{
- unsigned char buf[9];
- buf[0] = 0xcb;
-#if PY_VERSION_HEX >= 0x030B00A7
- PyFloat_Pack8(d, (char *)&buf[1], 0);
-#else
- _PyFloat_Pack8(d, &buf[1], 0);
-#endif
- msgpack_pack_append_buffer(x, buf, 9);
-}
-
-
-/*
- * Nil
- */
-
-static inline int msgpack_pack_nil(msgpack_packer* x)
-{
- static const unsigned char d = 0xc0;
- msgpack_pack_append_buffer(x, &d, 1);
-}
-
-
-/*
- * Boolean
- */
-
-static inline int msgpack_pack_true(msgpack_packer* x)
-{
- static const unsigned char d = 0xc3;
- msgpack_pack_append_buffer(x, &d, 1);
-}
-
-static inline int msgpack_pack_false(msgpack_packer* x)
-{
- static const unsigned char d = 0xc2;
- msgpack_pack_append_buffer(x, &d, 1);
-}
-
-
-/*
- * Array
- */
-
-static inline int msgpack_pack_array(msgpack_packer* x, unsigned int n)
-{
- if(n < 16) {
- unsigned char d = 0x90 | n;
- msgpack_pack_append_buffer(x, &d, 1);
- } else if(n < 65536) {
- unsigned char buf[3];
- buf[0] = 0xdc; _msgpack_store16(&buf[1], (uint16_t)n);
- msgpack_pack_append_buffer(x, buf, 3);
- } else {
- unsigned char buf[5];
- buf[0] = 0xdd; _msgpack_store32(&buf[1], (uint32_t)n);
- msgpack_pack_append_buffer(x, buf, 5);
- }
-}
-
-
-/*
- * Map
- */
-
-static inline int msgpack_pack_map(msgpack_packer* x, unsigned int n)
-{
- if(n < 16) {
- unsigned char d = 0x80 | n;
- msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
- } else if(n < 65536) {
- unsigned char buf[3];
- buf[0] = 0xde; _msgpack_store16(&buf[1], (uint16_t)n);
- msgpack_pack_append_buffer(x, buf, 3);
- } else {
- unsigned char buf[5];
- buf[0] = 0xdf; _msgpack_store32(&buf[1], (uint32_t)n);
- msgpack_pack_append_buffer(x, buf, 5);
- }
-}
-
-
-/*
- * Raw
- */
-
-static inline int msgpack_pack_raw(msgpack_packer* x, size_t l)
-{
- if (l < 32) {
- unsigned char d = 0xa0 | (uint8_t)l;
- msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
- } else if (x->use_bin_type && l < 256) { // str8 is new format introduced with bin.
- unsigned char buf[2] = {0xd9, (uint8_t)l};
- msgpack_pack_append_buffer(x, buf, 2);
- } else if (l < 65536) {
- unsigned char buf[3];
- buf[0] = 0xda; _msgpack_store16(&buf[1], (uint16_t)l);
- msgpack_pack_append_buffer(x, buf, 3);
- } else {
- unsigned char buf[5];
- buf[0] = 0xdb; _msgpack_store32(&buf[1], (uint32_t)l);
- msgpack_pack_append_buffer(x, buf, 5);
- }
-}
-
-/*
- * bin
- */
-static inline int msgpack_pack_bin(msgpack_packer *x, size_t l)
-{
- if (!x->use_bin_type) {
- return msgpack_pack_raw(x, l);
- }
- if (l < 256) {
- unsigned char buf[2] = {0xc4, (unsigned char)l};
- msgpack_pack_append_buffer(x, buf, 2);
- } else if (l < 65536) {
- unsigned char buf[3] = {0xc5};
- _msgpack_store16(&buf[1], (uint16_t)l);
- msgpack_pack_append_buffer(x, buf, 3);
- } else {
- unsigned char buf[5] = {0xc6};
- _msgpack_store32(&buf[1], (uint32_t)l);
- msgpack_pack_append_buffer(x, buf, 5);
- }
-}
-
-static inline int msgpack_pack_raw_body(msgpack_packer* x, const void* b, size_t l)
-{
- if (l > 0) msgpack_pack_append_buffer(x, (const unsigned char*)b, l);
- return 0;
-}
-
-/*
- * Ext
- */
-static inline int msgpack_pack_ext(msgpack_packer* x, char typecode, size_t l)
-{
- if (l == 1) {
- unsigned char buf[2];
- buf[0] = 0xd4;
- buf[1] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 2);
- }
- else if(l == 2) {
- unsigned char buf[2];
- buf[0] = 0xd5;
- buf[1] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 2);
- }
- else if(l == 4) {
- unsigned char buf[2];
- buf[0] = 0xd6;
- buf[1] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 2);
- }
- else if(l == 8) {
- unsigned char buf[2];
- buf[0] = 0xd7;
- buf[1] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 2);
- }
- else if(l == 16) {
- unsigned char buf[2];
- buf[0] = 0xd8;
- buf[1] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 2);
- }
- else if(l < 256) {
- unsigned char buf[3];
- buf[0] = 0xc7;
- buf[1] = l;
- buf[2] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 3);
- } else if(l < 65536) {
- unsigned char buf[4];
- buf[0] = 0xc8;
- _msgpack_store16(&buf[1], (uint16_t)l);
- buf[3] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 4);
- } else {
- unsigned char buf[6];
- buf[0] = 0xc9;
- _msgpack_store32(&buf[1], (uint32_t)l);
- buf[5] = (unsigned char)typecode;
- msgpack_pack_append_buffer(x, buf, 6);
- }
-
-}
-
-/*
- * Pack Timestamp extension type. Follows msgpack-c pack_template.h.
- */
-static inline int msgpack_pack_timestamp(msgpack_packer* x, int64_t seconds, uint32_t nanoseconds)
-{
- if ((seconds >> 34) == 0) {
- /* seconds is unsigned and fits in 34 bits */
- uint64_t data64 = ((uint64_t)nanoseconds << 34) | (uint64_t)seconds;
- if ((data64 & 0xffffffff00000000L) == 0) {
- /* no nanoseconds and seconds is 32bits or smaller. timestamp32. */
- unsigned char buf[4];
- uint32_t data32 = (uint32_t)data64;
- msgpack_pack_ext(x, -1, 4);
- _msgpack_store32(buf, data32);
- msgpack_pack_raw_body(x, buf, 4);
- } else {
- /* timestamp64 */
- unsigned char buf[8];
- msgpack_pack_ext(x, -1, 8);
- _msgpack_store64(buf, data64);
- msgpack_pack_raw_body(x, buf, 8);
-
- }
- } else {
- /* seconds is signed or >34bits */
- unsigned char buf[12];
- _msgpack_store32(&buf[0], nanoseconds);
- _msgpack_store64(&buf[4], seconds);
- msgpack_pack_ext(x, -1, 12);
- msgpack_pack_raw_body(x, buf, 12);
- }
- return 0;
-}
-
-
-#undef msgpack_pack_append_buffer
-
-#undef TAKE8_8
-#undef TAKE8_16
-#undef TAKE8_32
-#undef TAKE8_64
-
-#undef msgpack_pack_real_uint16
-#undef msgpack_pack_real_uint32
-#undef msgpack_pack_real_uint64
-#undef msgpack_pack_real_int16
-#undef msgpack_pack_real_int32
-#undef msgpack_pack_real_int64
diff --git a/srsly/msgpack/sysdep.h b/srsly/msgpack/sysdep.h
deleted file mode 100644
index 7067300..0000000
--- a/srsly/msgpack/sysdep.h
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * MessagePack system dependencies
- *
- * Copyright (C) 2008-2010 FURUHASHI Sadayuki
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef MSGPACK_SYSDEP_H__
-#define MSGPACK_SYSDEP_H__
-
-#include
-#include
-#if defined(_MSC_VER) && _MSC_VER < 1600
-typedef __int8 int8_t;
-typedef unsigned __int8 uint8_t;
-typedef __int16 int16_t;
-typedef unsigned __int16 uint16_t;
-typedef __int32 int32_t;
-typedef unsigned __int32 uint32_t;
-typedef __int64 int64_t;
-typedef unsigned __int64 uint64_t;
-#elif defined(_MSC_VER) // && _MSC_VER >= 1600
-#include
-#else
-#include
-#include
-#endif
-
-#ifdef _WIN32
-#define _msgpack_atomic_counter_header
-typedef long _msgpack_atomic_counter_t;
-#define _msgpack_sync_decr_and_fetch(ptr) InterlockedDecrement(ptr)
-#define _msgpack_sync_incr_and_fetch(ptr) InterlockedIncrement(ptr)
-#elif defined(__GNUC__) && ((__GNUC__*10 + __GNUC_MINOR__) < 41)
-#define _msgpack_atomic_counter_header "gcc_atomic.h"
-#else
-typedef unsigned int _msgpack_atomic_counter_t;
-#define _msgpack_sync_decr_and_fetch(ptr) __sync_sub_and_fetch(ptr, 1)
-#define _msgpack_sync_incr_and_fetch(ptr) __sync_add_and_fetch(ptr, 1)
-#endif
-
-#ifdef _WIN32
-
-#ifdef __cplusplus
-/* numeric_limits::min,max */
-#ifdef max
-#undef max
-#endif
-#ifdef min
-#undef min
-#endif
-#endif
-
-#else /* _WIN32 */
-#include /* ntohs, ntohl */
-#endif
-
-#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)
-#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
-#define __LITTLE_ENDIAN__
-#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-#define __BIG_ENDIAN__
-#elif _WIN32
-#define __LITTLE_ENDIAN__
-#endif
-#endif
-
-
-#ifdef __LITTLE_ENDIAN__
-
-#ifdef _WIN32
-# if defined(ntohs)
-# define _msgpack_be16(x) ntohs(x)
-# elif defined(_byteswap_ushort) || (defined(_MSC_VER) && _MSC_VER >= 1400)
-# define _msgpack_be16(x) ((uint16_t)_byteswap_ushort((unsigned short)x))
-# else
-# define _msgpack_be16(x) ( \
- ((((uint16_t)x) << 8) ) | \
- ((((uint16_t)x) >> 8) ) )
-# endif
-#else
-# define _msgpack_be16(x) ntohs(x)
-#endif
-
-#ifdef _WIN32
-# if defined(ntohl)
-# define _msgpack_be32(x) ntohl(x)
-# elif defined(_byteswap_ulong) || defined(_MSC_VER)
-# define _msgpack_be32(x) ((uint32_t)_byteswap_ulong((unsigned long)x))
-# else
-# define _msgpack_be32(x) \
- ( ((((uint32_t)x) << 24) ) | \
- ((((uint32_t)x) << 8) & 0x00ff0000U ) | \
- ((((uint32_t)x) >> 8) & 0x0000ff00U ) | \
- ((((uint32_t)x) >> 24) ) )
-# endif
-#else
-# define _msgpack_be32(x) ntohl(x)
-#endif
-
-#if defined(_byteswap_uint64) || defined(_MSC_VER)
-# define _msgpack_be64(x) (_byteswap_uint64(x))
-#elif defined(bswap_64)
-# define _msgpack_be64(x) bswap_64(x)
-#elif defined(__DARWIN_OSSwapInt64)
-# define _msgpack_be64(x) __DARWIN_OSSwapInt64(x)
-#else
-#define _msgpack_be64(x) \
- ( ((((uint64_t)x) << 56) ) | \
- ((((uint64_t)x) << 40) & 0x00ff000000000000ULL ) | \
- ((((uint64_t)x) << 24) & 0x0000ff0000000000ULL ) | \
- ((((uint64_t)x) << 8) & 0x000000ff00000000ULL ) | \
- ((((uint64_t)x) >> 8) & 0x00000000ff000000ULL ) | \
- ((((uint64_t)x) >> 24) & 0x0000000000ff0000ULL ) | \
- ((((uint64_t)x) >> 40) & 0x000000000000ff00ULL ) | \
- ((((uint64_t)x) >> 56) ) )
-#endif
-
-#define _msgpack_load16(cast, from) ((cast)( \
- (((uint16_t)((uint8_t*)(from))[0]) << 8) | \
- (((uint16_t)((uint8_t*)(from))[1]) ) ))
-
-#define _msgpack_load32(cast, from) ((cast)( \
- (((uint32_t)((uint8_t*)(from))[0]) << 24) | \
- (((uint32_t)((uint8_t*)(from))[1]) << 16) | \
- (((uint32_t)((uint8_t*)(from))[2]) << 8) | \
- (((uint32_t)((uint8_t*)(from))[3]) ) ))
-
-#define _msgpack_load64(cast, from) ((cast)( \
- (((uint64_t)((uint8_t*)(from))[0]) << 56) | \
- (((uint64_t)((uint8_t*)(from))[1]) << 48) | \
- (((uint64_t)((uint8_t*)(from))[2]) << 40) | \
- (((uint64_t)((uint8_t*)(from))[3]) << 32) | \
- (((uint64_t)((uint8_t*)(from))[4]) << 24) | \
- (((uint64_t)((uint8_t*)(from))[5]) << 16) | \
- (((uint64_t)((uint8_t*)(from))[6]) << 8) | \
- (((uint64_t)((uint8_t*)(from))[7]) ) ))
-
-#else
-
-#define _msgpack_be16(x) (x)
-#define _msgpack_be32(x) (x)
-#define _msgpack_be64(x) (x)
-
-#define _msgpack_load16(cast, from) ((cast)( \
- (((uint16_t)((uint8_t*)from)[0]) << 8) | \
- (((uint16_t)((uint8_t*)from)[1]) ) ))
-
-#define _msgpack_load32(cast, from) ((cast)( \
- (((uint32_t)((uint8_t*)from)[0]) << 24) | \
- (((uint32_t)((uint8_t*)from)[1]) << 16) | \
- (((uint32_t)((uint8_t*)from)[2]) << 8) | \
- (((uint32_t)((uint8_t*)from)[3]) ) ))
-
-#define _msgpack_load64(cast, from) ((cast)( \
- (((uint64_t)((uint8_t*)from)[0]) << 56) | \
- (((uint64_t)((uint8_t*)from)[1]) << 48) | \
- (((uint64_t)((uint8_t*)from)[2]) << 40) | \
- (((uint64_t)((uint8_t*)from)[3]) << 32) | \
- (((uint64_t)((uint8_t*)from)[4]) << 24) | \
- (((uint64_t)((uint8_t*)from)[5]) << 16) | \
- (((uint64_t)((uint8_t*)from)[6]) << 8) | \
- (((uint64_t)((uint8_t*)from)[7]) ) ))
-#endif
-
-
-#define _msgpack_store16(to, num) \
- do { uint16_t val = _msgpack_be16(num); memcpy(to, &val, 2); } while(0)
-#define _msgpack_store32(to, num) \
- do { uint32_t val = _msgpack_be32(num); memcpy(to, &val, 4); } while(0)
-#define _msgpack_store64(to, num) \
- do { uint64_t val = _msgpack_be64(num); memcpy(to, &val, 8); } while(0)
-
-/*
-#define _msgpack_load16(cast, from) \
- ({ cast val; memcpy(&val, (char*)from, 2); _msgpack_be16(val); })
-#define _msgpack_load32(cast, from) \
- ({ cast val; memcpy(&val, (char*)from, 4); _msgpack_be32(val); })
-#define _msgpack_load64(cast, from) \
- ({ cast val; memcpy(&val, (char*)from, 8); _msgpack_be64(val); })
-*/
-
-
-#endif /* msgpack/sysdep.h */
diff --git a/srsly/msgpack/unpack.h b/srsly/msgpack/unpack.h
deleted file mode 100644
index dabb5c1..0000000
--- a/srsly/msgpack/unpack.h
+++ /dev/null
@@ -1,393 +0,0 @@
-/*
- * MessagePack for Python unpacking routine
- *
- * Copyright (C) 2009 Naoki INADA
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#define MSGPACK_EMBED_STACK_SIZE (1024)
-#include "unpack_define.h"
-
-typedef struct unpack_user {
- bool use_list;
- bool raw;
- bool has_pairs_hook;
- bool strict_map_key;
- int timestamp;
- PyObject *object_hook;
- PyObject *list_hook;
- PyObject *ext_hook;
- PyObject *timestamp_t;
- PyObject *giga;
- PyObject *utc;
- const char *encoding;
- const char *unicode_errors;
- Py_ssize_t max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len;
-} unpack_user;
-
-typedef PyObject* msgpack_unpack_object;
-struct unpack_context;
-typedef struct unpack_context unpack_context;
-typedef int (*execute_fn)(unpack_context *ctx, const char* data, Py_ssize_t len, Py_ssize_t* off);
-
-static inline msgpack_unpack_object unpack_callback_root(unpack_user* u)
-{
- return NULL;
-}
-
-static inline int unpack_callback_uint16(unpack_user* u, uint16_t d, msgpack_unpack_object* o)
-{
- PyObject *p = PyLong_FromLong((long)d);
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-static inline int unpack_callback_uint8(unpack_user* u, uint8_t d, msgpack_unpack_object* o)
-{
- return unpack_callback_uint16(u, d, o);
-}
-
-
-static inline int unpack_callback_uint32(unpack_user* u, uint32_t d, msgpack_unpack_object* o)
-{
- PyObject *p = PyLong_FromSize_t((size_t)d);
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_uint64(unpack_user* u, uint64_t d, msgpack_unpack_object* o)
-{
- PyObject *p;
- if (d > LONG_MAX) {
- p = PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG)d);
- } else {
- p = PyLong_FromLong((long)d);
- }
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_int32(unpack_user* u, int32_t d, msgpack_unpack_object* o)
-{
- PyObject *p = PyLong_FromLong(d);
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_int16(unpack_user* u, int16_t d, msgpack_unpack_object* o)
-{
- return unpack_callback_int32(u, d, o);
-}
-
-static inline int unpack_callback_int8(unpack_user* u, int8_t d, msgpack_unpack_object* o)
-{
- return unpack_callback_int32(u, d, o);
-}
-
-static inline int unpack_callback_int64(unpack_user* u, int64_t d, msgpack_unpack_object* o)
-{
- PyObject *p;
- if (d > LONG_MAX || d < LONG_MIN) {
- p = PyLong_FromLongLong((PY_LONG_LONG)d);
- } else {
- p = PyLong_FromLong((long)d);
- }
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_double(unpack_user* u, double d, msgpack_unpack_object* o)
-{
- PyObject *p = PyFloat_FromDouble(d);
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_float(unpack_user* u, float d, msgpack_unpack_object* o)
-{
- return unpack_callback_double(u, d, o);
-}
-
-static inline int unpack_callback_nil(unpack_user* u, msgpack_unpack_object* o)
-{ Py_INCREF(Py_None); *o = Py_None; return 0; }
-
-static inline int unpack_callback_true(unpack_user* u, msgpack_unpack_object* o)
-{ Py_INCREF(Py_True); *o = Py_True; return 0; }
-
-static inline int unpack_callback_false(unpack_user* u, msgpack_unpack_object* o)
-{ Py_INCREF(Py_False); *o = Py_False; return 0; }
-
-static inline int unpack_callback_array(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
-{
- if (n > u->max_array_len) {
- PyErr_Format(PyExc_ValueError, "%u exceeds max_array_len(%zd)", n, u->max_array_len);
- return -1;
- }
- PyObject *p = u->use_list ? PyList_New(n) : PyTuple_New(n);
-
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_array_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object o)
-{
- if (u->use_list)
- PyList_SET_ITEM(*c, current, o);
- else
- PyTuple_SET_ITEM(*c, current, o);
- return 0;
-}
-
-static inline int unpack_callback_array_end(unpack_user* u, msgpack_unpack_object* c)
-{
- if (u->list_hook) {
- PyObject *new_c = PyObject_CallFunctionObjArgs(u->list_hook, *c, NULL);
- if (!new_c)
- return -1;
- Py_DECREF(*c);
- *c = new_c;
- }
- return 0;
-}
-
-static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
-{
- if (n > u->max_map_len) {
- PyErr_Format(PyExc_ValueError, "%u exceeds max_map_len(%zd)", n, u->max_map_len);
- return -1;
- }
- PyObject *p;
- if (u->has_pairs_hook) {
- p = PyList_New(n); // Or use tuple?
- }
- else {
- p = PyDict_New();
- }
- if (!p)
- return -1;
- *o = p;
- return 0;
-}
-
-static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v)
-{
- if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) {
- PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key when strict_map_key=True", Py_TYPE(k)->tp_name);
- return -1;
- }
- if (PyUnicode_CheckExact(k)) {
- PyUnicode_InternInPlace(&k);
- }
- if (u->has_pairs_hook) {
- msgpack_unpack_object item = PyTuple_Pack(2, k, v);
- if (!item)
- return -1;
- Py_DECREF(k);
- Py_DECREF(v);
- PyList_SET_ITEM(*c, current, item);
- return 0;
- }
- else if (PyDict_SetItem(*c, k, v) == 0) {
- Py_DECREF(k);
- Py_DECREF(v);
- return 0;
- }
- return -1;
-}
-
-static inline int unpack_callback_map_end(unpack_user* u, msgpack_unpack_object* c)
-{
- if (u->object_hook) {
- PyObject *new_c = PyObject_CallFunctionObjArgs(u->object_hook, *c, NULL);
- if (!new_c)
- return -1;
-
- Py_DECREF(*c);
- *c = new_c;
- }
- return 0;
-}
-
-static inline int unpack_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
-{
- if (l > u->max_str_len) {
- PyErr_Format(PyExc_ValueError, "%u exceeds max_str_len(%zd)", l, u->max_str_len);
- return -1;
- }
-
- PyObject *py;
- if (u->encoding) {
- py = PyUnicode_Decode(p, l, u->encoding, u->unicode_errors);
- } else if (u->raw) {
- py = PyBytes_FromStringAndSize(p, l);
- } else {
- py = PyUnicode_DecodeUTF8(p, l, u->unicode_errors);
- }
- if (!py)
- return -1;
- *o = py;
- return 0;
-}
-
-static inline int unpack_callback_bin(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
-{
- if (l > u->max_bin_len) {
- PyErr_Format(PyExc_ValueError, "%u exceeds max_bin_len(%zd)", l, u->max_bin_len);
- return -1;
- }
-
- PyObject *py = PyBytes_FromStringAndSize(p, l);
- if (!py)
- return -1;
- *o = py;
- return 0;
-}
-
-typedef struct msgpack_timestamp {
- int64_t tv_sec;
- uint32_t tv_nsec;
-} msgpack_timestamp;
-
-/*
- * Unpack ext buffer to a timestamp. Pulled from msgpack-c timestamp.h.
- */
-static int unpack_timestamp(const char* buf, unsigned int buflen, msgpack_timestamp* ts) {
- switch (buflen) {
- case 4:
- ts->tv_nsec = 0;
- {
- uint32_t v = _msgpack_load32(uint32_t, buf);
- ts->tv_sec = (int64_t)v;
- }
- return 0;
- case 8: {
- uint64_t value =_msgpack_load64(uint64_t, buf);
- ts->tv_nsec = (uint32_t)(value >> 34);
- ts->tv_sec = value & 0x00000003ffffffffLL;
- return 0;
- }
- case 12:
- ts->tv_nsec = _msgpack_load32(uint32_t, buf);
- ts->tv_sec = _msgpack_load64(int64_t, buf + 4);
- return 0;
- default:
- return -1;
- }
-}
-
-#include "datetime.h"
-
-static int unpack_callback_ext(unpack_user* u, const char* base, const char* pos,
- unsigned int length, msgpack_unpack_object* o)
-{
- int8_t typecode = (int8_t)*pos++;
- if (!u->ext_hook) {
- PyErr_SetString(PyExc_AssertionError, "u->ext_hook cannot be NULL");
- return -1;
- }
- if (length-1 > u->max_ext_len) {
- PyErr_Format(PyExc_ValueError, "%u exceeds max_ext_len(%zd)", length, u->max_ext_len);
- return -1;
- }
-
- PyObject *py = NULL;
- // length also includes the typecode, so the actual data is length-1
- if (typecode == -1) {
- msgpack_timestamp ts;
- if (unpack_timestamp(pos, length-1, &ts) < 0) {
- return -1;
- }
-
- if (u->timestamp == 2) { // int
- PyObject *a = PyLong_FromLongLong(ts.tv_sec);
- if (a == NULL) return -1;
-
- PyObject *c = PyNumber_Multiply(a, u->giga);
- Py_DECREF(a);
- if (c == NULL) {
- return -1;
- }
-
- PyObject *b = PyLong_FromUnsignedLong(ts.tv_nsec);
- if (b == NULL) {
- Py_DECREF(c);
- return -1;
- }
-
- py = PyNumber_Add(c, b);
- Py_DECREF(c);
- Py_DECREF(b);
- }
- else if (u->timestamp == 0) { // Timestamp
- py = PyObject_CallFunction(u->timestamp_t, "(Lk)", ts.tv_sec, ts.tv_nsec);
- }
- else if (u->timestamp == 3) { // datetime
- // Calculate datetime using epoch + delta
- // due to limitations PyDateTime_FromTimestamp on Windows with negative timestamps
- PyObject *epoch = PyDateTimeAPI->DateTime_FromDateAndTime(1970, 1, 1, 0, 0, 0, 0, u->utc, PyDateTimeAPI->DateTimeType);
- if (epoch == NULL) {
- return -1;
- }
-
- PyObject* d = PyDelta_FromDSU(ts.tv_sec/(24*3600), ts.tv_sec%(24*3600), ts.tv_nsec / 1000);
- if (d == NULL) {
- Py_DECREF(epoch);
- return -1;
- }
-
- py = PyNumber_Add(epoch, d);
-
- Py_DECREF(epoch);
- Py_DECREF(d);
- }
- else { // float
- PyObject *a = PyFloat_FromDouble((double)ts.tv_nsec);
- if (a == NULL) return -1;
-
- PyObject *b = PyNumber_TrueDivide(a, u->giga);
- Py_DECREF(a);
- if (b == NULL) return -1;
-
- PyObject *c = PyLong_FromLongLong(ts.tv_sec);
- if (c == NULL) {
- Py_DECREF(b);
- return -1;
- }
-
- a = PyNumber_Add(b, c);
- Py_DECREF(b);
- Py_DECREF(c);
- py = a;
- }
- } else {
- py = PyObject_CallFunction(u->ext_hook, "(iy#)", (int)typecode, pos, (Py_ssize_t)length-1);
- }
- if (!py)
- return -1;
- *o = py;
- return 0;
-}
-
-#include "unpack_template.h"
diff --git a/srsly/msgpack/unpack_container_header.h b/srsly/msgpack/unpack_container_header.h
deleted file mode 100644
index c14a3c2..0000000
--- a/srsly/msgpack/unpack_container_header.h
+++ /dev/null
@@ -1,51 +0,0 @@
-static inline int unpack_container_header(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
-{
- assert(len >= *off);
- uint32_t size;
- const unsigned char *const p = (unsigned char*)data + *off;
-
-#define inc_offset(inc) \
- if (len - *off < inc) \
- return 0; \
- *off += inc;
-
- switch (*p) {
- case var_offset:
- inc_offset(3);
- size = _msgpack_load16(uint16_t, p + 1);
- break;
- case var_offset + 1:
- inc_offset(5);
- size = _msgpack_load32(uint32_t, p + 1);
- break;
-#ifdef USE_CASE_RANGE
- case fixed_offset + 0x0 ... fixed_offset + 0xf:
-#else
- case fixed_offset + 0x0:
- case fixed_offset + 0x1:
- case fixed_offset + 0x2:
- case fixed_offset + 0x3:
- case fixed_offset + 0x4:
- case fixed_offset + 0x5:
- case fixed_offset + 0x6:
- case fixed_offset + 0x7:
- case fixed_offset + 0x8:
- case fixed_offset + 0x9:
- case fixed_offset + 0xa:
- case fixed_offset + 0xb:
- case fixed_offset + 0xc:
- case fixed_offset + 0xd:
- case fixed_offset + 0xe:
- case fixed_offset + 0xf:
-#endif
- ++*off;
- size = ((unsigned int)*p) & 0x0f;
- break;
- default:
- PyErr_SetString(PyExc_ValueError, "Unexpected type header on stream");
- return -1;
- }
- unpack_callback_uint32(&ctx->user, size, &ctx->stack[0].obj);
- return 1;
-}
-
diff --git a/srsly/msgpack/unpack_define.h b/srsly/msgpack/unpack_define.h
deleted file mode 100644
index 0dd708d..0000000
--- a/srsly/msgpack/unpack_define.h
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * MessagePack unpacking routine template
- *
- * Copyright (C) 2008-2010 FURUHASHI Sadayuki
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#ifndef MSGPACK_UNPACK_DEFINE_H__
-#define MSGPACK_UNPACK_DEFINE_H__
-
-#include "msgpack/sysdep.h"
-#include
-#include
-#include
-#include
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-
-#ifndef MSGPACK_EMBED_STACK_SIZE
-#define MSGPACK_EMBED_STACK_SIZE 32
-#endif
-
-
-// CS is first byte & 0x1f
-typedef enum {
- CS_HEADER = 0x00, // nil
-
- //CS_ = 0x01,
- //CS_ = 0x02, // false
- //CS_ = 0x03, // true
-
- CS_BIN_8 = 0x04,
- CS_BIN_16 = 0x05,
- CS_BIN_32 = 0x06,
-
- CS_EXT_8 = 0x07,
- CS_EXT_16 = 0x08,
- CS_EXT_32 = 0x09,
-
- CS_FLOAT = 0x0a,
- CS_DOUBLE = 0x0b,
- CS_UINT_8 = 0x0c,
- CS_UINT_16 = 0x0d,
- CS_UINT_32 = 0x0e,
- CS_UINT_64 = 0x0f,
- CS_INT_8 = 0x10,
- CS_INT_16 = 0x11,
- CS_INT_32 = 0x12,
- CS_INT_64 = 0x13,
-
- //CS_FIXEXT1 = 0x14,
- //CS_FIXEXT2 = 0x15,
- //CS_FIXEXT4 = 0x16,
- //CS_FIXEXT8 = 0x17,
- //CS_FIXEXT16 = 0x18,
-
- CS_RAW_8 = 0x19,
- CS_RAW_16 = 0x1a,
- CS_RAW_32 = 0x1b,
- CS_ARRAY_16 = 0x1c,
- CS_ARRAY_32 = 0x1d,
- CS_MAP_16 = 0x1e,
- CS_MAP_32 = 0x1f,
-
- ACS_RAW_VALUE,
- ACS_BIN_VALUE,
- ACS_EXT_VALUE,
-} msgpack_unpack_state;
-
-
-typedef enum {
- CT_ARRAY_ITEM,
- CT_MAP_KEY,
- CT_MAP_VALUE,
-} msgpack_container_type;
-
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif /* msgpack/unpack_define.h */
diff --git a/srsly/msgpack/unpack_template.h b/srsly/msgpack/unpack_template.h
deleted file mode 100644
index cce29e7..0000000
--- a/srsly/msgpack/unpack_template.h
+++ /dev/null
@@ -1,423 +0,0 @@
-/*
- * MessagePack unpacking routine template
- *
- * Copyright (C) 2008-2010 FURUHASHI Sadayuki
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef USE_CASE_RANGE
-#if !defined(_MSC_VER)
-#define USE_CASE_RANGE
-#endif
-#endif
-
-typedef struct unpack_stack {
- PyObject* obj;
- Py_ssize_t size;
- Py_ssize_t count;
- unsigned int ct;
- PyObject* map_key;
-} unpack_stack;
-
-struct unpack_context {
- unpack_user user;
- unsigned int cs;
- unsigned int trail;
- unsigned int top;
- /*
- unpack_stack* stack;
- unsigned int stack_size;
- unpack_stack embed_stack[MSGPACK_EMBED_STACK_SIZE];
- */
- unpack_stack stack[MSGPACK_EMBED_STACK_SIZE];
-};
-
-
-static inline void unpack_init(unpack_context* ctx)
-{
- ctx->cs = CS_HEADER;
- ctx->trail = 0;
- ctx->top = 0;
- /*
- ctx->stack = ctx->embed_stack;
- ctx->stack_size = MSGPACK_EMBED_STACK_SIZE;
- */
- ctx->stack[0].obj = unpack_callback_root(&ctx->user);
-}
-
-/*
-static inline void unpack_destroy(unpack_context* ctx)
-{
- if(ctx->stack_size != MSGPACK_EMBED_STACK_SIZE) {
- free(ctx->stack);
- }
-}
-*/
-
-static inline PyObject* unpack_data(unpack_context* ctx)
-{
- return (ctx)->stack[0].obj;
-}
-
-static inline void unpack_clear(unpack_context *ctx)
-{
- Py_CLEAR(ctx->stack[0].obj);
-}
-
-static inline int unpack_execute(bool construct, unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
-{
- assert(len >= *off);
-
- const unsigned char* p = (unsigned char*)data + *off;
- const unsigned char* const pe = (unsigned char*)data + len;
- const void* n = p;
-
- unsigned int trail = ctx->trail;
- unsigned int cs = ctx->cs;
- unsigned int top = ctx->top;
- unpack_stack* stack = ctx->stack;
- /*
- unsigned int stack_size = ctx->stack_size;
- */
- unpack_user* user = &ctx->user;
-
- PyObject* obj = NULL;
- unpack_stack* c = NULL;
-
- int ret;
-
-#define construct_cb(name) \
- construct && unpack_callback ## name
-
-#define push_simple_value(func) \
- if(construct_cb(func)(user, &obj) < 0) { goto _failed; } \
- goto _push
-#define push_fixed_value(func, arg) \
- if(construct_cb(func)(user, arg, &obj) < 0) { goto _failed; } \
- goto _push
-#define push_variable_value(func, base, pos, len) \
- if(construct_cb(func)(user, \
- (const char*)base, (const char*)pos, len, &obj) < 0) { goto _failed; } \
- goto _push
-
-#define again_fixed_trail(_cs, trail_len) \
- trail = trail_len; \
- cs = _cs; \
- goto _fixed_trail_again
-#define again_fixed_trail_if_zero(_cs, trail_len, ifzero) \
- trail = trail_len; \
- if(trail == 0) { goto ifzero; } \
- cs = _cs; \
- goto _fixed_trail_again
-
-#define start_container(func, count_, ct_) \
- if(top >= MSGPACK_EMBED_STACK_SIZE) { ret = -3; goto _end; } \
- if(construct_cb(func)(user, count_, &stack[top].obj) < 0) { goto _failed; } \
- if((count_) == 0) { obj = stack[top].obj; \
- if (construct_cb(func##_end)(user, &obj) < 0) { goto _failed; } \
- goto _push; } \
- stack[top].ct = ct_; \
- stack[top].size = count_; \
- stack[top].count = 0; \
- ++top; \
- goto _header_again
-
-#define NEXT_CS(p) ((unsigned int)*p & 0x1f)
-
-#ifdef USE_CASE_RANGE
-#define SWITCH_RANGE_BEGIN switch(*p) {
-#define SWITCH_RANGE(FROM, TO) case FROM ... TO:
-#define SWITCH_RANGE_DEFAULT default:
-#define SWITCH_RANGE_END }
-#else
-#define SWITCH_RANGE_BEGIN { if(0) {
-#define SWITCH_RANGE(FROM, TO) } else if(FROM <= *p && *p <= TO) {
-#define SWITCH_RANGE_DEFAULT } else {
-#define SWITCH_RANGE_END } }
-#endif
-
- if(p == pe) { goto _out; }
- do {
- switch(cs) {
- case CS_HEADER:
- SWITCH_RANGE_BEGIN
- SWITCH_RANGE(0x00, 0x7f) // Positive Fixnum
- push_fixed_value(_uint8, *(uint8_t*)p);
- SWITCH_RANGE(0xe0, 0xff) // Negative Fixnum
- push_fixed_value(_int8, *(int8_t*)p);
- SWITCH_RANGE(0xc0, 0xdf) // Variable
- switch(*p) {
- case 0xc0: // nil
- push_simple_value(_nil);
- //case 0xc1: // never used
- case 0xc2: // false
- push_simple_value(_false);
- case 0xc3: // true
- push_simple_value(_true);
- case 0xc4: // bin 8
- again_fixed_trail(NEXT_CS(p), 1);
- case 0xc5: // bin 16
- again_fixed_trail(NEXT_CS(p), 2);
- case 0xc6: // bin 32
- again_fixed_trail(NEXT_CS(p), 4);
- case 0xc7: // ext 8
- again_fixed_trail(NEXT_CS(p), 1);
- case 0xc8: // ext 16
- again_fixed_trail(NEXT_CS(p), 2);
- case 0xc9: // ext 32
- again_fixed_trail(NEXT_CS(p), 4);
- case 0xca: // float
- case 0xcb: // double
- case 0xcc: // unsigned int 8
- case 0xcd: // unsigned int 16
- case 0xce: // unsigned int 32
- case 0xcf: // unsigned int 64
- case 0xd0: // signed int 8
- case 0xd1: // signed int 16
- case 0xd2: // signed int 32
- case 0xd3: // signed int 64
- again_fixed_trail(NEXT_CS(p), 1 << (((unsigned int)*p) & 0x03));
- case 0xd4: // fixext 1
- case 0xd5: // fixext 2
- case 0xd6: // fixext 4
- case 0xd7: // fixext 8
- again_fixed_trail_if_zero(ACS_EXT_VALUE,
- (1 << (((unsigned int)*p) & 0x03))+1,
- _ext_zero);
- case 0xd8: // fixext 16
- again_fixed_trail_if_zero(ACS_EXT_VALUE, 16+1, _ext_zero);
- case 0xd9: // str 8
- again_fixed_trail(NEXT_CS(p), 1);
- case 0xda: // raw 16
- case 0xdb: // raw 32
- case 0xdc: // array 16
- case 0xdd: // array 32
- case 0xde: // map 16
- case 0xdf: // map 32
- again_fixed_trail(NEXT_CS(p), 2 << (((unsigned int)*p) & 0x01));
- default:
- ret = -2;
- goto _end;
- }
- SWITCH_RANGE(0xa0, 0xbf) // FixRaw
- again_fixed_trail_if_zero(ACS_RAW_VALUE, ((unsigned int)*p & 0x1f), _raw_zero);
- SWITCH_RANGE(0x90, 0x9f) // FixArray
- start_container(_array, ((unsigned int)*p) & 0x0f, CT_ARRAY_ITEM);
- SWITCH_RANGE(0x80, 0x8f) // FixMap
- start_container(_map, ((unsigned int)*p) & 0x0f, CT_MAP_KEY);
-
- SWITCH_RANGE_DEFAULT
- ret = -2;
- goto _end;
- SWITCH_RANGE_END
- // end CS_HEADER
-
-
- _fixed_trail_again:
- ++p;
-
- default:
- if((size_t)(pe - p) < trail) { goto _out; }
- n = p; p += trail - 1;
- switch(cs) {
- case CS_EXT_8:
- again_fixed_trail_if_zero(ACS_EXT_VALUE, *(uint8_t*)n+1, _ext_zero);
- case CS_EXT_16:
- again_fixed_trail_if_zero(ACS_EXT_VALUE,
- _msgpack_load16(uint16_t,n)+1,
- _ext_zero);
- case CS_EXT_32:
- again_fixed_trail_if_zero(ACS_EXT_VALUE,
- _msgpack_load32(uint32_t,n)+1,
- _ext_zero);
- case CS_FLOAT: {
- double f;
-#if PY_VERSION_HEX >= 0x030B00A7
- f = PyFloat_Unpack4((const char*)n, 0);
-#else
- f = _PyFloat_Unpack4((unsigned char*)n, 0);
-#endif
- push_fixed_value(_float, f); }
- case CS_DOUBLE: {
- double f;
-#if PY_VERSION_HEX >= 0x030B00A7
- f = PyFloat_Unpack8((const char*)n, 0);
-#else
- f = _PyFloat_Unpack8((unsigned char*)n, 0);
-#endif
- push_fixed_value(_double, f); }
- case CS_UINT_8:
- push_fixed_value(_uint8, *(uint8_t*)n);
- case CS_UINT_16:
- push_fixed_value(_uint16, _msgpack_load16(uint16_t,n));
- case CS_UINT_32:
- push_fixed_value(_uint32, _msgpack_load32(uint32_t,n));
- case CS_UINT_64:
- push_fixed_value(_uint64, _msgpack_load64(uint64_t,n));
-
- case CS_INT_8:
- push_fixed_value(_int8, *(int8_t*)n);
- case CS_INT_16:
- push_fixed_value(_int16, _msgpack_load16(int16_t,n));
- case CS_INT_32:
- push_fixed_value(_int32, _msgpack_load32(int32_t,n));
- case CS_INT_64:
- push_fixed_value(_int64, _msgpack_load64(int64_t,n));
-
- case CS_BIN_8:
- again_fixed_trail_if_zero(ACS_BIN_VALUE, *(uint8_t*)n, _bin_zero);
- case CS_BIN_16:
- again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load16(uint16_t,n), _bin_zero);
- case CS_BIN_32:
- again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load32(uint32_t,n), _bin_zero);
- case ACS_BIN_VALUE:
- _bin_zero:
- push_variable_value(_bin, data, n, trail);
-
- case CS_RAW_8:
- again_fixed_trail_if_zero(ACS_RAW_VALUE, *(uint8_t*)n, _raw_zero);
- case CS_RAW_16:
- again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load16(uint16_t,n), _raw_zero);
- case CS_RAW_32:
- again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load32(uint32_t,n), _raw_zero);
- case ACS_RAW_VALUE:
- _raw_zero:
- push_variable_value(_raw, data, n, trail);
-
- case ACS_EXT_VALUE:
- _ext_zero:
- push_variable_value(_ext, data, n, trail);
-
- case CS_ARRAY_16:
- start_container(_array, _msgpack_load16(uint16_t,n), CT_ARRAY_ITEM);
- case CS_ARRAY_32:
- /* FIXME security guard */
- start_container(_array, _msgpack_load32(uint32_t,n), CT_ARRAY_ITEM);
-
- case CS_MAP_16:
- start_container(_map, _msgpack_load16(uint16_t,n), CT_MAP_KEY);
- case CS_MAP_32:
- /* FIXME security guard */
- start_container(_map, _msgpack_load32(uint32_t,n), CT_MAP_KEY);
-
- default:
- goto _failed;
- }
- }
-
-_push:
- if(top == 0) { goto _finish; }
- c = &stack[top-1];
- switch(c->ct) {
- case CT_ARRAY_ITEM:
- if(construct_cb(_array_item)(user, c->count, &c->obj, obj) < 0) { goto _failed; }
- if(++c->count == c->size) {
- obj = c->obj;
- if (construct_cb(_array_end)(user, &obj) < 0) { goto _failed; }
- --top;
- /*printf("stack pop %d\n", top);*/
- goto _push;
- }
- goto _header_again;
- case CT_MAP_KEY:
- c->map_key = obj;
- c->ct = CT_MAP_VALUE;
- goto _header_again;
- case CT_MAP_VALUE:
- if(construct_cb(_map_item)(user, c->count, &c->obj, c->map_key, obj) < 0) { goto _failed; }
- if(++c->count == c->size) {
- obj = c->obj;
- if (construct_cb(_map_end)(user, &obj) < 0) { goto _failed; }
- --top;
- /*printf("stack pop %d\n", top);*/
- goto _push;
- }
- c->ct = CT_MAP_KEY;
- goto _header_again;
-
- default:
- goto _failed;
- }
-
-_header_again:
- cs = CS_HEADER;
- ++p;
- } while(p != pe);
- goto _out;
-
-
-_finish:
- if (!construct)
- unpack_callback_nil(user, &obj);
- stack[0].obj = obj;
- ++p;
- ret = 1;
- /*printf("-- finish --\n"); */
- goto _end;
-
-_failed:
- /*printf("** FAILED **\n"); */
- ret = -1;
- goto _end;
-
-_out:
- ret = 0;
- goto _end;
-
-_end:
- ctx->cs = cs;
- ctx->trail = trail;
- ctx->top = top;
- *off = p - (const unsigned char*)data;
-
- return ret;
-#undef construct_cb
-}
-
-#undef NEXT_CS
-#undef SWITCH_RANGE_BEGIN
-#undef SWITCH_RANGE
-#undef SWITCH_RANGE_DEFAULT
-#undef SWITCH_RANGE_END
-#undef push_simple_value
-#undef push_fixed_value
-#undef push_variable_value
-#undef again_fixed_trail
-#undef again_fixed_trail_if_zero
-#undef start_container
-
-static int unpack_construct(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) {
- return unpack_execute(1, ctx, data, len, off);
-}
-static int unpack_skip(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) {
- return unpack_execute(0, ctx, data, len, off);
-}
-
-#define unpack_container_header read_array_header
-#define fixed_offset 0x90
-#define var_offset 0xdc
-#include "unpack_container_header.h"
-#undef unpack_container_header
-#undef fixed_offset
-#undef var_offset
-
-#define unpack_container_header read_map_header
-#define fixed_offset 0x80
-#define var_offset 0xde
-#include "unpack_container_header.h"
-#undef unpack_container_header
-#undef fixed_offset
-#undef var_offset
-
-/* vim: set ts=4 sw=4 sts=4 expandtab */
diff --git a/srsly/msgpack/util.py b/srsly/msgpack/util.py
deleted file mode 100644
index 967d66d..0000000
--- a/srsly/msgpack/util.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from __future__ import unicode_literals
-
-try:
- unicode
-except NameError:
- unicode = str
-
-
-def ensure_bytes(string):
- """Ensure a string is returned as a bytes object, encoded as utf8."""
- if isinstance(string, unicode):
- return string.encode("utf8")
- else:
- return string
diff --git a/srsly/ruamel_yaml/LICENSE b/srsly/ruamel_yaml/LICENSE
deleted file mode 100755
index 5b863d3..0000000
--- a/srsly/ruamel_yaml/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
- The MIT License (MIT)
-
- Copyright (c) 2014-2020 Anthon van der Neut, Ruamel bvba
-
- Permission is hereby granted, free of charge, to any person obtaining a copy
- of this software and associated documentation files (the "Software"), to deal
- in the Software without restriction, including without limitation the rights
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- copies of the Software, and to permit persons to whom the Software is
- furnished to do so, subject to the following conditions:
-
- The above copyright notice and this permission notice shall be included in
- all copies or substantial portions of the Software.
-
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- SOFTWARE.
diff --git a/srsly/ruamel_yaml/__init__.py b/srsly/ruamel_yaml/__init__.py
deleted file mode 100755
index b7678ec..0000000
--- a/srsly/ruamel_yaml/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-__with_libyaml__ = False
-
-from .main import * # NOQA
-
-version_info = (0, 16, 7)
-__version__ = "0.16.7"
diff --git a/srsly/ruamel_yaml/anchor.py b/srsly/ruamel_yaml/anchor.py
deleted file mode 100755
index aa649f5..0000000
--- a/srsly/ruamel_yaml/anchor.py
+++ /dev/null
@@ -1,20 +0,0 @@
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA
-
-anchor_attrib = '_yaml_anchor'
-
-
-class Anchor(object):
- __slots__ = 'value', 'always_dump'
- attrib = anchor_attrib
-
- def __init__(self):
- # type: () -> None
- self.value = None
- self.always_dump = False
-
- def __repr__(self):
- # type: () -> Any
- ad = ', (always dump)' if self.always_dump else ""
- return 'Anchor({!r}{})'.format(self.value, ad)
diff --git a/srsly/ruamel_yaml/comments.py b/srsly/ruamel_yaml/comments.py
deleted file mode 100755
index 9090298..0000000
--- a/srsly/ruamel_yaml/comments.py
+++ /dev/null
@@ -1,1153 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import, print_function
-
-"""
-stuff to deal with comments and formatting on dict/list/ordereddict/set
-these are not really related, formatting could be factored out as
-a separate base
-"""
-
-import sys
-import copy
-
-
-from .compat import ordereddict # type: ignore
-from .compat import PY2, string_types, MutableSliceableSequence
-from .scalarstring import ScalarString
-from .anchor import Anchor
-
-if PY2:
- from collections import MutableSet, Sized, Set, Mapping
-else:
- from collections.abc import MutableSet, Sized, Set, Mapping
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA
-
-# fmt: off
-__all__ = ['CommentedSeq', 'CommentedKeySeq',
- 'CommentedMap', 'CommentedOrderedMap',
- 'CommentedSet', 'comment_attrib', 'merge_attrib']
-# fmt: on
-
-comment_attrib = "_yaml_comment"
-format_attrib = "_yaml_format"
-line_col_attrib = "_yaml_line_col"
-merge_attrib = "_yaml_merge"
-tag_attrib = "_yaml_tag"
-
-
-class Comment(object):
- # sys.getsize tested the Comment objects, __slots__ makes them bigger
- # and adding self.end did not matter
- __slots__ = "comment", "_items", "_end", "_start"
- attrib = comment_attrib
-
- def __init__(self):
- # type: () -> None
- self.comment = None # [post, [pre]]
- # map key (mapping/omap/dict) or index (sequence/list) to a list of
- # dict: post_key, pre_key, post_value, pre_value
- # list: pre item, post item
- self._items = {} # type: Dict[Any, Any]
- # self._start = [] # should not put these on first item
- self._end = [] # type: List[Any] # end of document comments
-
- def __str__(self):
- # type: () -> str
- if bool(self._end):
- end = ",\n end=" + str(self._end)
- else:
- end = ""
- return "Comment(comment={0},\n items={1}{2})".format(
- self.comment, self._items, end
- )
-
- @property
- def items(self):
- # type: () -> Any
- return self._items
-
- @property
- def end(self):
- # type: () -> Any
- return self._end
-
- @end.setter
- def end(self, value):
- # type: (Any) -> None
- self._end = value
-
- @property
- def start(self):
- # type: () -> Any
- return self._start
-
- @start.setter
- def start(self, value):
- # type: (Any) -> None
- self._start = value
-
-
-# to distinguish key from None
-def NoComment():
- # type: () -> None
- pass
-
-
-class Format(object):
- __slots__ = ("_flow_style",)
- attrib = format_attrib
-
- def __init__(self):
- # type: () -> None
- self._flow_style = None # type: Any
-
- def set_flow_style(self):
- # type: () -> None
- self._flow_style = True
-
- def set_block_style(self):
- # type: () -> None
- self._flow_style = False
-
- def flow_style(self, default=None):
- # type: (Optional[Any]) -> Any
- """if default (the flow_style) is None, the flow style tacked on to
- the object explicitly will be taken. If that is None as well the
- default flow style rules the format down the line, or the type
- of the constituent values (simple -> flow, map/list -> block)"""
- if self._flow_style is None:
- return default
- return self._flow_style
-
-
-class LineCol(object):
- attrib = line_col_attrib
-
- def __init__(self):
- # type: () -> None
- self.line = None
- self.col = None
- self.data = None # type: Optional[Dict[Any, Any]]
-
- def add_kv_line_col(self, key, data):
- # type: (Any, Any) -> None
- if self.data is None:
- self.data = {}
- self.data[key] = data
-
- def key(self, k):
- # type: (Any) -> Any
- return self._kv(k, 0, 1)
-
- def value(self, k):
- # type: (Any) -> Any
- return self._kv(k, 2, 3)
-
- def _kv(self, k, x0, x1):
- # type: (Any, Any, Any) -> Any
- if self.data is None:
- return None
- data = self.data[k]
- return data[x0], data[x1]
-
- def item(self, idx):
- # type: (Any) -> Any
- if self.data is None:
- return None
- return self.data[idx][0], self.data[idx][1]
-
- def add_idx_line_col(self, key, data):
- # type: (Any, Any) -> None
- if self.data is None:
- self.data = {}
- self.data[key] = data
-
-
-class Tag(object):
- """store tag information for roundtripping"""
-
- __slots__ = ("value",)
- attrib = tag_attrib
-
- def __init__(self):
- # type: () -> None
- self.value = None
-
- def __repr__(self):
- # type: () -> Any
- return "{0.__class__.__name__}({0.value!r})".format(self)
-
-
-class CommentedBase(object):
- @property
- def ca(self):
- # type: () -> Any
- if not hasattr(self, Comment.attrib):
- setattr(self, Comment.attrib, Comment())
- return getattr(self, Comment.attrib)
-
- def yaml_end_comment_extend(self, comment, clear=False):
- # type: (Any, bool) -> None
- if comment is None:
- return
- if clear or self.ca.end is None:
- self.ca.end = []
- self.ca.end.extend(comment)
-
- def yaml_key_comment_extend(self, key, comment, clear=False):
- # type: (Any, Any, bool) -> None
- r = self.ca._items.setdefault(key, [None, None, None, None])
- if clear or r[1] is None:
- if comment[1] is not None:
- assert isinstance(comment[1], list)
- r[1] = comment[1]
- else:
- r[1].extend(comment[0])
- r[0] = comment[0]
-
- def yaml_value_comment_extend(self, key, comment, clear=False):
- # type: (Any, Any, bool) -> None
- r = self.ca._items.setdefault(key, [None, None, None, None])
- if clear or r[3] is None:
- if comment[1] is not None:
- assert isinstance(comment[1], list)
- r[3] = comment[1]
- else:
- r[3].extend(comment[0])
- r[2] = comment[0]
-
- def yaml_set_start_comment(self, comment, indent=0):
- # type: (Any, Any) -> None
- """overwrites any preceding comment lines on an object
- expects comment to be without `#` and possible have multiple lines
- """
- from .error import CommentMark
- from .tokens import CommentToken
-
- pre_comments = self._yaml_get_pre_comment()
- if comment[-1] == "\n":
- comment = comment[:-1] # strip final newline if there
- start_mark = CommentMark(indent)
- for com in comment.split("\n"):
- pre_comments.append(CommentToken("# " + com + "\n", start_mark, None))
-
- def yaml_set_comment_before_after_key(
- self, key, before=None, indent=0, after=None, after_indent=None
- ):
- # type: (Any, Any, Any, Any, Any) -> None
- """
- expects comment (before/after) to be without `#` and possible have multiple lines
- """
- from srsly.ruamel_yaml.error import CommentMark
- from srsly.ruamel_yaml.tokens import CommentToken
-
- def comment_token(s, mark):
- # type: (Any, Any) -> Any
- # handle empty lines as having no comment
- return CommentToken(("# " if s else "") + s + "\n", mark, None)
-
- if after_indent is None:
- after_indent = indent + 2
- if before and (len(before) > 1) and before[-1] == "\n":
- before = before[:-1] # strip final newline if there
- if after and after[-1] == "\n":
- after = after[:-1] # strip final newline if there
- start_mark = CommentMark(indent)
- c = self.ca.items.setdefault(key, [None, [], None, None])
- if before == "\n":
- c[1].append(comment_token("", start_mark))
- elif before:
- for com in before.split("\n"):
- c[1].append(comment_token(com, start_mark))
- if after:
- start_mark = CommentMark(after_indent)
- if c[3] is None:
- c[3] = []
- for com in after.split("\n"):
- c[3].append(comment_token(com, start_mark)) # type: ignore
-
- @property
- def fa(self):
- # type: () -> Any
- """format attribute
-
- set_flow_style()/set_block_style()"""
- if not hasattr(self, Format.attrib):
- setattr(self, Format.attrib, Format())
- return getattr(self, Format.attrib)
-
- def yaml_add_eol_comment(self, comment, key=NoComment, column=None):
- # type: (Any, Optional[Any], Optional[Any]) -> None
- """
- there is a problem as eol comments should start with ' #'
- (but at the beginning of the line the space doesn't have to be before
- the #. The column index is for the # mark
- """
- from .tokens import CommentToken
- from .error import CommentMark
-
- if column is None:
- try:
- column = self._yaml_get_column(key)
- except AttributeError:
- column = 0
- if comment[0] != "#":
- comment = "# " + comment
- if column is None:
- if comment[0] == "#":
- comment = " " + comment
- column = 0
- start_mark = CommentMark(column)
- ct = [CommentToken(comment, start_mark, None), None]
- self._yaml_add_eol_comment(ct, key=key)
-
- @property
- def lc(self):
- # type: () -> Any
- if not hasattr(self, LineCol.attrib):
- setattr(self, LineCol.attrib, LineCol())
- return getattr(self, LineCol.attrib)
-
- def _yaml_set_line_col(self, line, col):
- # type: (Any, Any) -> None
- self.lc.line = line
- self.lc.col = col
-
- def _yaml_set_kv_line_col(self, key, data):
- # type: (Any, Any) -> None
- self.lc.add_kv_line_col(key, data)
-
- def _yaml_set_idx_line_col(self, key, data):
- # type: (Any, Any) -> None
- self.lc.add_idx_line_col(key, data)
-
- @property
- def anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- setattr(self, Anchor.attrib, Anchor())
- return getattr(self, Anchor.attrib)
-
- def yaml_anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- return None
- return self.anchor
-
- def yaml_set_anchor(self, value, always_dump=False):
- # type: (Any, bool) -> None
- self.anchor.value = value
- self.anchor.always_dump = always_dump
-
- @property
- def tag(self):
- # type: () -> Any
- if not hasattr(self, Tag.attrib):
- setattr(self, Tag.attrib, Tag())
- return getattr(self, Tag.attrib)
-
- def yaml_set_tag(self, value):
- # type: (Any) -> None
- self.tag.value = value
-
- def copy_attributes(self, t, memo=None):
- # type: (Any, Any) -> None
- # fmt: off
- for a in [Comment.attrib, Format.attrib, LineCol.attrib, Anchor.attrib,
- Tag.attrib, merge_attrib]:
- if hasattr(self, a):
- if memo is not None:
- setattr(t, a, copy.deepcopy(getattr(self, a, memo)))
- else:
- setattr(t, a, getattr(self, a))
- # fmt: on
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- raise NotImplementedError
-
- def _yaml_get_pre_comment(self):
- # type: () -> Any
- raise NotImplementedError
-
- def _yaml_get_column(self, key):
- # type: (Any) -> Any
- raise NotImplementedError
-
-
-class CommentedSeq(MutableSliceableSequence, list, CommentedBase): # type: ignore
- __slots__ = (Comment.attrib, "_lst")
-
- def __init__(self, *args, **kw):
- # type: (Any, Any) -> None
- list.__init__(self, *args, **kw)
-
- def __getsingleitem__(self, idx):
- # type: (Any) -> Any
- return list.__getitem__(self, idx)
-
- def __setsingleitem__(self, idx, value):
- # type: (Any, Any) -> None
- # try to preserve the scalarstring type if setting an existing key to a new value
- if idx < len(self):
- if (
- isinstance(value, string_types)
- and not isinstance(value, ScalarString)
- and isinstance(self[idx], ScalarString)
- ):
- value = type(self[idx])(value)
- list.__setitem__(self, idx, value)
-
- def __delsingleitem__(self, idx=None):
- # type: (Any) -> Any
- list.__delitem__(self, idx)
- self.ca.items.pop(idx, None) # might not be there -> default value
- for list_index in sorted(self.ca.items):
- if list_index < idx:
- continue
- self.ca.items[list_index - 1] = self.ca.items.pop(list_index)
-
- def __len__(self):
- # type: () -> int
- return list.__len__(self)
-
- def insert(self, idx, val):
- # type: (Any, Any) -> None
- """the comments after the insertion have to move forward"""
- list.insert(self, idx, val)
- for list_index in sorted(self.ca.items, reverse=True):
- if list_index < idx:
- break
- self.ca.items[list_index + 1] = self.ca.items.pop(list_index)
-
- def extend(self, val):
- # type: (Any) -> None
- list.extend(self, val)
-
- def __eq__(self, other):
- # type: (Any) -> bool
- return list.__eq__(self, other)
-
- def _yaml_add_comment(self, comment, key=NoComment):
- # type: (Any, Optional[Any]) -> None
- if key is not NoComment:
- self.yaml_key_comment_extend(key, comment)
- else:
- self.ca.comment = comment
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- self._yaml_add_comment(comment, key=key)
-
- def _yaml_get_columnX(self, key):
- # type: (Any) -> Any
- return self.ca.items[key][0].start_mark.column
-
- def _yaml_get_column(self, key):
- # type: (Any) -> Any
- column = None
- sel_idx = None
- pre, post = key - 1, key + 1
- if pre in self.ca.items:
- sel_idx = pre
- elif post in self.ca.items:
- sel_idx = post
- else:
- # self.ca.items is not ordered
- for row_idx, _k1 in enumerate(self):
- if row_idx >= key:
- break
- if row_idx not in self.ca.items:
- continue
- sel_idx = row_idx
- if sel_idx is not None:
- column = self._yaml_get_columnX(sel_idx)
- return column
-
- def _yaml_get_pre_comment(self):
- # type: () -> Any
- pre_comments = [] # type: List[Any]
- if self.ca.comment is None:
- self.ca.comment = [None, pre_comments]
- else:
- self.ca.comment[1] = pre_comments
- return pre_comments
-
- def __deepcopy__(self, memo):
- # type: (Any) -> Any
- res = self.__class__()
- memo[id(self)] = res
- for k in self:
- res.append(copy.deepcopy(k, memo))
- self.copy_attributes(res, memo=memo)
- return res
-
- def __add__(self, other):
- # type: (Any) -> Any
- return list.__add__(self, other)
-
- def sort(self, key=None, reverse=False): # type: ignore
- # type: (Any, bool) -> None
- if key is None:
- tmp_lst = sorted(zip(self, range(len(self))), reverse=reverse)
- list.__init__(self, [x[0] for x in tmp_lst])
- else:
- tmp_lst = sorted(
- zip(map(key, list.__iter__(self)), range(len(self))), reverse=reverse
- )
- list.__init__(self, [list.__getitem__(self, x[1]) for x in tmp_lst])
- itm = self.ca.items
- self.ca._items = {}
- for idx, x in enumerate(tmp_lst):
- old_index = x[1]
- if old_index in itm:
- self.ca.items[idx] = itm[old_index]
-
- def __repr__(self):
- # type: () -> Any
- return list.__repr__(self)
-
-
-class CommentedKeySeq(tuple, CommentedBase): # type: ignore
- """This primarily exists to be able to roundtrip keys that are sequences"""
-
- def _yaml_add_comment(self, comment, key=NoComment):
- # type: (Any, Optional[Any]) -> None
- if key is not NoComment:
- self.yaml_key_comment_extend(key, comment)
- else:
- self.ca.comment = comment
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- self._yaml_add_comment(comment, key=key)
-
- def _yaml_get_columnX(self, key):
- # type: (Any) -> Any
- return self.ca.items[key][0].start_mark.column
-
- def _yaml_get_column(self, key):
- # type: (Any) -> Any
- column = None
- sel_idx = None
- pre, post = key - 1, key + 1
- if pre in self.ca.items:
- sel_idx = pre
- elif post in self.ca.items:
- sel_idx = post
- else:
- # self.ca.items is not ordered
- for row_idx, _k1 in enumerate(self):
- if row_idx >= key:
- break
- if row_idx not in self.ca.items:
- continue
- sel_idx = row_idx
- if sel_idx is not None:
- column = self._yaml_get_columnX(sel_idx)
- return column
-
- def _yaml_get_pre_comment(self):
- # type: () -> Any
- pre_comments = [] # type: List[Any]
- if self.ca.comment is None:
- self.ca.comment = [None, pre_comments]
- else:
- self.ca.comment[1] = pre_comments
- return pre_comments
-
-
-class CommentedMapView(Sized):
- __slots__ = ("_mapping",)
-
- def __init__(self, mapping):
- # type: (Any) -> None
- self._mapping = mapping
-
- def __len__(self):
- # type: () -> int
- count = len(self._mapping)
- return count
-
-
-class CommentedMapKeysView(CommentedMapView, Set): # type: ignore
- __slots__ = ()
-
- @classmethod
- def _from_iterable(self, it):
- # type: (Any) -> Any
- return set(it)
-
- def __contains__(self, key):
- # type: (Any) -> Any
- return key in self._mapping
-
- def __iter__(self):
- # type: () -> Any # yield from self._mapping # not in py27, pypy
- # for x in self._mapping._keys():
- for x in self._mapping:
- yield x
-
-
-class CommentedMapItemsView(CommentedMapView, Set): # type: ignore
- __slots__ = ()
-
- @classmethod
- def _from_iterable(self, it):
- # type: (Any) -> Any
- return set(it)
-
- def __contains__(self, item):
- # type: (Any) -> Any
- key, value = item
- try:
- v = self._mapping[key]
- except KeyError:
- return False
- else:
- return v == value
-
- def __iter__(self):
- # type: () -> Any
- for key in self._mapping._keys():
- yield (key, self._mapping[key])
-
-
-class CommentedMapValuesView(CommentedMapView):
- __slots__ = ()
-
- def __contains__(self, value):
- # type: (Any) -> Any
- for key in self._mapping:
- if value == self._mapping[key]:
- return True
- return False
-
- def __iter__(self):
- # type: () -> Any
- for key in self._mapping._keys():
- yield self._mapping[key]
-
-
-class CommentedMap(ordereddict, CommentedBase): # type: ignore
- __slots__ = (Comment.attrib, "_ok", "_ref")
-
- def __init__(self, *args, **kw):
- # type: (Any, Any) -> None
- self._ok = set() # type: MutableSet[Any] # own keys
- self._ref = [] # type: List[CommentedMap]
- ordereddict.__init__(self, *args, **kw)
-
- def _yaml_add_comment(self, comment, key=NoComment, value=NoComment):
- # type: (Any, Optional[Any], Optional[Any]) -> None
- """values is set to key to indicate a value attachment of comment"""
- if key is not NoComment:
- self.yaml_key_comment_extend(key, comment)
- return
- if value is not NoComment:
- self.yaml_value_comment_extend(value, comment)
- else:
- self.ca.comment = comment
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- """add on the value line, with value specified by the key"""
- self._yaml_add_comment(comment, value=key)
-
- def _yaml_get_columnX(self, key):
- # type: (Any) -> Any
- return self.ca.items[key][2].start_mark.column
-
- def _yaml_get_column(self, key):
- # type: (Any) -> Any
- column = None
- sel_idx = None
- pre, post, last = None, None, None
- for x in self:
- if pre is not None and x != key:
- post = x
- break
- if x == key:
- pre = last
- last = x
- if pre in self.ca.items:
- sel_idx = pre
- elif post in self.ca.items:
- sel_idx = post
- else:
- # self.ca.items is not ordered
- for k1 in self:
- if k1 >= key:
- break
- if k1 not in self.ca.items:
- continue
- sel_idx = k1
- if sel_idx is not None:
- column = self._yaml_get_columnX(sel_idx)
- return column
-
- def _yaml_get_pre_comment(self):
- # type: () -> Any
- pre_comments = [] # type: List[Any]
- if self.ca.comment is None:
- self.ca.comment = [None, pre_comments]
- else:
- self.ca.comment[1] = pre_comments
- return pre_comments
-
- def update(self, vals):
- # type: (Any) -> None
- try:
- ordereddict.update(self, vals)
- except TypeError:
- # probably a dict that is used
- for x in vals:
- self[x] = vals[x]
- try:
- self._ok.update(vals.keys()) # type: ignore
- except AttributeError:
- # assume a list/tuple of two element lists/tuples
- for x in vals:
- self._ok.add(x[0])
-
- def insert(self, pos, key, value, comment=None):
- # type: (Any, Any, Any, Optional[Any]) -> None
- """insert key value into given position
- attach comment if provided
- """
- ordereddict.insert(self, pos, key, value)
- self._ok.add(key)
- if comment is not None:
- self.yaml_add_eol_comment(comment, key=key)
-
- def mlget(self, key, default=None, list_ok=False):
- # type: (Any, Any, Any) -> Any
- """multi-level get that expects dicts within dicts"""
- if not isinstance(key, list):
- return self.get(key, default)
- # assume that the key is a list of recursively accessible dicts
-
- def get_one_level(key_list, level, d):
- # type: (Any, Any, Any) -> Any
- if not list_ok:
- assert isinstance(d, dict)
- if level >= len(key_list):
- if level > len(key_list):
- raise IndexError
- return d[key_list[level - 1]]
- return get_one_level(key_list, level + 1, d[key_list[level - 1]])
-
- try:
- return get_one_level(key, 1, self)
- except KeyError:
- return default
- except (TypeError, IndexError):
- if not list_ok:
- raise
- return default
-
- def __getitem__(self, key):
- # type: (Any) -> Any
- try:
- return ordereddict.__getitem__(self, key)
- except KeyError:
- for merged in getattr(self, merge_attrib, []):
- if key in merged[1]:
- return merged[1][key]
- raise
-
- def __setitem__(self, key, value):
- # type: (Any, Any) -> None
- # try to preserve the scalarstring type if setting an existing key to a new value
- if key in self:
- if (
- isinstance(value, string_types)
- and not isinstance(value, ScalarString)
- and isinstance(self[key], ScalarString)
- ):
- value = type(self[key])(value)
- ordereddict.__setitem__(self, key, value)
- self._ok.add(key)
-
- def _unmerged_contains(self, key):
- # type: (Any) -> Any
- if key in self._ok:
- return True
- return None
-
- def __contains__(self, key):
- # type: (Any) -> bool
- return bool(ordereddict.__contains__(self, key))
-
- def get(self, key, default=None):
- # type: (Any, Any) -> Any
- try:
- return self.__getitem__(key)
- except: # NOQA
- return default
-
- def __repr__(self):
- # type: () -> Any
- return ordereddict.__repr__(self).replace("CommentedMap", "ordereddict")
-
- def non_merged_items(self):
- # type: () -> Any
- for x in ordereddict.__iter__(self):
- if x in self._ok:
- yield x, ordereddict.__getitem__(self, x)
-
- def __delitem__(self, key):
- # type: (Any) -> None
- # for merged in getattr(self, merge_attrib, []):
- # if key in merged[1]:
- # value = merged[1][key]
- # break
- # else:
- # # not found in merged in stuff
- # ordereddict.__delitem__(self, key)
- # for referer in self._ref:
- # referer.update_key_value(key)
- # return
- #
- # ordereddict.__setitem__(self, key, value) # merge might have different value
- # self._ok.discard(key)
- self._ok.discard(key)
- ordereddict.__delitem__(self, key)
- for referer in self._ref:
- referer.update_key_value(key)
-
- def __iter__(self):
- # type: () -> Any
- for x in ordereddict.__iter__(self):
- yield x
-
- def _keys(self):
- # type: () -> Any
- for x in ordereddict.__iter__(self):
- yield x
-
- def __len__(self):
- # type: () -> int
- return int(ordereddict.__len__(self))
-
- def __eq__(self, other):
- # type: (Any) -> bool
- return bool(dict(self) == other)
-
- if PY2:
-
- def keys(self):
- # type: () -> Any
- return list(self._keys())
-
- def iterkeys(self):
- # type: () -> Any
- return self._keys()
-
- def viewkeys(self):
- # type: () -> Any
- return CommentedMapKeysView(self)
-
- else:
-
- def keys(self):
- # type: () -> Any
- return CommentedMapKeysView(self)
-
- if PY2:
-
- def _values(self):
- # type: () -> Any
- for x in ordereddict.__iter__(self):
- yield ordereddict.__getitem__(self, x)
-
- def values(self):
- # type: () -> Any
- return list(self._values())
-
- def itervalues(self):
- # type: () -> Any
- return self._values()
-
- def viewvalues(self):
- # type: () -> Any
- return CommentedMapValuesView(self)
-
- else:
-
- def values(self):
- # type: () -> Any
- return CommentedMapValuesView(self)
-
- def _items(self):
- # type: () -> Any
- for x in ordereddict.__iter__(self):
- yield x, ordereddict.__getitem__(self, x)
-
- if PY2:
-
- def items(self):
- # type: () -> Any
- return list(self._items())
-
- def iteritems(self):
- # type: () -> Any
- return self._items()
-
- def viewitems(self):
- # type: () -> Any
- return CommentedMapItemsView(self)
-
- else:
-
- def items(self):
- # type: () -> Any
- return CommentedMapItemsView(self)
-
- @property
- def merge(self):
- # type: () -> Any
- if not hasattr(self, merge_attrib):
- setattr(self, merge_attrib, [])
- return getattr(self, merge_attrib)
-
- def copy(self):
- # type: () -> Any
- x = type(self)() # update doesn't work
- for k, v in self._items():
- x[k] = v
- self.copy_attributes(x)
- return x
-
- def add_referent(self, cm):
- # type: (Any) -> None
- if cm not in self._ref:
- self._ref.append(cm)
-
- def add_yaml_merge(self, value):
- # type: (Any) -> None
- for v in value:
- v[1].add_referent(self)
- for k, v in v[1].items():
- if ordereddict.__contains__(self, k):
- continue
- ordereddict.__setitem__(self, k, v)
- self.merge.extend(value)
-
- def update_key_value(self, key):
- # type: (Any) -> None
- if key in self._ok:
- return
- for v in self.merge:
- if key in v[1]:
- ordereddict.__setitem__(self, key, v[1][key])
- return
- ordereddict.__delitem__(self, key)
-
- def __deepcopy__(self, memo):
- # type: (Any) -> Any
- res = self.__class__()
- memo[id(self)] = res
- for k in self:
- res[k] = copy.deepcopy(self[k], memo)
- self.copy_attributes(res, memo=memo)
- return res
-
-
-# based on brownie mappings
-@classmethod # type: ignore
-def raise_immutable(cls, *args, **kwargs):
- # type: (Any, *Any, **Any) -> None
- raise TypeError("{} objects are immutable".format(cls.__name__))
-
-
-class CommentedKeyMap(CommentedBase, Mapping): # type: ignore
- __slots__ = Comment.attrib, "_od"
- """This primarily exists to be able to roundtrip keys that are mappings"""
-
- def __init__(self, *args, **kw):
- # type: (Any, Any) -> None
- if hasattr(self, "_od"):
- raise_immutable(self)
- try:
- self._od = ordereddict(*args, **kw)
- except TypeError:
- if PY2:
- self._od = ordereddict(args[0].items())
- else:
- raise
-
- __delitem__ = (
- __setitem__
- ) = clear = pop = popitem = setdefault = update = raise_immutable
-
- # need to implement __getitem__, __iter__ and __len__
- def __getitem__(self, index):
- # type: (Any) -> Any
- return self._od[index]
-
- def __iter__(self):
- # type: () -> Iterator[Any]
- for x in self._od.__iter__():
- yield x
-
- def __len__(self):
- # type: () -> int
- return len(self._od)
-
- def __hash__(self):
- # type: () -> Any
- return hash(tuple(self.items()))
-
- def __repr__(self):
- # type: () -> Any
- if not hasattr(self, merge_attrib):
- return self._od.__repr__()
- return "ordereddict(" + repr(list(self._od.items())) + ")"
-
- @classmethod
- def fromkeys(keys, v=None):
- # type: (Any, Any) -> Any
- return CommentedKeyMap(dict.fromkeys(keys, v))
-
- def _yaml_add_comment(self, comment, key=NoComment):
- # type: (Any, Optional[Any]) -> None
- if key is not NoComment:
- self.yaml_key_comment_extend(key, comment)
- else:
- self.ca.comment = comment
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- self._yaml_add_comment(comment, key=key)
-
- def _yaml_get_columnX(self, key):
- # type: (Any) -> Any
- return self.ca.items[key][0].start_mark.column
-
- def _yaml_get_column(self, key):
- # type: (Any) -> Any
- column = None
- sel_idx = None
- pre, post = key - 1, key + 1
- if pre in self.ca.items:
- sel_idx = pre
- elif post in self.ca.items:
- sel_idx = post
- else:
- # self.ca.items is not ordered
- for row_idx, _k1 in enumerate(self):
- if row_idx >= key:
- break
- if row_idx not in self.ca.items:
- continue
- sel_idx = row_idx
- if sel_idx is not None:
- column = self._yaml_get_columnX(sel_idx)
- return column
-
- def _yaml_get_pre_comment(self):
- # type: () -> Any
- pre_comments = [] # type: List[Any]
- if self.ca.comment is None:
- self.ca.comment = [None, pre_comments]
- else:
- self.ca.comment[1] = pre_comments
- return pre_comments
-
-
-class CommentedOrderedMap(CommentedMap):
- __slots__ = (Comment.attrib,)
-
-
-class CommentedSet(MutableSet, CommentedBase): # type: ignore # NOQA
- __slots__ = Comment.attrib, "odict"
-
- def __init__(self, values=None):
- # type: (Any) -> None
- self.odict = ordereddict()
- MutableSet.__init__(self)
- if values is not None:
- self |= values # type: ignore
-
- def _yaml_add_comment(self, comment, key=NoComment, value=NoComment):
- # type: (Any, Optional[Any], Optional[Any]) -> None
- """values is set to key to indicate a value attachment of comment"""
- if key is not NoComment:
- self.yaml_key_comment_extend(key, comment)
- return
- if value is not NoComment:
- self.yaml_value_comment_extend(value, comment)
- else:
- self.ca.comment = comment
-
- def _yaml_add_eol_comment(self, comment, key):
- # type: (Any, Any) -> None
- """add on the value line, with value specified by the key"""
- self._yaml_add_comment(comment, value=key)
-
- def add(self, value):
- # type: (Any) -> None
- """Add an element."""
- self.odict[value] = None
-
- def discard(self, value):
- # type: (Any) -> None
- """Remove an element. Do not raise an exception if absent."""
- del self.odict[value]
-
- def __contains__(self, x):
- # type: (Any) -> Any
- return x in self.odict
-
- def __iter__(self):
- # type: () -> Any
- for x in self.odict:
- yield x
-
- def __len__(self):
- # type: () -> int
- return len(self.odict)
-
- def __repr__(self):
- # type: () -> str
- return "set({0!r})".format(self.odict.keys())
-
-
-class TaggedScalar(CommentedBase):
- # the value and style attributes are set during roundtrip construction
- def __init__(self, value=None, style=None, tag=None):
- # type: (Any, Any, Any) -> None
- self.value = value
- self.style = style
- if tag is not None:
- self.yaml_set_tag(tag)
-
- def __str__(self):
- # type: () -> Any
- return self.value
-
-
-def dump_comments(d, name="", sep=".", out=sys.stdout):
- # type: (Any, str, str, Any) -> None
- """
- recursively dump comments, all but the toplevel preceded by the path
- in dotted form x.0.a
- """
- if isinstance(d, dict) and hasattr(d, "ca"):
- if name:
- sys.stdout.write("{}\n".format(name))
- out.write("{}\n".format(d.ca)) # type: ignore
- for k in d:
- dump_comments(d[k], name=(name + sep + k) if name else k, sep=sep, out=out)
- elif isinstance(d, list) and hasattr(d, "ca"):
- if name:
- sys.stdout.write("{}\n".format(name))
- out.write("{}\n".format(d.ca)) # type: ignore
- for idx, k in enumerate(d):
- dump_comments(
- k, name=(name + sep + str(idx)) if name else str(idx), sep=sep, out=out
- )
diff --git a/srsly/ruamel_yaml/compat.py b/srsly/ruamel_yaml/compat.py
deleted file mode 100755
index 5544205..0000000
--- a/srsly/ruamel_yaml/compat.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-# partially from package six by Benjamin Peterson
-
-import sys
-import os
-import types
-import traceback
-from abc import abstractmethod
-from collections import OrderedDict # type: ignore
-
-
-# fmt: off
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Union, BinaryIO, IO, Text, Tuple # NOQA
- from typing import Optional # NOQA
-# fmt: on
-
-_DEFAULT_YAML_VERSION = (1, 2)
-
-
-class ordereddict(OrderedDict): # type: ignore
- if not hasattr(OrderedDict, "insert"):
-
- def insert(self, pos, key, value):
- # type: (int, Any, Any) -> None
- if pos >= len(self):
- self[key] = value
- return
- od = ordereddict()
- od.update(self)
- for k in od:
- del self[k]
- for index, old_key in enumerate(od):
- if pos == index:
- self[key] = value
- self[old_key] = od[old_key]
-
-
-PY2 = sys.version_info[0] == 2
-PY3 = sys.version_info[0] == 3
-
-
-if PY3:
-
- def utf8(s):
- # type: (str) -> str
- return s
-
- def to_str(s):
- # type: (str) -> str
- return s
-
- def to_unicode(s):
- # type: (str) -> str
- return s
-
-
-else:
- if False:
- unicode = str
-
- def utf8(s):
- # type: (unicode) -> str
- return s.encode("utf-8")
-
- def to_str(s):
- # type: (str) -> str
- return str(s)
-
- def to_unicode(s):
- # type: (str) -> unicode
- return unicode(s) # NOQA
-
-
-if PY3:
- string_types = str
- integer_types = int
- class_types = type
- text_type = str
- binary_type = bytes
-
- MAXSIZE = sys.maxsize
- unichr = chr
- import io
-
- StringIO = io.StringIO
- BytesIO = io.BytesIO
- # have unlimited precision
- no_limit_int = int
- from collections.abc import (
- Hashable,
- MutableSequence,
- MutableMapping,
- Mapping,
- ) # NOQA
-
-else:
- string_types = basestring # NOQA
- integer_types = (int, long) # NOQA
- class_types = (type, types.ClassType)
- text_type = unicode # NOQA
- binary_type = str
-
- # to allow importing
- unichr = unichr
- from StringIO import StringIO as _StringIO
-
- StringIO = _StringIO
- import cStringIO
-
- BytesIO = cStringIO.StringIO
- # have unlimited precision
- no_limit_int = long # NOQA not available on Python 3
- from collections import Hashable, MutableSequence, MutableMapping, Mapping # NOQA
-
-if False: # MYPY
- # StreamType = Union[BinaryIO, IO[str], IO[unicode], StringIO]
- # StreamType = Union[BinaryIO, IO[str], StringIO] # type: ignore
- StreamType = Any
-
- StreamTextType = StreamType # Union[Text, StreamType]
- VersionType = Union[List[int], str, Tuple[int, int]]
-
-if PY3:
- builtins_module = "builtins"
-else:
- builtins_module = "__builtin__"
-
-UNICODE_SIZE = 4 if sys.maxunicode > 65535 else 2
-
-
-def with_metaclass(meta, *bases):
- # type: (Any, Any) -> Any
- """Create a base class with a metaclass."""
- return meta("NewBase", bases, {})
-
-
-DBG_TOKEN = 1
-DBG_EVENT = 2
-DBG_NODE = 4
-
-
-_debug = None # type: Optional[int]
-if "RUAMELDEBUG" in os.environ:
- _debugx = os.environ.get("RUAMELDEBUG")
- if _debugx is None:
- _debug = 0
- else:
- _debug = int(_debugx)
-
-
-if bool(_debug):
-
- class ObjectCounter(object):
- def __init__(self):
- # type: () -> None
- self.map = {} # type: Dict[Any, Any]
-
- def __call__(self, k):
- # type: (Any) -> None
- self.map[k] = self.map.get(k, 0) + 1
-
- def dump(self):
- # type: () -> None
- for k in sorted(self.map):
- sys.stdout.write("{} -> {}".format(k, self.map[k]))
-
- object_counter = ObjectCounter()
-
-
-# used from yaml util when testing
-def dbg(val=None):
- # type: (Any) -> Any
- global _debug
- if _debug is None:
- # set to true or false
- _debugx = os.environ.get("YAMLDEBUG")
- if _debugx is None:
- _debug = 0
- else:
- _debug = int(_debugx)
- if val is None:
- return _debug
- return _debug & val
-
-
-class Nprint(object):
- def __init__(self, file_name=None):
- # type: (Any) -> None
- self._max_print = None # type: Any
- self._count = None # type: Any
- self._file_name = file_name
-
- def __call__(self, *args, **kw):
- # type: (Any, Any) -> None
- if not bool(_debug):
- return
- out = sys.stdout if self._file_name is None else open(self._file_name, "a")
- dbgprint = print # to fool checking for print statements by dv utility
- kw1 = kw.copy()
- kw1["file"] = out
- dbgprint(*args, **kw1)
- out.flush()
- if self._max_print is not None:
- if self._count is None:
- self._count = self._max_print
- self._count -= 1
- if self._count == 0:
- dbgprint("forced exit\n")
- traceback.print_stack()
- out.flush()
- sys.exit(0)
- if self._file_name:
- out.close()
-
- def set_max_print(self, i):
- # type: (int) -> None
- self._max_print = i
- self._count = None
-
-
-nprint = Nprint()
-nprintf = Nprint("/var/tmp/srsly.ruamel_yaml.log")
-
-# char checkers following production rules
-
-
-def check_namespace_char(ch):
- # type: (Any) -> bool
- if u"\x21" <= ch <= u"\x7E": # ! to ~
- return True
- if u"\xA0" <= ch <= u"\uD7FF":
- return True
- if (u"\uE000" <= ch <= u"\uFFFD") and ch != u"\uFEFF": # excl. byte order mark
- return True
- if u"\U00010000" <= ch <= u"\U0010FFFF":
- return True
- return False
-
-
-def check_anchorname_char(ch):
- # type: (Any) -> bool
- if ch in u",[]{}":
- return False
- return check_namespace_char(ch)
-
-
-def version_tnf(t1, t2=None):
- # type: (Any, Any) -> Any
- """
- return True if srsly.ruamel_yaml version_info < t1, None if t2 is specified and bigger else False
- """
- from srsly.ruamel_yaml import version_info # NOQA
-
- if version_info < t1:
- return True
- if t2 is not None and version_info < t2:
- return None
- return False
-
-
-class MutableSliceableSequence(MutableSequence): # type: ignore
- __slots__ = ()
-
- def __getitem__(self, index):
- # type: (Any) -> Any
- if not isinstance(index, slice):
- return self.__getsingleitem__(index)
- return type(self)(
- [self[i] for i in range(*index.indices(len(self)))]
- ) # type: ignore
-
- def __setitem__(self, index, value):
- # type: (Any, Any) -> None
- if not isinstance(index, slice):
- return self.__setsingleitem__(index, value)
- assert iter(value)
- # nprint(index.start, index.stop, index.step, index.indices(len(self)))
- if index.step is None:
- del self[index.start : index.stop]
- for elem in reversed(value):
- self.insert(0 if index.start is None else index.start, elem)
- else:
- range_parms = index.indices(len(self))
- nr_assigned_items = (range_parms[1] - range_parms[0] - 1) // range_parms[
- 2
- ] + 1
- # need to test before changing, in case TypeError is caught
- if nr_assigned_items < len(value):
- raise TypeError(
- "too many elements in value {} < {}".format(
- nr_assigned_items, len(value)
- )
- )
- elif nr_assigned_items > len(value):
- raise TypeError(
- "not enough elements in value {} > {}".format(
- nr_assigned_items, len(value)
- )
- )
- for idx, i in enumerate(range(*range_parms)):
- self[i] = value[idx]
-
- def __delitem__(self, index):
- # type: (Any) -> None
- if not isinstance(index, slice):
- return self.__delsingleitem__(index)
- # nprint(index.start, index.stop, index.step, index.indices(len(self)))
- for i in reversed(range(*index.indices(len(self)))):
- del self[i]
-
- @abstractmethod
- def __getsingleitem__(self, index):
- # type: (Any) -> Any
- raise IndexError
-
- @abstractmethod
- def __setsingleitem__(self, index, value):
- # type: (Any, Any) -> None
- raise IndexError
-
- @abstractmethod
- def __delsingleitem__(self, index):
- # type: (Any) -> None
- raise IndexError
diff --git a/srsly/ruamel_yaml/composer.py b/srsly/ruamel_yaml/composer.py
deleted file mode 100755
index 27d5a48..0000000
--- a/srsly/ruamel_yaml/composer.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import, print_function
-
-import warnings
-
-from .error import MarkedYAMLError, ReusedAnchorWarning
-from .compat import utf8, nprint, nprintf # NOQA
-
-from .events import (
- StreamStartEvent,
- StreamEndEvent,
- MappingStartEvent,
- MappingEndEvent,
- SequenceStartEvent,
- SequenceEndEvent,
- AliasEvent,
- ScalarEvent,
-)
-from .nodes import MappingNode, ScalarNode, SequenceNode
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List # NOQA
-
-__all__ = ["Composer", "ComposerError"]
-
-
-class ComposerError(MarkedYAMLError):
- pass
-
-
-class Composer(object):
- def __init__(self, loader=None):
- # type: (Any) -> None
- self.loader = loader
- if self.loader is not None and getattr(self.loader, "_composer", None) is None:
- self.loader._composer = self
- self.anchors = {} # type: Dict[Any, Any]
-
- @property
- def parser(self):
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- self.loader.parser
- return self.loader._parser
-
- @property
- def resolver(self):
- # type: () -> Any
- # assert self.loader._resolver is not None
- if hasattr(self.loader, "typ"):
- self.loader.resolver
- return self.loader._resolver
-
- def check_node(self):
- # type: () -> Any
- # Drop the STREAM-START event.
- if self.parser.check_event(StreamStartEvent):
- self.parser.get_event()
-
- # If there are more documents available?
- return not self.parser.check_event(StreamEndEvent)
-
- def get_node(self):
- # type: () -> Any
- # Get the root node of the next document.
- if not self.parser.check_event(StreamEndEvent):
- return self.compose_document()
-
- def get_single_node(self):
- # type: () -> Any
- # Drop the STREAM-START event.
- self.parser.get_event()
-
- # Compose a document if the stream is not empty.
- document = None # type: Any
- if not self.parser.check_event(StreamEndEvent):
- document = self.compose_document()
-
- # Ensure that the stream contains no more documents.
- if not self.parser.check_event(StreamEndEvent):
- event = self.parser.get_event()
- raise ComposerError(
- "expected a single document in the stream",
- document.start_mark,
- "but found another document",
- event.start_mark,
- )
-
- # Drop the STREAM-END event.
- self.parser.get_event()
-
- return document
-
- def compose_document(self):
- # type: (Any) -> Any
- # Drop the DOCUMENT-START event.
- self.parser.get_event()
-
- # Compose the root node.
- node = self.compose_node(None, None)
-
- # Drop the DOCUMENT-END event.
- self.parser.get_event()
-
- self.anchors = {}
- return node
-
- def compose_node(self, parent, index):
- # type: (Any, Any) -> Any
- if self.parser.check_event(AliasEvent):
- event = self.parser.get_event()
- alias = event.anchor
- if alias not in self.anchors:
- raise ComposerError(
- None,
- None,
- "found undefined alias %r" % utf8(alias),
- event.start_mark,
- )
- return self.anchors[alias]
- event = self.parser.peek_event()
- anchor = event.anchor
- if anchor is not None: # have an anchor
- if anchor in self.anchors:
- # raise ComposerError(
- # "found duplicate anchor %r; first occurrence"
- # % utf8(anchor), self.anchors[anchor].start_mark,
- # "second occurrence", event.start_mark)
- ws = (
- "\nfound duplicate anchor {!r}\nfirst occurrence {}\nsecond occurrence "
- "{}".format(
- (anchor), self.anchors[anchor].start_mark, event.start_mark
- )
- )
- warnings.warn(ws, ReusedAnchorWarning)
- self.resolver.descend_resolver(parent, index)
- if self.parser.check_event(ScalarEvent):
- node = self.compose_scalar_node(anchor)
- elif self.parser.check_event(SequenceStartEvent):
- node = self.compose_sequence_node(anchor)
- elif self.parser.check_event(MappingStartEvent):
- node = self.compose_mapping_node(anchor)
- self.resolver.ascend_resolver()
- return node
-
- def compose_scalar_node(self, anchor):
- # type: (Any) -> Any
- event = self.parser.get_event()
- tag = event.tag
- if tag is None or tag == u"!":
- tag = self.resolver.resolve(ScalarNode, event.value, event.implicit)
- node = ScalarNode(
- tag,
- event.value,
- event.start_mark,
- event.end_mark,
- style=event.style,
- comment=event.comment,
- anchor=anchor,
- )
- if anchor is not None:
- self.anchors[anchor] = node
- return node
-
- def compose_sequence_node(self, anchor):
- # type: (Any) -> Any
- start_event = self.parser.get_event()
- tag = start_event.tag
- if tag is None or tag == u"!":
- tag = self.resolver.resolve(SequenceNode, None, start_event.implicit)
- node = SequenceNode(
- tag,
- [],
- start_event.start_mark,
- None,
- flow_style=start_event.flow_style,
- comment=start_event.comment,
- anchor=anchor,
- )
- if anchor is not None:
- self.anchors[anchor] = node
- index = 0
- while not self.parser.check_event(SequenceEndEvent):
- node.value.append(self.compose_node(node, index))
- index += 1
- end_event = self.parser.get_event()
- if node.flow_style is True and end_event.comment is not None:
- if node.comment is not None:
- nprint(
- "Warning: unexpected end_event commment in sequence "
- "node {}".format(node.flow_style)
- )
- node.comment = end_event.comment
- node.end_mark = end_event.end_mark
- self.check_end_doc_comment(end_event, node)
- return node
-
- def compose_mapping_node(self, anchor):
- # type: (Any) -> Any
- start_event = self.parser.get_event()
- tag = start_event.tag
- if tag is None or tag == u"!":
- tag = self.resolver.resolve(MappingNode, None, start_event.implicit)
- node = MappingNode(
- tag,
- [],
- start_event.start_mark,
- None,
- flow_style=start_event.flow_style,
- comment=start_event.comment,
- anchor=anchor,
- )
- if anchor is not None:
- self.anchors[anchor] = node
- while not self.parser.check_event(MappingEndEvent):
- # key_event = self.parser.peek_event()
- item_key = self.compose_node(node, None)
- # if item_key in node.value:
- # raise ComposerError("while composing a mapping",
- # start_event.start_mark,
- # "found duplicate key", key_event.start_mark)
- item_value = self.compose_node(node, item_key)
- # node.value[item_key] = item_value
- node.value.append((item_key, item_value))
- end_event = self.parser.get_event()
- if node.flow_style is True and end_event.comment is not None:
- node.comment = end_event.comment
- node.end_mark = end_event.end_mark
- self.check_end_doc_comment(end_event, node)
- return node
-
- def check_end_doc_comment(self, end_event, node):
- # type: (Any, Any) -> None
- if end_event.comment and end_event.comment[1]:
- # pre comments on an end_event, no following to move to
- if node.comment is None:
- node.comment = [None, None]
- assert not isinstance(node, ScalarEvent)
- # this is a post comment on a mapping node, add as third element
- # in the list
- node.comment.append(end_event.comment[1])
- end_event.comment[1] = None
diff --git a/srsly/ruamel_yaml/configobjwalker.py b/srsly/ruamel_yaml/configobjwalker.py
deleted file mode 100755
index 511184e..0000000
--- a/srsly/ruamel_yaml/configobjwalker.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# coding: utf-8
-
-import warnings
-
-from .util import configobj_walker as new_configobj_walker
-
-if False: # MYPY
- from typing import Any # NOQA
-
-
-def configobj_walker(cfg):
- # type: (Any) -> Any
- warnings.warn(
- "configobj_walker has moved to srsly.ruamel_yaml.util, please update your code"
- )
- return new_configobj_walker(cfg)
diff --git a/srsly/ruamel_yaml/constructor.py b/srsly/ruamel_yaml/constructor.py
deleted file mode 100755
index b600339..0000000
--- a/srsly/ruamel_yaml/constructor.py
+++ /dev/null
@@ -1,1706 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division
-
-import datetime
-import base64
-import binascii
-import re
-import sys
-import types
-import warnings
-
-# fmt: off
-from .error import (MarkedYAMLError, MarkedYAMLFutureWarning,
- MantissaNoDotYAML1_1Warning)
-from .nodes import * # NOQA
-from .nodes import (SequenceNode, MappingNode, ScalarNode)
-from .compat import (utf8, builtins_module, to_str, PY2, PY3, # NOQA
- text_type, nprint, nprintf, version_tnf)
-from .compat import ordereddict, Hashable, MutableSequence # type: ignore
-from .compat import MutableMapping # type: ignore
-
-from .comments import * # NOQA
-from .comments import (CommentedMap, CommentedOrderedMap, CommentedSet,
- CommentedKeySeq, CommentedSeq, TaggedScalar,
- CommentedKeyMap)
-from .scalarstring import (SingleQuotedScalarString, DoubleQuotedScalarString,
- LiteralScalarString, FoldedScalarString,
- PlainScalarString, ScalarString,)
-from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
-from .scalarfloat import ScalarFloat
-from .scalarbool import ScalarBoolean
-from .timestamp import TimeStamp
-from .util import RegExp
-
-if False: # MYPY
- from typing import Any, Dict, List, Set, Generator, Union, Optional # NOQA
-
-
-__all__ = ['BaseConstructor', 'SafeConstructor', 'Constructor',
- 'ConstructorError', 'RoundTripConstructor']
-# fmt: on
-
-
-class ConstructorError(MarkedYAMLError):
- pass
-
-
-class DuplicateKeyFutureWarning(MarkedYAMLFutureWarning):
- pass
-
-
-class DuplicateKeyError(MarkedYAMLFutureWarning):
- pass
-
-
-class BaseConstructor(object):
-
- yaml_constructors = {} # type: Dict[Any, Any]
- yaml_multi_constructors = {} # type: Dict[Any, Any]
-
- def __init__(self, preserve_quotes=None, loader=None):
- # type: (Optional[bool], Any) -> None
- self.loader = loader
- if (
- self.loader is not None
- and getattr(self.loader, "_constructor", None) is None
- ):
- self.loader._constructor = self
- self.loader = loader
- self.yaml_base_dict_type = dict
- self.yaml_base_list_type = list
- self.constructed_objects = {} # type: Dict[Any, Any]
- self.recursive_objects = {} # type: Dict[Any, Any]
- self.state_generators = [] # type: List[Any]
- self.deep_construct = False
- self._preserve_quotes = preserve_quotes
- self.allow_duplicate_keys = version_tnf((0, 15, 1), (0, 16))
-
- @property
- def composer(self):
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- return self.loader.composer
- try:
- return self.loader._composer
- except AttributeError:
- sys.stdout.write("slt {}\n".format(type(self)))
- sys.stdout.write("slc {}\n".format(self.loader._composer))
- sys.stdout.write("{}\n".format(dir(self)))
- raise
-
- @property
- def resolver(self):
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- return self.loader.resolver
- return self.loader._resolver
-
- def check_data(self):
- # type: () -> Any
- # If there are more documents available?
- return self.composer.check_node()
-
- def get_data(self):
- # type: () -> Any
- # Construct and return the next document.
- if self.composer.check_node():
- return self.construct_document(self.composer.get_node())
-
- def get_single_data(self):
- # type: () -> Any
- # Ensure that the stream contains a single document and construct it.
- node = self.composer.get_single_node()
- if node is not None:
- return self.construct_document(node)
- return None
-
- def construct_document(self, node):
- # type: (Any) -> Any
- data = self.construct_object(node)
- while bool(self.state_generators):
- state_generators = self.state_generators
- self.state_generators = []
- for generator in state_generators:
- for _dummy in generator:
- pass
- self.constructed_objects = {}
- self.recursive_objects = {}
- self.deep_construct = False
- return data
-
- def construct_object(self, node, deep=False):
- # type: (Any, bool) -> Any
- """deep is True when creating an object/mapping recursively,
- in that case want the underlying elements available during construction
- """
- if node in self.constructed_objects:
- return self.constructed_objects[node]
- if deep:
- old_deep = self.deep_construct
- self.deep_construct = True
- if node in self.recursive_objects:
- return self.recursive_objects[node]
- # raise ConstructorError(
- # None, None, 'found unconstructable recursive node', node.start_mark
- # )
- self.recursive_objects[node] = None
- data = self.construct_non_recursive_object(node)
-
- self.constructed_objects[node] = data
- del self.recursive_objects[node]
- if deep:
- self.deep_construct = old_deep
- return data
-
- def construct_non_recursive_object(self, node, tag=None):
- # type: (Any, Optional[str]) -> Any
- constructor = None # type: Any
- tag_suffix = None
- if tag is None:
- tag = node.tag
- if tag in self.yaml_constructors:
- constructor = self.yaml_constructors[tag]
- else:
- for tag_prefix in self.yaml_multi_constructors:
- if tag.startswith(tag_prefix):
- tag_suffix = tag[len(tag_prefix) :]
- constructor = self.yaml_multi_constructors[tag_prefix]
- break
- else:
- if None in self.yaml_multi_constructors:
- tag_suffix = tag
- constructor = self.yaml_multi_constructors[None]
- elif None in self.yaml_constructors:
- constructor = self.yaml_constructors[None]
- elif isinstance(node, ScalarNode):
- constructor = self.__class__.construct_scalar
- elif isinstance(node, SequenceNode):
- constructor = self.__class__.construct_sequence
- elif isinstance(node, MappingNode):
- constructor = self.__class__.construct_mapping
- if tag_suffix is None:
- data = constructor(self, node)
- else:
- data = constructor(self, tag_suffix, node)
- if isinstance(data, types.GeneratorType):
- generator = data
- data = next(generator)
- if self.deep_construct:
- for _dummy in generator:
- pass
- else:
- self.state_generators.append(generator)
- return data
-
- def construct_scalar(self, node):
- # type: (Any) -> Any
- if not isinstance(node, ScalarNode):
- raise ConstructorError(
- None,
- None,
- "expected a scalar node, but found %s" % node.id,
- node.start_mark,
- )
- return node.value
-
- def construct_sequence(self, node, deep=False):
- # type: (Any, bool) -> Any
- """deep is True when creating an object/mapping recursively,
- in that case want the underlying elements available during construction
- """
- if not isinstance(node, SequenceNode):
- raise ConstructorError(
- None,
- None,
- "expected a sequence node, but found %s" % node.id,
- node.start_mark,
- )
- return [self.construct_object(child, deep=deep) for child in node.value]
-
- def construct_mapping(self, node, deep=False):
- # type: (Any, bool) -> Any
- """deep is True when creating an object/mapping recursively,
- in that case want the underlying elements available during construction
- """
- if not isinstance(node, MappingNode):
- raise ConstructorError(
- None,
- None,
- "expected a mapping node, but found %s" % node.id,
- node.start_mark,
- )
- total_mapping = self.yaml_base_dict_type()
- if getattr(node, "merge", None) is not None:
- todo = [(node.merge, False), (node.value, False)]
- else:
- todo = [(node.value, True)]
- for values, check in todo:
- mapping = self.yaml_base_dict_type() # type: Dict[Any, Any]
- for key_node, value_node in values:
- # keys can be list -> deep
- key = self.construct_object(key_node, deep=True)
- # lists are not hashable, but tuples are
- if not isinstance(key, Hashable):
- if isinstance(key, list):
- key = tuple(key)
- if PY2:
- try:
- hash(key)
- except TypeError as exc:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unacceptable key (%s)" % exc,
- key_node.start_mark,
- )
- else:
- if not isinstance(key, Hashable):
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unhashable key",
- key_node.start_mark,
- )
-
- value = self.construct_object(value_node, deep=deep)
- if check:
- if self.check_mapping_key(node, key_node, mapping, key, value):
- mapping[key] = value
- else:
- mapping[key] = value
- total_mapping.update(mapping)
- return total_mapping
-
- def check_mapping_key(self, node, key_node, mapping, key, value):
- # type: (Any, Any, Any, Any, Any) -> bool
- """return True if key is unique"""
- if key in mapping:
- if not self.allow_duplicate_keys:
- mk = mapping.get(key)
- if PY2:
- if isinstance(key, unicode):
- key = key.encode("utf-8")
- if isinstance(value, unicode):
- value = value.encode("utf-8")
- if isinstance(mk, unicode):
- mk = mk.encode("utf-8")
- args = [
- "while constructing a mapping",
- node.start_mark,
- 'found duplicate key "{}" with value "{}" '
- '(original value: "{}")'.format(key, value, mk),
- key_node.start_mark,
- """
- To suppress this check see:
- http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
- """,
- """\
- Duplicate keys will become an error in future releases, and are errors
- by default when using the new API.
- """,
- ]
- if self.allow_duplicate_keys is None:
- warnings.warn(DuplicateKeyFutureWarning(*args))
- else:
- raise DuplicateKeyError(*args)
- return False
- return True
-
- def check_set_key(self, node, key_node, setting, key):
- # type: (Any, Any, Any, Any, Any) -> None
- if key in setting:
- if not self.allow_duplicate_keys:
- if PY2:
- if isinstance(key, unicode):
- key = key.encode("utf-8")
- args = [
- "while constructing a set",
- node.start_mark,
- 'found duplicate key "{}"'.format(key),
- key_node.start_mark,
- """
- To suppress this check see:
- http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
- """,
- """\
- Duplicate keys will become an error in future releases, and are errors
- by default when using the new API.
- """,
- ]
- if self.allow_duplicate_keys is None:
- warnings.warn(DuplicateKeyFutureWarning(*args))
- else:
- raise DuplicateKeyError(*args)
-
- def construct_pairs(self, node, deep=False):
- # type: (Any, bool) -> Any
- if not isinstance(node, MappingNode):
- raise ConstructorError(
- None,
- None,
- "expected a mapping node, but found %s" % node.id,
- node.start_mark,
- )
- pairs = []
- for key_node, value_node in node.value:
- key = self.construct_object(key_node, deep=deep)
- value = self.construct_object(value_node, deep=deep)
- pairs.append((key, value))
- return pairs
-
- @classmethod
- def add_constructor(cls, tag, constructor):
- # type: (Any, Any) -> None
- if "yaml_constructors" not in cls.__dict__:
- cls.yaml_constructors = cls.yaml_constructors.copy()
- cls.yaml_constructors[tag] = constructor
-
- @classmethod
- def add_multi_constructor(cls, tag_prefix, multi_constructor):
- # type: (Any, Any) -> None
- if "yaml_multi_constructors" not in cls.__dict__:
- cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
- cls.yaml_multi_constructors[tag_prefix] = multi_constructor
-
-
-class SafeConstructor(BaseConstructor):
- def construct_scalar(self, node):
- # type: (Any) -> Any
- if isinstance(node, MappingNode):
- for key_node, value_node in node.value:
- if key_node.tag == u"tag:yaml.org,2002:value":
- return self.construct_scalar(value_node)
- return BaseConstructor.construct_scalar(self, node)
-
- def flatten_mapping(self, node):
- # type: (Any) -> Any
- """
- This implements the merge key feature http://yaml.org/type/merge.html
- by inserting keys from the merge dict/list of dicts if not yet
- available in this node
- """
- merge = [] # type: List[Any]
- index = 0
- while index < len(node.value):
- key_node, value_node = node.value[index]
- if key_node.tag == u"tag:yaml.org,2002:merge":
- if merge: # double << key
- if self.allow_duplicate_keys:
- del node.value[index]
- index += 1
- continue
- args = [
- "while constructing a mapping",
- node.start_mark,
- 'found duplicate key "{}"'.format(key_node.value),
- key_node.start_mark,
- """
- To suppress this check see:
- http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
- """,
- """\
- Duplicate keys will become an error in future releases, and are errors
- by default when using the new API.
- """,
- ]
- if self.allow_duplicate_keys is None:
- warnings.warn(DuplicateKeyFutureWarning(*args))
- else:
- raise DuplicateKeyError(*args)
- del node.value[index]
- if isinstance(value_node, MappingNode):
- self.flatten_mapping(value_node)
- merge.extend(value_node.value)
- elif isinstance(value_node, SequenceNode):
- submerge = []
- for subnode in value_node.value:
- if not isinstance(subnode, MappingNode):
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "expected a mapping for merging, but found %s"
- % subnode.id,
- subnode.start_mark,
- )
- self.flatten_mapping(subnode)
- submerge.append(subnode.value)
- submerge.reverse()
- for value in submerge:
- merge.extend(value)
- else:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "expected a mapping or list of mappings for merging, "
- "but found %s" % value_node.id,
- value_node.start_mark,
- )
- elif key_node.tag == u"tag:yaml.org,2002:value":
- key_node.tag = u"tag:yaml.org,2002:str"
- index += 1
- else:
- index += 1
- if bool(merge):
- node.merge = (
- merge
- ) # separate merge keys to be able to update without duplicate
- node.value = merge + node.value
-
- def construct_mapping(self, node, deep=False):
- # type: (Any, bool) -> Any
- """deep is True when creating an object/mapping recursively,
- in that case want the underlying elements available during construction
- """
- if isinstance(node, MappingNode):
- self.flatten_mapping(node)
- return BaseConstructor.construct_mapping(self, node, deep=deep)
-
- def construct_yaml_null(self, node):
- # type: (Any) -> Any
- self.construct_scalar(node)
- return None
-
- # YAML 1.2 spec doesn't mention yes/no etc any more, 1.1 does
- bool_values = {
- u"yes": True,
- u"no": False,
- u"y": True,
- u"n": False,
- u"true": True,
- u"false": False,
- u"on": True,
- u"off": False,
- }
-
- def construct_yaml_bool(self, node):
- # type: (Any) -> bool
- value = self.construct_scalar(node)
- return self.bool_values[value.lower()]
-
- def construct_yaml_int(self, node):
- # type: (Any) -> int
- value_s = to_str(self.construct_scalar(node))
- value_s = value_s.replace("_", "")
- sign = +1
- if value_s[0] == "-":
- sign = -1
- if value_s[0] in "+-":
- value_s = value_s[1:]
- if value_s == "0":
- return 0
- elif value_s.startswith("0b"):
- return sign * int(value_s[2:], 2)
- elif value_s.startswith("0x"):
- return sign * int(value_s[2:], 16)
- elif value_s.startswith("0o"):
- return sign * int(value_s[2:], 8)
- elif self.resolver.processing_version == (1, 1) and value_s[0] == "0":
- return sign * int(value_s, 8)
- elif self.resolver.processing_version == (1, 1) and ":" in value_s:
- digits = [int(part) for part in value_s.split(":")]
- digits.reverse()
- base = 1
- value = 0
- for digit in digits:
- value += digit * base
- base *= 60
- return sign * value
- else:
- return sign * int(value_s)
-
- inf_value = 1e300
- while inf_value != inf_value * inf_value:
- inf_value *= inf_value
- nan_value = -inf_value / inf_value # Trying to make a quiet NaN (like C99).
-
- def construct_yaml_float(self, node):
- # type: (Any) -> float
- value_so = to_str(self.construct_scalar(node))
- value_s = value_so.replace("_", "").lower()
- sign = +1
- if value_s[0] == "-":
- sign = -1
- if value_s[0] in "+-":
- value_s = value_s[1:]
- if value_s == ".inf":
- return sign * self.inf_value
- elif value_s == ".nan":
- return self.nan_value
- elif self.resolver.processing_version != (1, 2) and ":" in value_s:
- digits = [float(part) for part in value_s.split(":")]
- digits.reverse()
- base = 1
- value = 0.0
- for digit in digits:
- value += digit * base
- base *= 60
- return sign * value
- else:
- if self.resolver.processing_version != (1, 2) and "e" in value_s:
- # value_s is lower case independent of input
- mantissa, exponent = value_s.split("e")
- if "." not in mantissa:
- warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so))
- return sign * float(value_s)
-
- if PY3:
-
- def construct_yaml_binary(self, node):
- # type: (Any) -> Any
- try:
- value = self.construct_scalar(node).encode("ascii")
- except UnicodeEncodeError as exc:
- raise ConstructorError(
- None,
- None,
- "failed to convert base64 data into ascii: %s" % exc,
- node.start_mark,
- )
- try:
- if hasattr(base64, "decodebytes"):
- return base64.decodebytes(value)
- else:
- return base64.decodestring(value)
- except binascii.Error as exc:
- raise ConstructorError(
- None,
- None,
- "failed to decode base64 data: %s" % exc,
- node.start_mark,
- )
-
- else:
-
- def construct_yaml_binary(self, node):
- # type: (Any) -> Any
- value = self.construct_scalar(node)
- try:
- return to_str(value).decode("base64")
- except (binascii.Error, UnicodeEncodeError) as exc:
- raise ConstructorError(
- None,
- None,
- "failed to decode base64 data: %s" % exc,
- node.start_mark,
- )
-
- timestamp_regexp = RegExp(
- u"""^(?P[0-9][0-9][0-9][0-9])
- -(?P[0-9][0-9]?)
- -(?P[0-9][0-9]?)
- (?:((?P[Tt])|[ \\t]+) # explictly not retaining extra spaces
- (?P[0-9][0-9]?)
- :(?P[0-9][0-9])
- :(?P[0-9][0-9])
- (?:\\.(?P[0-9]*))?
- (?:[ \\t]*(?PZ|(?P[-+])(?P[0-9][0-9]?)
- (?::(?P[0-9][0-9]))?))?)?$""",
- re.X,
- )
-
- def construct_yaml_timestamp(self, node, values=None):
- # type: (Any, Any) -> Any
- if values is None:
- try:
- match = self.timestamp_regexp.match(node.value)
- except TypeError:
- match = None
- if match is None:
- raise ConstructorError(
- None,
- None,
- 'failed to construct timestamp from "{}"'.format(node.value),
- node.start_mark,
- )
- values = match.groupdict()
- year = int(values["year"])
- month = int(values["month"])
- day = int(values["day"])
- if not values["hour"]:
- return datetime.date(year, month, day)
- hour = int(values["hour"])
- minute = int(values["minute"])
- second = int(values["second"])
- fraction = 0
- if values["fraction"]:
- fraction_s = values["fraction"][:6]
- while len(fraction_s) < 6:
- fraction_s += "0"
- fraction = int(fraction_s)
- if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4:
- fraction += 1
- delta = None
- if values["tz_sign"]:
- tz_hour = int(values["tz_hour"])
- minutes = values["tz_minute"]
- tz_minute = int(minutes) if minutes else 0
- delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
- if values["tz_sign"] == "-":
- delta = -delta
- # should do something else instead (or hook this up to the preceding if statement
- # in reverse
- # if delta is None:
- # return datetime.datetime(year, month, day, hour, minute, second, fraction)
- # return datetime.datetime(year, month, day, hour, minute, second, fraction,
- # datetime.timezone.utc)
- # the above is not good enough though, should provide tzinfo. In Python3 that is easily
- # doable drop that kind of support for Python2 as it has not native tzinfo
- data = datetime.datetime(year, month, day, hour, minute, second, fraction)
- if delta:
- data -= delta
- return data
-
- def construct_yaml_omap(self, node):
- # type: (Any) -> Any
- # Note: we do now check for duplicate keys
- omap = ordereddict()
- yield omap
- if not isinstance(node, SequenceNode):
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a sequence, but found %s" % node.id,
- node.start_mark,
- )
- for subnode in node.value:
- if not isinstance(subnode, MappingNode):
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a mapping of length 1, but found %s" % subnode.id,
- subnode.start_mark,
- )
- if len(subnode.value) != 1:
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a single mapping item, but found %d items"
- % len(subnode.value),
- subnode.start_mark,
- )
- key_node, value_node = subnode.value[0]
- key = self.construct_object(key_node)
- assert key not in omap
- value = self.construct_object(value_node)
- omap[key] = value
-
- def construct_yaml_pairs(self, node):
- # type: (Any) -> Any
- # Note: the same code as `construct_yaml_omap`.
- pairs = [] # type: List[Any]
- yield pairs
- if not isinstance(node, SequenceNode):
- raise ConstructorError(
- "while constructing pairs",
- node.start_mark,
- "expected a sequence, but found %s" % node.id,
- node.start_mark,
- )
- for subnode in node.value:
- if not isinstance(subnode, MappingNode):
- raise ConstructorError(
- "while constructing pairs",
- node.start_mark,
- "expected a mapping of length 1, but found %s" % subnode.id,
- subnode.start_mark,
- )
- if len(subnode.value) != 1:
- raise ConstructorError(
- "while constructing pairs",
- node.start_mark,
- "expected a single mapping item, but found %d items"
- % len(subnode.value),
- subnode.start_mark,
- )
- key_node, value_node = subnode.value[0]
- key = self.construct_object(key_node)
- value = self.construct_object(value_node)
- pairs.append((key, value))
-
- def construct_yaml_set(self, node):
- # type: (Any) -> Any
- data = set() # type: Set[Any]
- yield data
- value = self.construct_mapping(node)
- data.update(value)
-
- def construct_yaml_str(self, node):
- # type: (Any) -> Any
- value = self.construct_scalar(node)
- if PY3:
- return value
- try:
- return value.encode("ascii")
- except UnicodeEncodeError:
- return value
-
- def construct_yaml_seq(self, node):
- # type: (Any) -> Any
- data = self.yaml_base_list_type() # type: List[Any]
- yield data
- data.extend(self.construct_sequence(node))
-
- def construct_yaml_map(self, node):
- # type: (Any) -> Any
- data = self.yaml_base_dict_type() # type: Dict[Any, Any]
- yield data
- value = self.construct_mapping(node)
- data.update(value)
-
- def construct_yaml_object(self, node, cls):
- # type: (Any, Any) -> Any
- data = cls.__new__(cls)
- yield data
- if hasattr(data, "__setstate__"):
- state = self.construct_mapping(node, deep=True)
- data.__setstate__(state)
- else:
- state = self.construct_mapping(node)
- data.__dict__.update(state)
-
- def construct_undefined(self, node):
- # type: (Any) -> None
- raise ConstructorError(
- None,
- None,
- "could not determine a constructor for the tag %r" % utf8(node.tag),
- node.start_mark,
- )
-
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:null", SafeConstructor.construct_yaml_null
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:bool", SafeConstructor.construct_yaml_bool
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:int", SafeConstructor.construct_yaml_int
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:float", SafeConstructor.construct_yaml_float
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:binary", SafeConstructor.construct_yaml_binary
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:timestamp", SafeConstructor.construct_yaml_timestamp
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:omap", SafeConstructor.construct_yaml_omap
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:pairs", SafeConstructor.construct_yaml_pairs
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:set", SafeConstructor.construct_yaml_set
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:str", SafeConstructor.construct_yaml_str
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:seq", SafeConstructor.construct_yaml_seq
-)
-
-SafeConstructor.add_constructor(
- u"tag:yaml.org,2002:map", SafeConstructor.construct_yaml_map
-)
-
-SafeConstructor.add_constructor(None, SafeConstructor.construct_undefined)
-
-if PY2:
-
- class classobj:
- pass
-
-
-class Constructor(SafeConstructor):
- def construct_python_str(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_unicode(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- if PY3:
-
- def construct_python_bytes(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_long(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_complex(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_tuple(self, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def find_python_module(self, name, mark):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def find_python_name(self, name, mark):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_name(self, suffix, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_module(self, suffix, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def set_python_instance_state(self, instance, state):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_object(self, suffix, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_object_apply(self, suffix, node, newobj=False):
- raise ValueError("Unsafe constructor not implemented in this library")
-
- def construct_python_object_new(self, suffix, node):
- raise ValueError("Unsafe constructor not implemented in this library")
-
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/none", Constructor.construct_yaml_null
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/bool", Constructor.construct_yaml_bool
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/str", Constructor.construct_python_str
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/unicode", Constructor.construct_python_unicode
-)
-
-if PY3:
- Constructor.add_constructor(
- u"tag:yaml.org,2002:python/bytes", Constructor.construct_python_bytes
- )
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/int", Constructor.construct_yaml_int
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/long", Constructor.construct_python_long
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/float", Constructor.construct_yaml_float
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/complex", Constructor.construct_python_complex
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/list", Constructor.construct_yaml_seq
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/tuple", Constructor.construct_python_tuple
-)
-
-Constructor.add_constructor(
- u"tag:yaml.org,2002:python/dict", Constructor.construct_yaml_map
-)
-
-Constructor.add_multi_constructor(
- u"tag:yaml.org,2002:python/name:", Constructor.construct_python_name
-)
-
-Constructor.add_multi_constructor(
- u"tag:yaml.org,2002:python/module:", Constructor.construct_python_module
-)
-
-Constructor.add_multi_constructor(
- u"tag:yaml.org,2002:python/object:", Constructor.construct_python_object
-)
-
-Constructor.add_multi_constructor(
- u"tag:yaml.org,2002:python/object/apply:", Constructor.construct_python_object_apply
-)
-
-Constructor.add_multi_constructor(
- u"tag:yaml.org,2002:python/object/new:", Constructor.construct_python_object_new
-)
-
-
-class RoundTripConstructor(SafeConstructor):
- """need to store the comments on the node itself,
- as well as on the items
- """
-
- def construct_scalar(self, node):
- # type: (Any) -> Any
- if not isinstance(node, ScalarNode):
- raise ConstructorError(
- None,
- None,
- "expected a scalar node, but found %s" % node.id,
- node.start_mark,
- )
-
- if node.style == "|" and isinstance(node.value, text_type):
- lss = LiteralScalarString(node.value, anchor=node.anchor)
- if node.comment and node.comment[1]:
- lss.comment = node.comment[1][0] # type: ignore
- return lss
- if node.style == ">" and isinstance(node.value, text_type):
- fold_positions = [] # type: List[int]
- idx = -1
- while True:
- idx = node.value.find("\a", idx + 1)
- if idx < 0:
- break
- fold_positions.append(idx - len(fold_positions))
- fss = FoldedScalarString(node.value.replace("\a", ""), anchor=node.anchor)
- if node.comment and node.comment[1]:
- fss.comment = node.comment[1][0] # type: ignore
- if fold_positions:
- fss.fold_pos = fold_positions # type: ignore
- return fss
- elif bool(self._preserve_quotes) and isinstance(node.value, text_type):
- if node.style == "'":
- return SingleQuotedScalarString(node.value, anchor=node.anchor)
- if node.style == '"':
- return DoubleQuotedScalarString(node.value, anchor=node.anchor)
- if node.anchor:
- return PlainScalarString(node.value, anchor=node.anchor)
- return node.value
-
- def construct_yaml_int(self, node):
- # type: (Any) -> Any
- width = None # type: Any
- value_su = to_str(self.construct_scalar(node))
- try:
- sx = value_su.rstrip("_")
- underscore = [len(sx) - sx.rindex("_") - 1, False, False] # type: Any
- except ValueError:
- underscore = None
- except IndexError:
- underscore = None
- value_s = value_su.replace("_", "")
- sign = +1
- if value_s[0] == "-":
- sign = -1
- if value_s[0] in "+-":
- value_s = value_s[1:]
- if value_s == "0":
- return 0
- elif value_s.startswith("0b"):
- if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
- width = len(value_s[2:])
- if underscore is not None:
- underscore[1] = value_su[2] == "_"
- underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
- return BinaryInt(
- sign * int(value_s[2:], 2),
- width=width,
- underscore=underscore,
- anchor=node.anchor,
- )
- elif value_s.startswith("0x"):
- # default to lower-case if no a-fA-F in string
- if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
- width = len(value_s[2:])
- hex_fun = HexInt # type: Any
- for ch in value_s[2:]:
- if ch in "ABCDEF": # first non-digit is capital
- hex_fun = HexCapsInt
- break
- if ch in "abcdef":
- break
- if underscore is not None:
- underscore[1] = value_su[2] == "_"
- underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
- return hex_fun(
- sign * int(value_s[2:], 16),
- width=width,
- underscore=underscore,
- anchor=node.anchor,
- )
- elif value_s.startswith("0o"):
- if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
- width = len(value_s[2:])
- if underscore is not None:
- underscore[1] = value_su[2] == "_"
- underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
- return OctalInt(
- sign * int(value_s[2:], 8),
- width=width,
- underscore=underscore,
- anchor=node.anchor,
- )
- elif self.resolver.processing_version != (1, 2) and value_s[0] == "0":
- return sign * int(value_s, 8)
- elif self.resolver.processing_version != (1, 2) and ":" in value_s:
- digits = [int(part) for part in value_s.split(":")]
- digits.reverse()
- base = 1
- value = 0
- for digit in digits:
- value += digit * base
- base *= 60
- return sign * value
- elif self.resolver.processing_version > (1, 1) and value_s[0] == "0":
- # not an octal, an integer with leading zero(s)
- if underscore is not None:
- # cannot have a leading underscore
- underscore[2] = len(value_su) > 1 and value_su[-1] == "_"
- return ScalarInt(
- sign * int(value_s), width=len(value_s), underscore=underscore
- )
- elif underscore:
- # cannot have a leading underscore
- underscore[2] = len(value_su) > 1 and value_su[-1] == "_"
- return ScalarInt(
- sign * int(value_s),
- width=None,
- underscore=underscore,
- anchor=node.anchor,
- )
- elif node.anchor:
- return ScalarInt(sign * int(value_s), width=None, anchor=node.anchor)
- else:
- return sign * int(value_s)
-
- def construct_yaml_float(self, node):
- # type: (Any) -> Any
- def leading_zeros(v):
- # type: (Any) -> int
- lead0 = 0
- idx = 0
- while idx < len(v) and v[idx] in "0.":
- if v[idx] == "0":
- lead0 += 1
- idx += 1
- return lead0
-
- # underscore = None
- m_sign = False # type: Any
- value_so = to_str(self.construct_scalar(node))
- value_s = value_so.replace("_", "").lower()
- sign = +1
- if value_s[0] == "-":
- sign = -1
- if value_s[0] in "+-":
- m_sign = value_s[0]
- value_s = value_s[1:]
- if value_s == ".inf":
- return sign * self.inf_value
- if value_s == ".nan":
- return self.nan_value
- if self.resolver.processing_version != (1, 2) and ":" in value_s:
- digits = [float(part) for part in value_s.split(":")]
- digits.reverse()
- base = 1
- value = 0.0
- for digit in digits:
- value += digit * base
- base *= 60
- return sign * value
- if "e" in value_s:
- try:
- mantissa, exponent = value_so.split("e")
- exp = "e"
- except ValueError:
- mantissa, exponent = value_so.split("E")
- exp = "E"
- if self.resolver.processing_version != (1, 2):
- # value_s is lower case independent of input
- if "." not in mantissa:
- warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so))
- lead0 = leading_zeros(mantissa)
- width = len(mantissa)
- prec = mantissa.find(".")
- if m_sign:
- width -= 1
- e_width = len(exponent)
- e_sign = exponent[0] in "+-"
- # nprint('sf', width, prec, m_sign, exp, e_width, e_sign)
- return ScalarFloat(
- sign * float(value_s),
- width=width,
- prec=prec,
- m_sign=m_sign,
- m_lead0=lead0,
- exp=exp,
- e_width=e_width,
- e_sign=e_sign,
- anchor=node.anchor,
- )
- width = len(value_so)
- prec = value_so.index(
- "."
- ) # you can use index, this would not be float without dot
- lead0 = leading_zeros(value_so)
- return ScalarFloat(
- sign * float(value_s),
- width=width,
- prec=prec,
- m_sign=m_sign,
- m_lead0=lead0,
- anchor=node.anchor,
- )
-
- def construct_yaml_str(self, node):
- # type: (Any) -> Any
- value = self.construct_scalar(node)
- if isinstance(value, ScalarString):
- return value
- if PY3:
- return value
- try:
- return value.encode("ascii")
- except AttributeError:
- # in case you replace the node dynamically e.g. with a dict
- return value
- except UnicodeEncodeError:
- return value
-
- def construct_rt_sequence(self, node, seqtyp, deep=False):
- # type: (Any, Any, bool) -> Any
- if not isinstance(node, SequenceNode):
- raise ConstructorError(
- None,
- None,
- "expected a sequence node, but found %s" % node.id,
- node.start_mark,
- )
- ret_val = []
- if node.comment:
- seqtyp._yaml_add_comment(node.comment[:2])
- if len(node.comment) > 2:
- seqtyp.yaml_end_comment_extend(node.comment[2], clear=True)
- if node.anchor:
- from .serializer import templated_id
-
- if not templated_id(node.anchor):
- seqtyp.yaml_set_anchor(node.anchor)
- for idx, child in enumerate(node.value):
- if child.comment:
- seqtyp._yaml_add_comment(child.comment, key=idx)
- child.comment = None # if moved to sequence remove from child
- ret_val.append(self.construct_object(child, deep=deep))
- seqtyp._yaml_set_idx_line_col(
- idx, [child.start_mark.line, child.start_mark.column]
- )
- return ret_val
-
- def flatten_mapping(self, node):
- # type: (Any) -> Any
- """
- This implements the merge key feature http://yaml.org/type/merge.html
- by inserting keys from the merge dict/list of dicts if not yet
- available in this node
- """
-
- def constructed(value_node):
- # type: (Any) -> Any
- # If the contents of a merge are defined within the
- # merge marker, then they won't have been constructed
- # yet. But if they were already constructed, we need to use
- # the existing object.
- if value_node in self.constructed_objects:
- value = self.constructed_objects[value_node]
- else:
- value = self.construct_object(value_node, deep=False)
- return value
-
- # merge = []
- merge_map_list = [] # type: List[Any]
- index = 0
- while index < len(node.value):
- key_node, value_node = node.value[index]
- if key_node.tag == u"tag:yaml.org,2002:merge":
- if merge_map_list: # double << key
- if self.allow_duplicate_keys:
- del node.value[index]
- index += 1
- continue
- args = [
- "while constructing a mapping",
- node.start_mark,
- 'found duplicate key "{}"'.format(key_node.value),
- key_node.start_mark,
- """
- To suppress this check see:
- http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
- """,
- """\
- Duplicate keys will become an error in future releases, and are errors
- by default when using the new API.
- """,
- ]
- if self.allow_duplicate_keys is None:
- warnings.warn(DuplicateKeyFutureWarning(*args))
- else:
- raise DuplicateKeyError(*args)
- del node.value[index]
- if isinstance(value_node, MappingNode):
- merge_map_list.append((index, constructed(value_node)))
- # self.flatten_mapping(value_node)
- # merge.extend(value_node.value)
- elif isinstance(value_node, SequenceNode):
- # submerge = []
- for subnode in value_node.value:
- if not isinstance(subnode, MappingNode):
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "expected a mapping for merging, but found %s"
- % subnode.id,
- subnode.start_mark,
- )
- merge_map_list.append((index, constructed(subnode)))
- # self.flatten_mapping(subnode)
- # submerge.append(subnode.value)
- # submerge.reverse()
- # for value in submerge:
- # merge.extend(value)
- else:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "expected a mapping or list of mappings for merging, "
- "but found %s" % value_node.id,
- value_node.start_mark,
- )
- elif key_node.tag == u"tag:yaml.org,2002:value":
- key_node.tag = u"tag:yaml.org,2002:str"
- index += 1
- else:
- index += 1
- return merge_map_list
- # if merge:
- # node.value = merge + node.value
-
- def _sentinel(self):
- # type: () -> None
- pass
-
- def construct_mapping(self, node, maptyp, deep=False): # type: ignore
- # type: (Any, Any, bool) -> Any
- if not isinstance(node, MappingNode):
- raise ConstructorError(
- None,
- None,
- "expected a mapping node, but found %s" % node.id,
- node.start_mark,
- )
- merge_map = self.flatten_mapping(node)
- # mapping = {}
- if node.comment:
- maptyp._yaml_add_comment(node.comment[:2])
- if len(node.comment) > 2:
- maptyp.yaml_end_comment_extend(node.comment[2], clear=True)
- if node.anchor:
- from .serializer import templated_id
-
- if not templated_id(node.anchor):
- maptyp.yaml_set_anchor(node.anchor)
- last_key, last_value = None, self._sentinel
- for key_node, value_node in node.value:
- # keys can be list -> deep
- key = self.construct_object(key_node, deep=True)
- # lists are not hashable, but tuples are
- if not isinstance(key, Hashable):
- if isinstance(key, MutableSequence):
- key_s = CommentedKeySeq(key)
- if key_node.flow_style is True:
- key_s.fa.set_flow_style()
- elif key_node.flow_style is False:
- key_s.fa.set_block_style()
- key = key_s
- elif isinstance(key, MutableMapping):
- key_m = CommentedKeyMap(key)
- if key_node.flow_style is True:
- key_m.fa.set_flow_style()
- elif key_node.flow_style is False:
- key_m.fa.set_block_style()
- key = key_m
- if PY2:
- try:
- hash(key)
- except TypeError as exc:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unacceptable key (%s)" % exc,
- key_node.start_mark,
- )
- else:
- if not isinstance(key, Hashable):
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unhashable key",
- key_node.start_mark,
- )
- value = self.construct_object(value_node, deep=deep)
- if self.check_mapping_key(node, key_node, maptyp, key, value):
- if (
- key_node.comment
- and len(key_node.comment) > 4
- and key_node.comment[4]
- ):
- if last_value is None:
- key_node.comment[0] = key_node.comment.pop(4)
- maptyp._yaml_add_comment(key_node.comment, value=last_key)
- else:
- key_node.comment[2] = key_node.comment.pop(4)
- maptyp._yaml_add_comment(key_node.comment, key=key)
- key_node.comment = None
- if key_node.comment:
- maptyp._yaml_add_comment(key_node.comment, key=key)
- if value_node.comment:
- maptyp._yaml_add_comment(value_node.comment, value=key)
- maptyp._yaml_set_kv_line_col(
- key,
- [
- key_node.start_mark.line,
- key_node.start_mark.column,
- value_node.start_mark.line,
- value_node.start_mark.column,
- ],
- )
- maptyp[key] = value
- last_key, last_value = key, value # could use indexing
- # do this last, or <<: before a key will prevent insertion in instances
- # of collections.OrderedDict (as they have no __contains__
- if merge_map:
- maptyp.add_yaml_merge(merge_map)
-
- def construct_setting(self, node, typ, deep=False):
- # type: (Any, Any, bool) -> Any
- if not isinstance(node, MappingNode):
- raise ConstructorError(
- None,
- None,
- "expected a mapping node, but found %s" % node.id,
- node.start_mark,
- )
- if node.comment:
- typ._yaml_add_comment(node.comment[:2])
- if len(node.comment) > 2:
- typ.yaml_end_comment_extend(node.comment[2], clear=True)
- if node.anchor:
- from .serializer import templated_id
-
- if not templated_id(node.anchor):
- typ.yaml_set_anchor(node.anchor)
- for key_node, value_node in node.value:
- # keys can be list -> deep
- key = self.construct_object(key_node, deep=True)
- # lists are not hashable, but tuples are
- if not isinstance(key, Hashable):
- if isinstance(key, list):
- key = tuple(key)
- if PY2:
- try:
- hash(key)
- except TypeError as exc:
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unacceptable key (%s)" % exc,
- key_node.start_mark,
- )
- else:
- if not isinstance(key, Hashable):
- raise ConstructorError(
- "while constructing a mapping",
- node.start_mark,
- "found unhashable key",
- key_node.start_mark,
- )
- # construct but should be null
- value = self.construct_object(value_node, deep=deep) # NOQA
- self.check_set_key(node, key_node, typ, key)
- if key_node.comment:
- typ._yaml_add_comment(key_node.comment, key=key)
- if value_node.comment:
- typ._yaml_add_comment(value_node.comment, value=key)
- typ.add(key)
-
- def construct_yaml_seq(self, node):
- # type: (Any) -> Any
- data = CommentedSeq()
- data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- if node.comment:
- data._yaml_add_comment(node.comment)
- yield data
- data.extend(self.construct_rt_sequence(node, data))
- self.set_collection_style(data, node)
-
- def construct_yaml_map(self, node):
- # type: (Any) -> Any
- data = CommentedMap()
- data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- yield data
- self.construct_mapping(node, data, deep=True)
- self.set_collection_style(data, node)
-
- def set_collection_style(self, data, node):
- # type: (Any, Any) -> None
- if len(data) == 0:
- return
- if node.flow_style is True:
- data.fa.set_flow_style()
- elif node.flow_style is False:
- data.fa.set_block_style()
-
- def construct_yaml_object(self, node, cls):
- # type: (Any, Any) -> Any
- data = cls.__new__(cls)
- yield data
- if hasattr(data, "__setstate__"):
- state = SafeConstructor.construct_mapping(self, node, deep=True)
- data.__setstate__(state)
- else:
- state = SafeConstructor.construct_mapping(self, node)
- data.__dict__.update(state)
-
- def construct_yaml_omap(self, node):
- # type: (Any) -> Any
- # Note: we do now check for duplicate keys
- omap = CommentedOrderedMap()
- omap._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- if node.flow_style is True:
- omap.fa.set_flow_style()
- elif node.flow_style is False:
- omap.fa.set_block_style()
- yield omap
- if node.comment:
- omap._yaml_add_comment(node.comment[:2])
- if len(node.comment) > 2:
- omap.yaml_end_comment_extend(node.comment[2], clear=True)
- if not isinstance(node, SequenceNode):
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a sequence, but found %s" % node.id,
- node.start_mark,
- )
- for subnode in node.value:
- if not isinstance(subnode, MappingNode):
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a mapping of length 1, but found %s" % subnode.id,
- subnode.start_mark,
- )
- if len(subnode.value) != 1:
- raise ConstructorError(
- "while constructing an ordered map",
- node.start_mark,
- "expected a single mapping item, but found %d items"
- % len(subnode.value),
- subnode.start_mark,
- )
- key_node, value_node = subnode.value[0]
- key = self.construct_object(key_node)
- assert key not in omap
- value = self.construct_object(value_node)
- if key_node.comment:
- omap._yaml_add_comment(key_node.comment, key=key)
- if subnode.comment:
- omap._yaml_add_comment(subnode.comment, key=key)
- if value_node.comment:
- omap._yaml_add_comment(value_node.comment, value=key)
- omap[key] = value
-
- def construct_yaml_set(self, node):
- # type: (Any) -> Any
- data = CommentedSet()
- data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- yield data
- self.construct_setting(node, data)
-
- def construct_undefined(self, node):
- # type: (Any) -> Any
- try:
- if isinstance(node, MappingNode):
- data = CommentedMap()
- data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- if node.flow_style is True:
- data.fa.set_flow_style()
- elif node.flow_style is False:
- data.fa.set_block_style()
- data.yaml_set_tag(node.tag)
- yield data
- if node.anchor:
- data.yaml_set_anchor(node.anchor)
- self.construct_mapping(node, data)
- return
- elif isinstance(node, ScalarNode):
- data2 = TaggedScalar()
- data2.value = self.construct_scalar(node)
- data2.style = node.style
- data2.yaml_set_tag(node.tag)
- yield data2
- if node.anchor:
- data2.yaml_set_anchor(node.anchor, always_dump=True)
- return
- elif isinstance(node, SequenceNode):
- data3 = CommentedSeq()
- data3._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
- if node.flow_style is True:
- data3.fa.set_flow_style()
- elif node.flow_style is False:
- data3.fa.set_block_style()
- data3.yaml_set_tag(node.tag)
- yield data3
- if node.anchor:
- data3.yaml_set_anchor(node.anchor)
- data3.extend(self.construct_sequence(node))
- return
- except: # NOQA
- pass
- raise ConstructorError(
- None,
- None,
- "could not determine a constructor for the tag %r" % utf8(node.tag),
- node.start_mark,
- )
-
- def construct_yaml_timestamp(self, node, values=None):
- # type: (Any, Any) -> Any
- try:
- match = self.timestamp_regexp.match(node.value)
- except TypeError:
- match = None
- if match is None:
- raise ConstructorError(
- None,
- None,
- 'failed to construct timestamp from "{}"'.format(node.value),
- node.start_mark,
- )
- values = match.groupdict()
- if not values["hour"]:
- return SafeConstructor.construct_yaml_timestamp(self, node, values)
- for part in ["t", "tz_sign", "tz_hour", "tz_minute"]:
- if values[part]:
- break
- else:
- return SafeConstructor.construct_yaml_timestamp(self, node, values)
- year = int(values["year"])
- month = int(values["month"])
- day = int(values["day"])
- hour = int(values["hour"])
- minute = int(values["minute"])
- second = int(values["second"])
- fraction = 0
- if values["fraction"]:
- fraction_s = values["fraction"][:6]
- while len(fraction_s) < 6:
- fraction_s += "0"
- fraction = int(fraction_s)
- if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4:
- fraction += 1
- delta = None
- if values["tz_sign"]:
- tz_hour = int(values["tz_hour"])
- minutes = values["tz_minute"]
- tz_minute = int(minutes) if minutes else 0
- delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
- if values["tz_sign"] == "-":
- delta = -delta
- if delta:
- dt = datetime.datetime(year, month, day, hour, minute)
- dt -= delta
- data = TimeStamp(
- dt.year, dt.month, dt.day, dt.hour, dt.minute, second, fraction
- )
- data._yaml["delta"] = delta
- tz = values["tz_sign"] + values["tz_hour"]
- if values["tz_minute"]:
- tz += ":" + values["tz_minute"]
- data._yaml["tz"] = tz
- else:
- data = TimeStamp(year, month, day, hour, minute, second, fraction)
- if values["tz"]: # no delta
- data._yaml["tz"] = values["tz"]
-
- if values["t"]:
- data._yaml["t"] = True
- return data
-
- def construct_yaml_bool(self, node):
- # type: (Any) -> Any
- b = SafeConstructor.construct_yaml_bool(self, node)
- if node.anchor:
- return ScalarBoolean(b, anchor=node.anchor)
- return b
-
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:null", RoundTripConstructor.construct_yaml_null
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:bool", RoundTripConstructor.construct_yaml_bool
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:int", RoundTripConstructor.construct_yaml_int
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:float", RoundTripConstructor.construct_yaml_float
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:binary", RoundTripConstructor.construct_yaml_binary
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:timestamp", RoundTripConstructor.construct_yaml_timestamp
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:omap", RoundTripConstructor.construct_yaml_omap
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:pairs", RoundTripConstructor.construct_yaml_pairs
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:set", RoundTripConstructor.construct_yaml_set
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:str", RoundTripConstructor.construct_yaml_str
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:seq", RoundTripConstructor.construct_yaml_seq
-)
-
-RoundTripConstructor.add_constructor(
- u"tag:yaml.org,2002:map", RoundTripConstructor.construct_yaml_map
-)
-
-RoundTripConstructor.add_constructor(None, RoundTripConstructor.construct_undefined)
diff --git a/srsly/ruamel_yaml/cyaml.py b/srsly/ruamel_yaml/cyaml.py
deleted file mode 100755
index 3375944..0000000
--- a/srsly/ruamel_yaml/cyaml.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-from _ruamel_yaml import CParser, CEmitter # type: ignore
-
-from .constructor import Constructor, BaseConstructor, SafeConstructor
-from .representer import Representer, SafeRepresenter, BaseRepresenter
-from .resolver import Resolver, BaseResolver
-
-if False: # MYPY
- from typing import Any, Union, Optional # NOQA
- from .compat import StreamTextType, StreamType, VersionType # NOQA
-
-__all__ = [
- "CBaseLoader",
- "CSafeLoader",
- "CLoader",
- "CBaseDumper",
- "CSafeDumper",
- "CDumper",
-]
-
-
-# this includes some hacks to solve the usage of resolver by lower level
-# parts of the parser
-
-
-class CBaseLoader(CParser, BaseConstructor, BaseResolver): # type: ignore
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- CParser.__init__(self, stream)
- self._parser = self._composer = self
- BaseConstructor.__init__(self, loader=self)
- BaseResolver.__init__(self, loadumper=self)
- # self.descend_resolver = self._resolver.descend_resolver
- # self.ascend_resolver = self._resolver.ascend_resolver
- # self.resolve = self._resolver.resolve
-
-
-class CSafeLoader(CParser, SafeConstructor, Resolver): # type: ignore
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- CParser.__init__(self, stream)
- self._parser = self._composer = self
- SafeConstructor.__init__(self, loader=self)
- Resolver.__init__(self, loadumper=self)
- # self.descend_resolver = self._resolver.descend_resolver
- # self.ascend_resolver = self._resolver.ascend_resolver
- # self.resolve = self._resolver.resolve
-
-
-class CLoader(CParser, Constructor, Resolver): # type: ignore
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- CParser.__init__(self, stream)
- self._parser = self._composer = self
- Constructor.__init__(self, loader=self)
- Resolver.__init__(self, loadumper=self)
- # self.descend_resolver = self._resolver.descend_resolver
- # self.ascend_resolver = self._resolver.ascend_resolver
- # self.resolve = self._resolver.resolve
-
-
-class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): # type: ignore
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- CEmitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- encoding=encoding,
- allow_unicode=allow_unicode,
- line_break=line_break,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- )
- self._emitter = self._serializer = self._representer = self
- BaseRepresenter.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=self,
- )
- BaseResolver.__init__(self, loadumper=self)
-
-
-class CSafeDumper(CEmitter, SafeRepresenter, Resolver): # type: ignore
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- self._emitter = self._serializer = self._representer = self
- CEmitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- encoding=encoding,
- allow_unicode=allow_unicode,
- line_break=line_break,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- )
- self._emitter = self._serializer = self._representer = self
- SafeRepresenter.__init__(
- self, default_style=default_style, default_flow_style=default_flow_style
- )
- Resolver.__init__(self)
-
-
-class CDumper(CEmitter, Representer, Resolver): # type: ignore
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- CEmitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- encoding=encoding,
- allow_unicode=allow_unicode,
- line_break=line_break,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- )
- self._emitter = self._serializer = self._representer = self
- Representer.__init__(
- self, default_style=default_style, default_flow_style=default_flow_style
- )
- Resolver.__init__(self)
diff --git a/srsly/ruamel_yaml/dumper.py b/srsly/ruamel_yaml/dumper.py
deleted file mode 100755
index 04d3b4d..0000000
--- a/srsly/ruamel_yaml/dumper.py
+++ /dev/null
@@ -1,221 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-from .emitter import Emitter
-from .serializer import Serializer
-from .representer import (
- Representer,
- SafeRepresenter,
- BaseRepresenter,
- RoundTripRepresenter,
-)
-from .resolver import Resolver, BaseResolver, VersionedResolver
-
-if False: # MYPY
- from typing import Any, Dict, List, Union, Optional # NOQA
- from .compat import StreamType, VersionType # NOQA
-
-__all__ = ["BaseDumper", "SafeDumper", "Dumper", "RoundTripDumper"]
-
-
-class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver):
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (Any, StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- Emitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- block_seq_indent=block_seq_indent,
- dumper=self,
- )
- Serializer.__init__(
- self,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- dumper=self,
- )
- BaseRepresenter.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=self,
- )
- BaseResolver.__init__(self, loadumper=self)
-
-
-class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver):
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- Emitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- block_seq_indent=block_seq_indent,
- dumper=self,
- )
- Serializer.__init__(
- self,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- dumper=self,
- )
- SafeRepresenter.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=self,
- )
- Resolver.__init__(self, loadumper=self)
-
-
-class Dumper(Emitter, Serializer, Representer, Resolver):
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- Emitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- block_seq_indent=block_seq_indent,
- dumper=self,
- )
- Serializer.__init__(
- self,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- dumper=self,
- )
- Representer.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=self,
- )
- Resolver.__init__(self, loadumper=self)
-
-
-class RoundTripDumper(Emitter, Serializer, RoundTripRepresenter, VersionedResolver):
- def __init__(
- self,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Optional[bool], Optional[int], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- Emitter.__init__(
- self,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- dumper=self,
- )
- Serializer.__init__(
- self,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- dumper=self,
- )
- RoundTripRepresenter.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=self,
- )
- VersionedResolver.__init__(self, loader=self)
diff --git a/srsly/ruamel_yaml/emitter.py b/srsly/ruamel_yaml/emitter.py
deleted file mode 100755
index 74a1660..0000000
--- a/srsly/ruamel_yaml/emitter.py
+++ /dev/null
@@ -1,1727 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-from __future__ import print_function
-
-# Emitter expects events obeying the following grammar:
-# stream ::= STREAM-START document* STREAM-END
-# document ::= DOCUMENT-START node DOCUMENT-END
-# node ::= SCALAR | sequence | mapping
-# sequence ::= SEQUENCE-START node* SEQUENCE-END
-# mapping ::= MAPPING-START (node node)* MAPPING-END
-
-import sys
-from .error import YAMLError, YAMLStreamError
-from .events import * # NOQA
-
-# fmt: off
-from .compat import utf8, text_type, PY2, nprint, dbg, DBG_EVENT, \
- check_anchorname_char
-# fmt: on
-
-if False: # MYPY
- from typing import Any, Dict, List, Union, Text, Tuple, Optional # NOQA
- from .compat import StreamType # NOQA
-
-__all__ = ["Emitter", "EmitterError"]
-
-
-class EmitterError(YAMLError):
- pass
-
-
-class ScalarAnalysis(object):
- def __init__(
- self,
- scalar,
- empty,
- multiline,
- allow_flow_plain,
- allow_block_plain,
- allow_single_quoted,
- allow_double_quoted,
- allow_block,
- ):
- # type: (Any, Any, Any, bool, bool, bool, bool, bool) -> None
- self.scalar = scalar
- self.empty = empty
- self.multiline = multiline
- self.allow_flow_plain = allow_flow_plain
- self.allow_block_plain = allow_block_plain
- self.allow_single_quoted = allow_single_quoted
- self.allow_double_quoted = allow_double_quoted
- self.allow_block = allow_block
-
-
-class Indents(object):
- # replacement for the list based stack of None/int
- def __init__(self):
- # type: () -> None
- self.values = [] # type: List[Tuple[int, bool]]
-
- def append(self, val, seq):
- # type: (Any, Any) -> None
- self.values.append((val, seq))
-
- def pop(self):
- # type: () -> Any
- return self.values.pop()[0]
-
- def last_seq(self):
- # type: () -> bool
- # return the seq(uence) value for the element added before the last one
- # in increase_indent()
- try:
- return self.values[-2][1]
- except IndexError:
- return False
-
- def seq_flow_align(self, seq_indent, column):
- # type: (int, int) -> int
- # extra spaces because of dash
- if len(self.values) < 2 or not self.values[-1][1]:
- return 0
- # -1 for the dash
- base = self.values[-1][0] if self.values[-1][0] is not None else 0
- return base + seq_indent - column - 1
-
- def __len__(self):
- # type: () -> int
- return len(self.values)
-
-
-class Emitter(object):
- # fmt: off
- DEFAULT_TAG_PREFIXES = {
- u'!': u'!',
- u'tag:yaml.org,2002:': u'!!',
- }
- # fmt: on
-
- MAX_SIMPLE_KEY_LENGTH = 128
-
- def __init__(
- self,
- stream,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- brace_single_entry_mapping_in_flow_sequence=None,
- dumper=None,
- ):
- # type: (StreamType, Any, Optional[int], Optional[int], Optional[bool], Any, Optional[int], Optional[bool], Any, Optional[bool], Any) -> None # NOQA
- self.dumper = dumper
- if self.dumper is not None and getattr(self.dumper, "_emitter", None) is None:
- self.dumper._emitter = self
- self.stream = stream
-
- # Encoding can be overriden by STREAM-START.
- self.encoding = None # type: Optional[Text]
- self.allow_space_break = None
-
- # Emitter is a state machine with a stack of states to handle nested
- # structures.
- self.states = [] # type: List[Any]
- self.state = self.expect_stream_start # type: Any
-
- # Current event and the event queue.
- self.events = [] # type: List[Any]
- self.event = None # type: Any
-
- # The current indentation level and the stack of previous indents.
- self.indents = Indents()
- self.indent = None # type: Optional[int]
-
- # flow_context is an expanding/shrinking list consisting of '{' and '['
- # for each unclosed flow context. If empty list that means block context
- self.flow_context = [] # type: List[Text]
-
- # Contexts.
- self.root_context = False
- self.sequence_context = False
- self.mapping_context = False
- self.simple_key_context = False
-
- # Characteristics of the last emitted character:
- # - current position.
- # - is it a whitespace?
- # - is it an indention character
- # (indentation space, '-', '?', or ':')?
- self.line = 0
- self.column = 0
- self.whitespace = True
- self.indention = True
- self.compact_seq_seq = True # dash after dash
- self.compact_seq_map = True # key after dash
- # self.compact_ms = False # dash after key, only when excplicit key with ?
- self.no_newline = None # type: Optional[bool] # set if directly after `- `
-
- # Whether the document requires an explicit document end indicator
- self.open_ended = False
-
- # colon handling
- self.colon = u":"
- self.prefixed_colon = (
- self.colon if prefix_colon is None else prefix_colon + self.colon
- )
- # single entry mappings in flow sequence
- self.brace_single_entry_mapping_in_flow_sequence = (
- brace_single_entry_mapping_in_flow_sequence
- ) # NOQA
-
- # Formatting details.
- self.canonical = canonical
- self.allow_unicode = allow_unicode
- # set to False to get "\Uxxxxxxxx" for non-basic unicode like emojis
- self.unicode_supplementary = sys.maxunicode > 0xFFFF
- self.sequence_dash_offset = block_seq_indent if block_seq_indent else 0
- self.top_level_colon_align = top_level_colon_align
- self.best_sequence_indent = 2
- self.requested_indent = indent # specific for literal zero indent
- if indent and 1 < indent < 10:
- self.best_sequence_indent = indent
- self.best_map_indent = self.best_sequence_indent
- # if self.best_sequence_indent < self.sequence_dash_offset + 1:
- # self.best_sequence_indent = self.sequence_dash_offset + 1
- self.best_width = 80
- if width and width > self.best_sequence_indent * 2:
- self.best_width = width
- self.best_line_break = u"\n" # type: Any
- if line_break in [u"\r", u"\n", u"\r\n"]:
- self.best_line_break = line_break
-
- # Tag prefixes.
- self.tag_prefixes = None # type: Any
-
- # Prepared anchor and tag.
- self.prepared_anchor = None # type: Any
- self.prepared_tag = None # type: Any
-
- # Scalar analysis and style.
- self.analysis = None # type: Any
- self.style = None # type: Any
-
- self.scalar_after_indicator = True # write a scalar on the same line as `---`
-
- @property
- def stream(self):
- # type: () -> Any
- try:
- return self._stream
- except AttributeError:
- raise YAMLStreamError("output stream needs to specified")
-
- @stream.setter
- def stream(self, val):
- # type: (Any) -> None
- if val is None:
- return
- if not hasattr(val, "write"):
- raise YAMLStreamError("stream argument needs to have a write() method")
- self._stream = val
-
- @property
- def serializer(self):
- # type: () -> Any
- try:
- if hasattr(self.dumper, "typ"):
- return self.dumper.serializer
- return self.dumper._serializer
- except AttributeError:
- return self # cyaml
-
- @property
- def flow_level(self):
- # type: () -> int
- return len(self.flow_context)
-
- def dispose(self):
- # type: () -> None
- # Reset the state attributes (to clear self-references)
- self.states = []
- self.state = None
-
- def emit(self, event):
- # type: (Any) -> None
- if dbg(DBG_EVENT):
- nprint(event)
- self.events.append(event)
- while not self.need_more_events():
- self.event = self.events.pop(0)
- self.state()
- self.event = None
-
- # In some cases, we wait for a few next events before emitting.
-
- def need_more_events(self):
- # type: () -> bool
- if not self.events:
- return True
- event = self.events[0]
- if isinstance(event, DocumentStartEvent):
- return self.need_events(1)
- elif isinstance(event, SequenceStartEvent):
- return self.need_events(2)
- elif isinstance(event, MappingStartEvent):
- return self.need_events(3)
- else:
- return False
-
- def need_events(self, count):
- # type: (int) -> bool
- level = 0
- for event in self.events[1:]:
- if isinstance(event, (DocumentStartEvent, CollectionStartEvent)):
- level += 1
- elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)):
- level -= 1
- elif isinstance(event, StreamEndEvent):
- level = -1
- if level < 0:
- return False
- return len(self.events) < count + 1
-
- def increase_indent(self, flow=False, sequence=None, indentless=False):
- # type: (bool, Optional[bool], bool) -> None
- self.indents.append(self.indent, sequence)
- if self.indent is None: # top level
- if flow:
- # self.indent = self.best_sequence_indent if self.indents.last_seq() else \
- # self.best_map_indent
- # self.indent = self.best_sequence_indent
- self.indent = self.requested_indent
- else:
- self.indent = 0
- elif not indentless:
- self.indent += (
- self.best_sequence_indent
- if self.indents.last_seq()
- else self.best_map_indent
- )
- # if self.indents.last_seq():
- # if self.indent == 0: # top level block sequence
- # self.indent = self.best_sequence_indent - self.sequence_dash_offset
- # else:
- # self.indent += self.best_sequence_indent
- # else:
- # self.indent += self.best_map_indent
-
- # States.
-
- # Stream handlers.
-
- def expect_stream_start(self):
- # type: () -> None
- if isinstance(self.event, StreamStartEvent):
- if PY2:
- if self.event.encoding and not getattr(self.stream, "encoding", None):
- self.encoding = self.event.encoding
- else:
- if self.event.encoding and not hasattr(self.stream, "encoding"):
- self.encoding = self.event.encoding
- self.write_stream_start()
- self.state = self.expect_first_document_start
- else:
- raise EmitterError("expected StreamStartEvent, but got %s" % (self.event,))
-
- def expect_nothing(self):
- # type: () -> None
- raise EmitterError("expected nothing, but got %s" % (self.event,))
-
- # Document handlers.
-
- def expect_first_document_start(self):
- # type: () -> Any
- return self.expect_document_start(first=True)
-
- def expect_document_start(self, first=False):
- # type: (bool) -> None
- if isinstance(self.event, DocumentStartEvent):
- if (self.event.version or self.event.tags) and self.open_ended:
- self.write_indicator(u"...", True)
- self.write_indent()
- if self.event.version:
- version_text = self.prepare_version(self.event.version)
- self.write_version_directive(version_text)
- self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy()
- if self.event.tags:
- handles = sorted(self.event.tags.keys())
- for handle in handles:
- prefix = self.event.tags[handle]
- self.tag_prefixes[prefix] = handle
- handle_text = self.prepare_tag_handle(handle)
- prefix_text = self.prepare_tag_prefix(prefix)
- self.write_tag_directive(handle_text, prefix_text)
- implicit = (
- first
- and not self.event.explicit
- and not self.canonical
- and not self.event.version
- and not self.event.tags
- and not self.check_empty_document()
- )
- if not implicit:
- self.write_indent()
- self.write_indicator(u"---", True)
- if self.canonical:
- self.write_indent()
- self.state = self.expect_document_root
- elif isinstance(self.event, StreamEndEvent):
- if self.open_ended:
- self.write_indicator(u"...", True)
- self.write_indent()
- self.write_stream_end()
- self.state = self.expect_nothing
- else:
- raise EmitterError(
- "expected DocumentStartEvent, but got %s" % (self.event,)
- )
-
- def expect_document_end(self):
- # type: () -> None
- if isinstance(self.event, DocumentEndEvent):
- self.write_indent()
- if self.event.explicit:
- self.write_indicator(u"...", True)
- self.write_indent()
- self.flush_stream()
- self.state = self.expect_document_start
- else:
- raise EmitterError("expected DocumentEndEvent, but got %s" % (self.event,))
-
- def expect_document_root(self):
- # type: () -> None
- self.states.append(self.expect_document_end)
- self.expect_node(root=True)
-
- # Node handlers.
-
- def expect_node(self, root=False, sequence=False, mapping=False, simple_key=False):
- # type: (bool, bool, bool, bool) -> None
- self.root_context = root
- self.sequence_context = sequence # not used in PyYAML
- self.mapping_context = mapping
- self.simple_key_context = simple_key
- if isinstance(self.event, AliasEvent):
- self.expect_alias()
- elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)):
- if (
- self.process_anchor(u"&")
- and isinstance(self.event, ScalarEvent)
- and self.sequence_context
- ):
- self.sequence_context = False
- if (
- root
- and isinstance(self.event, ScalarEvent)
- and not self.scalar_after_indicator
- ):
- self.write_indent()
- self.process_tag()
- if isinstance(self.event, ScalarEvent):
- # nprint('@', self.indention, self.no_newline, self.column)
- self.expect_scalar()
- elif isinstance(self.event, SequenceStartEvent):
- # nprint('@', self.indention, self.no_newline, self.column)
- i2, n2 = self.indention, self.no_newline # NOQA
- if self.event.comment:
- if self.event.flow_style is False and self.event.comment:
- if self.write_post_comment(self.event):
- self.indention = False
- self.no_newline = True
- if self.write_pre_comment(self.event):
- self.indention = i2
- self.no_newline = not self.indention
- if (
- self.flow_level
- or self.canonical
- or self.event.flow_style
- or self.check_empty_sequence()
- ):
- self.expect_flow_sequence()
- else:
- self.expect_block_sequence()
- elif isinstance(self.event, MappingStartEvent):
- if self.event.flow_style is False and self.event.comment:
- self.write_post_comment(self.event)
- if self.event.comment and self.event.comment[1]:
- self.write_pre_comment(self.event)
- if (
- self.flow_level
- or self.canonical
- or self.event.flow_style
- or self.check_empty_mapping()
- ):
- self.expect_flow_mapping(single=self.event.nr_items == 1)
- else:
- self.expect_block_mapping()
- else:
- raise EmitterError("expected NodeEvent, but got %s" % (self.event,))
-
- def expect_alias(self):
- # type: () -> None
- if self.event.anchor is None:
- raise EmitterError("anchor is not specified for alias")
- self.process_anchor(u"*")
- self.state = self.states.pop()
-
- def expect_scalar(self):
- # type: () -> None
- self.increase_indent(flow=True)
- self.process_scalar()
- self.indent = self.indents.pop()
- self.state = self.states.pop()
-
- # Flow sequence handlers.
-
- def expect_flow_sequence(self):
- # type: () -> None
- ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column)
- self.write_indicator(u" " * ind + u"[", True, whitespace=True)
- self.increase_indent(flow=True, sequence=True)
- self.flow_context.append("[")
- self.state = self.expect_first_flow_sequence_item
-
- def expect_first_flow_sequence_item(self):
- # type: () -> None
- if isinstance(self.event, SequenceEndEvent):
- self.indent = self.indents.pop()
- popped = self.flow_context.pop()
- assert popped == "["
- self.write_indicator(u"]", False)
- if self.event.comment and self.event.comment[0]:
- # eol comment on empty flow sequence
- self.write_post_comment(self.event)
- elif self.flow_level == 0:
- self.write_line_break()
- self.state = self.states.pop()
- else:
- if self.canonical or self.column > self.best_width:
- self.write_indent()
- self.states.append(self.expect_flow_sequence_item)
- self.expect_node(sequence=True)
-
- def expect_flow_sequence_item(self):
- # type: () -> None
- if isinstance(self.event, SequenceEndEvent):
- self.indent = self.indents.pop()
- popped = self.flow_context.pop()
- assert popped == "["
- if self.canonical:
- self.write_indicator(u",", False)
- self.write_indent()
- self.write_indicator(u"]", False)
- if self.event.comment and self.event.comment[0]:
- # eol comment on flow sequence
- self.write_post_comment(self.event)
- else:
- self.no_newline = False
- self.state = self.states.pop()
- else:
- self.write_indicator(u",", False)
- if self.canonical or self.column > self.best_width:
- self.write_indent()
- self.states.append(self.expect_flow_sequence_item)
- self.expect_node(sequence=True)
-
- # Flow mapping handlers.
-
- def expect_flow_mapping(self, single=False):
- # type: (Optional[bool]) -> None
- ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column)
- map_init = u"{"
- if (
- single
- and self.flow_level
- and self.flow_context[-1] == "["
- and not self.canonical
- and not self.brace_single_entry_mapping_in_flow_sequence
- ):
- # single map item with flow context, no curly braces necessary
- map_init = u""
- self.write_indicator(u" " * ind + map_init, True, whitespace=True)
- self.flow_context.append(map_init)
- self.increase_indent(flow=True, sequence=False)
- self.state = self.expect_first_flow_mapping_key
-
- def expect_first_flow_mapping_key(self):
- # type: () -> None
- if isinstance(self.event, MappingEndEvent):
- self.indent = self.indents.pop()
- popped = self.flow_context.pop()
- assert popped == "{" # empty flow mapping
- self.write_indicator(u"}", False)
- if self.event.comment and self.event.comment[0]:
- # eol comment on empty mapping
- self.write_post_comment(self.event)
- elif self.flow_level == 0:
- self.write_line_break()
- self.state = self.states.pop()
- else:
- if self.canonical or self.column > self.best_width:
- self.write_indent()
- if not self.canonical and self.check_simple_key():
- self.states.append(self.expect_flow_mapping_simple_value)
- self.expect_node(mapping=True, simple_key=True)
- else:
- self.write_indicator(u"?", True)
- self.states.append(self.expect_flow_mapping_value)
- self.expect_node(mapping=True)
-
- def expect_flow_mapping_key(self):
- # type: () -> None
- if isinstance(self.event, MappingEndEvent):
- # if self.event.comment and self.event.comment[1]:
- # self.write_pre_comment(self.event)
- self.indent = self.indents.pop()
- popped = self.flow_context.pop()
- assert popped in [u"{", u""]
- if self.canonical:
- self.write_indicator(u",", False)
- self.write_indent()
- if popped != u"":
- self.write_indicator(u"}", False)
- if self.event.comment and self.event.comment[0]:
- # eol comment on flow mapping, never reached on empty mappings
- self.write_post_comment(self.event)
- else:
- self.no_newline = False
- self.state = self.states.pop()
- else:
- self.write_indicator(u",", False)
- if self.canonical or self.column > self.best_width:
- self.write_indent()
- if not self.canonical and self.check_simple_key():
- self.states.append(self.expect_flow_mapping_simple_value)
- self.expect_node(mapping=True, simple_key=True)
- else:
- self.write_indicator(u"?", True)
- self.states.append(self.expect_flow_mapping_value)
- self.expect_node(mapping=True)
-
- def expect_flow_mapping_simple_value(self):
- # type: () -> None
- self.write_indicator(self.prefixed_colon, False)
- self.states.append(self.expect_flow_mapping_key)
- self.expect_node(mapping=True)
-
- def expect_flow_mapping_value(self):
- # type: () -> None
- if self.canonical or self.column > self.best_width:
- self.write_indent()
- self.write_indicator(self.prefixed_colon, True)
- self.states.append(self.expect_flow_mapping_key)
- self.expect_node(mapping=True)
-
- # Block sequence handlers.
-
- def expect_block_sequence(self):
- # type: () -> None
- if self.mapping_context:
- indentless = not self.indention
- else:
- indentless = False
- if not self.compact_seq_seq and self.column != 0:
- self.write_line_break()
- self.increase_indent(flow=False, sequence=True, indentless=indentless)
- self.state = self.expect_first_block_sequence_item
-
- def expect_first_block_sequence_item(self):
- # type: () -> Any
- return self.expect_block_sequence_item(first=True)
-
- def expect_block_sequence_item(self, first=False):
- # type: (bool) -> None
- if not first and isinstance(self.event, SequenceEndEvent):
- if self.event.comment and self.event.comment[1]:
- # final comments on a block list e.g. empty line
- self.write_pre_comment(self.event)
- self.indent = self.indents.pop()
- self.state = self.states.pop()
- self.no_newline = False
- else:
- if self.event.comment and self.event.comment[1]:
- self.write_pre_comment(self.event)
- nonl = self.no_newline if self.column == 0 else False
- self.write_indent()
- ind = self.sequence_dash_offset # if len(self.indents) > 1 else 0
- self.write_indicator(u" " * ind + u"-", True, indention=True)
- if nonl or self.sequence_dash_offset + 2 > self.best_sequence_indent:
- self.no_newline = True
- self.states.append(self.expect_block_sequence_item)
- self.expect_node(sequence=True)
-
- # Block mapping handlers.
-
- def expect_block_mapping(self):
- # type: () -> None
- if not self.mapping_context and not (self.compact_seq_map or self.column == 0):
- self.write_line_break()
- self.increase_indent(flow=False, sequence=False)
- self.state = self.expect_first_block_mapping_key
-
- def expect_first_block_mapping_key(self):
- # type: () -> None
- return self.expect_block_mapping_key(first=True)
-
- def expect_block_mapping_key(self, first=False):
- # type: (Any) -> None
- if not first and isinstance(self.event, MappingEndEvent):
- if self.event.comment and self.event.comment[1]:
- # final comments from a doc
- self.write_pre_comment(self.event)
- self.indent = self.indents.pop()
- self.state = self.states.pop()
- else:
- if self.event.comment and self.event.comment[1]:
- # final comments from a doc
- self.write_pre_comment(self.event)
- self.write_indent()
- if self.check_simple_key():
- if not isinstance(
- self.event, (SequenceStartEvent, MappingStartEvent)
- ): # sequence keys
- try:
- if self.event.style == "?":
- self.write_indicator(u"?", True, indention=True)
- except AttributeError: # aliases have no style
- pass
- self.states.append(self.expect_block_mapping_simple_value)
- self.expect_node(mapping=True, simple_key=True)
- if isinstance(self.event, AliasEvent):
- self.stream.write(u" ")
- else:
- self.write_indicator(u"?", True, indention=True)
- self.states.append(self.expect_block_mapping_value)
- self.expect_node(mapping=True)
-
- def expect_block_mapping_simple_value(self):
- # type: () -> None
- if getattr(self.event, "style", None) != "?":
- # prefix = u''
- if self.indent == 0 and self.top_level_colon_align is not None:
- # write non-prefixed colon
- c = u" " * (self.top_level_colon_align - self.column) + self.colon
- else:
- c = self.prefixed_colon
- self.write_indicator(c, False)
- self.states.append(self.expect_block_mapping_key)
- self.expect_node(mapping=True)
-
- def expect_block_mapping_value(self):
- # type: () -> None
- self.write_indent()
- self.write_indicator(self.prefixed_colon, True, indention=True)
- self.states.append(self.expect_block_mapping_key)
- self.expect_node(mapping=True)
-
- # Checkers.
-
- def check_empty_sequence(self):
- # type: () -> bool
- return (
- isinstance(self.event, SequenceStartEvent)
- and bool(self.events)
- and isinstance(self.events[0], SequenceEndEvent)
- )
-
- def check_empty_mapping(self):
- # type: () -> bool
- return (
- isinstance(self.event, MappingStartEvent)
- and bool(self.events)
- and isinstance(self.events[0], MappingEndEvent)
- )
-
- def check_empty_document(self):
- # type: () -> bool
- if not isinstance(self.event, DocumentStartEvent) or not self.events:
- return False
- event = self.events[0]
- return (
- isinstance(event, ScalarEvent)
- and event.anchor is None
- and event.tag is None
- and event.implicit
- and event.value == ""
- )
-
- def check_simple_key(self):
- # type: () -> bool
- length = 0
- if isinstance(self.event, NodeEvent) and self.event.anchor is not None:
- if self.prepared_anchor is None:
- self.prepared_anchor = self.prepare_anchor(self.event.anchor)
- length += len(self.prepared_anchor)
- if (
- isinstance(self.event, (ScalarEvent, CollectionStartEvent))
- and self.event.tag is not None
- ):
- if self.prepared_tag is None:
- self.prepared_tag = self.prepare_tag(self.event.tag)
- length += len(self.prepared_tag)
- if isinstance(self.event, ScalarEvent):
- if self.analysis is None:
- self.analysis = self.analyze_scalar(self.event.value)
- length += len(self.analysis.scalar)
- return length < self.MAX_SIMPLE_KEY_LENGTH and (
- isinstance(self.event, AliasEvent)
- or (
- isinstance(self.event, SequenceStartEvent)
- and self.event.flow_style is True
- )
- or (
- isinstance(self.event, MappingStartEvent)
- and self.event.flow_style is True
- )
- or (
- isinstance(self.event, ScalarEvent)
- # if there is an explicit style for an empty string, it is a simple key
- and not (self.analysis.empty and self.style and self.style not in "'\"")
- and not self.analysis.multiline
- )
- or self.check_empty_sequence()
- or self.check_empty_mapping()
- )
-
- # Anchor, Tag, and Scalar processors.
-
- def process_anchor(self, indicator):
- # type: (Any) -> bool
- if self.event.anchor is None:
- self.prepared_anchor = None
- return False
- if self.prepared_anchor is None:
- self.prepared_anchor = self.prepare_anchor(self.event.anchor)
- if self.prepared_anchor:
- self.write_indicator(indicator + self.prepared_anchor, True)
- # issue 288
- self.no_newline = False
- self.prepared_anchor = None
- return True
-
- def process_tag(self):
- # type: () -> None
- tag = self.event.tag
- if isinstance(self.event, ScalarEvent):
- if self.style is None:
- self.style = self.choose_scalar_style()
- if (not self.canonical or tag is None) and (
- (self.style == "" and self.event.implicit[0])
- or (self.style != "" and self.event.implicit[1])
- ):
- self.prepared_tag = None
- return
- if self.event.implicit[0] and tag is None:
- tag = u"!"
- self.prepared_tag = None
- else:
- if (not self.canonical or tag is None) and self.event.implicit:
- self.prepared_tag = None
- return
- if tag is None:
- raise EmitterError("tag is not specified")
- if self.prepared_tag is None:
- self.prepared_tag = self.prepare_tag(tag)
- if self.prepared_tag:
- self.write_indicator(self.prepared_tag, True)
- if (
- self.sequence_context
- and not self.flow_level
- and isinstance(self.event, ScalarEvent)
- ):
- self.no_newline = True
- self.prepared_tag = None
-
- def choose_scalar_style(self):
- # type: () -> Any
- if self.analysis is None:
- self.analysis = self.analyze_scalar(self.event.value)
- if self.event.style == '"' or self.canonical:
- return '"'
- if (not self.event.style or self.event.style == "?") and (
- self.event.implicit[0] or not self.event.implicit[2]
- ):
- if not (
- self.simple_key_context
- and (self.analysis.empty or self.analysis.multiline)
- ) and (
- self.flow_level
- and self.analysis.allow_flow_plain
- or (not self.flow_level and self.analysis.allow_block_plain)
- ):
- return ""
- self.analysis.allow_block = True
- if self.event.style and self.event.style in "|>":
- if (
- not self.flow_level
- and not self.simple_key_context
- and self.analysis.allow_block
- ):
- return self.event.style
- if not self.event.style and self.analysis.allow_double_quoted:
- if "'" in self.event.value or "\n" in self.event.value:
- return '"'
- if not self.event.style or self.event.style == "'":
- if self.analysis.allow_single_quoted and not (
- self.simple_key_context and self.analysis.multiline
- ):
- return "'"
- return '"'
-
- def process_scalar(self):
- # type: () -> None
- if self.analysis is None:
- self.analysis = self.analyze_scalar(self.event.value)
- if self.style is None:
- self.style = self.choose_scalar_style()
- split = not self.simple_key_context
- # if self.analysis.multiline and split \
- # and (not self.style or self.style in '\'\"'):
- # self.write_indent()
- # nprint('xx', self.sequence_context, self.flow_level)
- if self.sequence_context and not self.flow_level:
- self.write_indent()
- if self.style == '"':
- self.write_double_quoted(self.analysis.scalar, split)
- elif self.style == "'":
- self.write_single_quoted(self.analysis.scalar, split)
- elif self.style == ">":
- self.write_folded(self.analysis.scalar)
- elif self.style == "|":
- self.write_literal(self.analysis.scalar, self.event.comment)
- else:
- self.write_plain(self.analysis.scalar, split)
- self.analysis = None
- self.style = None
- if self.event.comment:
- self.write_post_comment(self.event)
-
- # Analyzers.
-
- def prepare_version(self, version):
- # type: (Any) -> Any
- major, minor = version
- if major != 1:
- raise EmitterError("unsupported YAML version: %d.%d" % (major, minor))
- return u"%d.%d" % (major, minor)
-
- def prepare_tag_handle(self, handle):
- # type: (Any) -> Any
- if not handle:
- raise EmitterError("tag handle must not be empty")
- if handle[0] != u"!" or handle[-1] != u"!":
- raise EmitterError(
- "tag handle must start and end with '!': %r" % (utf8(handle))
- )
- for ch in handle[1:-1]:
- if not (
- u"0" <= ch <= u"9"
- or u"A" <= ch <= u"Z"
- or u"a" <= ch <= u"z"
- or ch in u"-_"
- ):
- raise EmitterError(
- "invalid character %r in the tag handle: %r"
- % (utf8(ch), utf8(handle))
- )
- return handle
-
- def prepare_tag_prefix(self, prefix):
- # type: (Any) -> Any
- if not prefix:
- raise EmitterError("tag prefix must not be empty")
- chunks = [] # type: List[Any]
- start = end = 0
- if prefix[0] == u"!":
- end = 1
- ch_set = u"-;/?:@&=+$,_.~*'()[]"
- if self.dumper:
- version = getattr(self.dumper, "version", (1, 2))
- if version is None or version >= (1, 2):
- ch_set += u"#"
- while end < len(prefix):
- ch = prefix[end]
- if (
- u"0" <= ch <= u"9"
- or u"A" <= ch <= u"Z"
- or u"a" <= ch <= u"z"
- or ch in ch_set
- ):
- end += 1
- else:
- if start < end:
- chunks.append(prefix[start:end])
- start = end = end + 1
- data = utf8(ch)
- for ch in data:
- chunks.append(u"%%%02X" % ord(ch))
- if start < end:
- chunks.append(prefix[start:end])
- return "".join(chunks)
-
- def prepare_tag(self, tag):
- # type: (Any) -> Any
- if not tag:
- raise EmitterError("tag must not be empty")
- if tag == u"!":
- return tag
- handle = None
- suffix = tag
- prefixes = sorted(self.tag_prefixes.keys())
- for prefix in prefixes:
- if tag.startswith(prefix) and (prefix == u"!" or len(prefix) < len(tag)):
- handle = self.tag_prefixes[prefix]
- suffix = tag[len(prefix) :]
- chunks = [] # type: List[Any]
- start = end = 0
- ch_set = u"-;/?:@&=+$,_.~*'()[]"
- if self.dumper:
- version = getattr(self.dumper, "version", (1, 2))
- if version is None or version >= (1, 2):
- ch_set += u"#"
- while end < len(suffix):
- ch = suffix[end]
- if (
- u"0" <= ch <= u"9"
- or u"A" <= ch <= u"Z"
- or u"a" <= ch <= u"z"
- or ch in ch_set
- or (ch == u"!" and handle != u"!")
- ):
- end += 1
- else:
- if start < end:
- chunks.append(suffix[start:end])
- start = end = end + 1
- data = utf8(ch)
- for ch in data:
- chunks.append(u"%%%02X" % ord(ch))
- if start < end:
- chunks.append(suffix[start:end])
- suffix_text = "".join(chunks)
- if handle:
- return u"%s%s" % (handle, suffix_text)
- else:
- return u"!<%s>" % suffix_text
-
- def prepare_anchor(self, anchor):
- # type: (Any) -> Any
- if not anchor:
- raise EmitterError("anchor must not be empty")
- for ch in anchor:
- if not check_anchorname_char(ch):
- raise EmitterError(
- "invalid character %r in the anchor: %r" % (utf8(ch), utf8(anchor))
- )
- return anchor
-
- def analyze_scalar(self, scalar):
- # type: (Any) -> Any
- # Empty scalar is a special case.
- if not scalar:
- return ScalarAnalysis(
- scalar=scalar,
- empty=True,
- multiline=False,
- allow_flow_plain=False,
- allow_block_plain=True,
- allow_single_quoted=True,
- allow_double_quoted=True,
- allow_block=False,
- )
-
- # Indicators and special characters.
- block_indicators = False
- flow_indicators = False
- line_breaks = False
- special_characters = False
-
- # Important whitespace combinations.
- leading_space = False
- leading_break = False
- trailing_space = False
- trailing_break = False
- break_space = False
- space_break = False
-
- # Check document indicators.
- if scalar.startswith(u"---") or scalar.startswith(u"..."):
- block_indicators = True
- flow_indicators = True
-
- # First character or preceded by a whitespace.
- preceeded_by_whitespace = True
-
- # Last character or followed by a whitespace.
- followed_by_whitespace = (
- len(scalar) == 1 or scalar[1] in u"\0 \t\r\n\x85\u2028\u2029"
- )
-
- # The previous character is a space.
- previous_space = False
-
- # The previous character is a break.
- previous_break = False
-
- index = 0
- while index < len(scalar):
- ch = scalar[index]
-
- # Check for indicators.
- if index == 0:
- # Leading indicators are special characters.
- if ch in u"#,[]{}&*!|>'\"%@`":
- flow_indicators = True
- block_indicators = True
- if ch in u"?:": # ToDo
- if self.serializer.use_version == (1, 1):
- flow_indicators = True
- elif len(scalar) == 1: # single character
- flow_indicators = True
- if followed_by_whitespace:
- block_indicators = True
- if ch == u"-" and followed_by_whitespace:
- flow_indicators = True
- block_indicators = True
- else:
- # Some indicators cannot appear within a scalar as well.
- if ch in u",[]{}": # http://yaml.org/spec/1.2/spec.html#id2788859
- flow_indicators = True
- if ch == u"?" and self.serializer.use_version == (1, 1):
- flow_indicators = True
- if ch == u":":
- if followed_by_whitespace:
- flow_indicators = True
- block_indicators = True
- if ch == u"#" and preceeded_by_whitespace:
- flow_indicators = True
- block_indicators = True
-
- # Check for line breaks, special, and unicode characters.
- if ch in u"\n\x85\u2028\u2029":
- line_breaks = True
- if not (ch == u"\n" or u"\x20" <= ch <= u"\x7E"):
- if (
- ch == u"\x85"
- or u"\xA0" <= ch <= u"\uD7FF"
- or u"\uE000" <= ch <= u"\uFFFD"
- or (
- self.unicode_supplementary
- and (u"\U00010000" <= ch <= u"\U0010FFFF")
- )
- ) and ch != u"\uFEFF":
- # unicode_characters = True
- if not self.allow_unicode:
- special_characters = True
- else:
- special_characters = True
-
- # Detect important whitespace combinations.
- if ch == u" ":
- if index == 0:
- leading_space = True
- if index == len(scalar) - 1:
- trailing_space = True
- if previous_break:
- break_space = True
- previous_space = True
- previous_break = False
- elif ch in u"\n\x85\u2028\u2029":
- if index == 0:
- leading_break = True
- if index == len(scalar) - 1:
- trailing_break = True
- if previous_space:
- space_break = True
- previous_space = False
- previous_break = True
- else:
- previous_space = False
- previous_break = False
-
- # Prepare for the next character.
- index += 1
- preceeded_by_whitespace = ch in u"\0 \t\r\n\x85\u2028\u2029"
- followed_by_whitespace = (
- index + 1 >= len(scalar)
- or scalar[index + 1] in u"\0 \t\r\n\x85\u2028\u2029"
- )
-
- # Let's decide what styles are allowed.
- allow_flow_plain = True
- allow_block_plain = True
- allow_single_quoted = True
- allow_double_quoted = True
- allow_block = True
-
- # Leading and trailing whitespaces are bad for plain scalars.
- if leading_space or leading_break or trailing_space or trailing_break:
- allow_flow_plain = allow_block_plain = False
-
- # We do not permit trailing spaces for block scalars.
- if trailing_space:
- allow_block = False
-
- # Spaces at the beginning of a new line are only acceptable for block
- # scalars.
- if break_space:
- allow_flow_plain = allow_block_plain = allow_single_quoted = False
-
- # Spaces followed by breaks, as well as special character are only
- # allowed for double quoted scalars.
- if special_characters:
- allow_flow_plain = (
- allow_block_plain
- ) = allow_single_quoted = allow_block = False
- elif space_break:
- allow_flow_plain = allow_block_plain = allow_single_quoted = False
- if not self.allow_space_break:
- allow_block = False
-
- # Although the plain scalar writer supports breaks, we never emit
- # multiline plain scalars.
- if line_breaks:
- allow_flow_plain = allow_block_plain = False
-
- # Flow indicators are forbidden for flow plain scalars.
- if flow_indicators:
- allow_flow_plain = False
-
- # Block indicators are forbidden for block plain scalars.
- if block_indicators:
- allow_block_plain = False
-
- return ScalarAnalysis(
- scalar=scalar,
- empty=False,
- multiline=line_breaks,
- allow_flow_plain=allow_flow_plain,
- allow_block_plain=allow_block_plain,
- allow_single_quoted=allow_single_quoted,
- allow_double_quoted=allow_double_quoted,
- allow_block=allow_block,
- )
-
- # Writers.
-
- def flush_stream(self):
- # type: () -> None
- if hasattr(self.stream, "flush"):
- self.stream.flush()
-
- def write_stream_start(self):
- # type: () -> None
- # Write BOM if needed.
- if self.encoding and self.encoding.startswith("utf-16"):
- self.stream.write(u"\uFEFF".encode(self.encoding))
-
- def write_stream_end(self):
- # type: () -> None
- self.flush_stream()
-
- def write_indicator(
- self, indicator, need_whitespace, whitespace=False, indention=False
- ):
- # type: (Any, Any, bool, bool) -> None
- if self.whitespace or not need_whitespace:
- data = indicator
- else:
- data = u" " + indicator
- self.whitespace = whitespace
- self.indention = self.indention and indention
- self.column += len(data)
- self.open_ended = False
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
-
- def write_indent(self):
- # type: () -> None
- indent = self.indent or 0
- if (
- not self.indention
- or self.column > indent
- or (self.column == indent and not self.whitespace)
- ):
- if bool(self.no_newline):
- self.no_newline = False
- else:
- self.write_line_break()
- if self.column < indent:
- self.whitespace = True
- data = u" " * (indent - self.column)
- self.column = indent
- if self.encoding:
- data = data.encode(self.encoding)
- self.stream.write(data)
-
- def write_line_break(self, data=None):
- # type: (Any) -> None
- if data is None:
- data = self.best_line_break
- self.whitespace = True
- self.indention = True
- self.line += 1
- self.column = 0
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
-
- def write_version_directive(self, version_text):
- # type: (Any) -> None
- data = u"%%YAML %s" % version_text
- if self.encoding:
- data = data.encode(self.encoding)
- self.stream.write(data)
- self.write_line_break()
-
- def write_tag_directive(self, handle_text, prefix_text):
- # type: (Any, Any) -> None
- data = u"%%TAG %s %s" % (handle_text, prefix_text)
- if self.encoding:
- data = data.encode(self.encoding)
- self.stream.write(data)
- self.write_line_break()
-
- # Scalar streams.
-
- def write_single_quoted(self, text, split=True):
- # type: (Any, Any) -> None
- if self.root_context:
- if self.requested_indent is not None:
- self.write_line_break()
- if self.requested_indent != 0:
- self.write_indent()
- self.write_indicator(u"'", True)
- spaces = False
- breaks = False
- start = end = 0
- while end <= len(text):
- ch = None
- if end < len(text):
- ch = text[end]
- if spaces:
- if ch is None or ch != u" ":
- if (
- start + 1 == end
- and self.column > self.best_width
- and split
- and start != 0
- and end != len(text)
- ):
- self.write_indent()
- else:
- data = text[start:end]
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end
- elif breaks:
- if ch is None or ch not in u"\n\x85\u2028\u2029":
- if text[start] == u"\n":
- self.write_line_break()
- for br in text[start:end]:
- if br == u"\n":
- self.write_line_break()
- else:
- self.write_line_break(br)
- self.write_indent()
- start = end
- else:
- if ch is None or ch in u" \n\x85\u2028\u2029" or ch == u"'":
- if start < end:
- data = text[start:end]
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end
- if ch == u"'":
- data = u"''"
- self.column += 2
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end + 1
- if ch is not None:
- spaces = ch == u" "
- breaks = ch in u"\n\x85\u2028\u2029"
- end += 1
- self.write_indicator(u"'", False)
-
- ESCAPE_REPLACEMENTS = {
- u"\0": u"0",
- u"\x07": u"a",
- u"\x08": u"b",
- u"\x09": u"t",
- u"\x0A": u"n",
- u"\x0B": u"v",
- u"\x0C": u"f",
- u"\x0D": u"r",
- u"\x1B": u"e",
- u'"': u'"',
- u"\\": u"\\",
- u"\x85": u"N",
- u"\xA0": u"_",
- u"\u2028": u"L",
- u"\u2029": u"P",
- }
-
- def write_double_quoted(self, text, split=True):
- # type: (Any, Any) -> None
- if self.root_context:
- if self.requested_indent is not None:
- self.write_line_break()
- if self.requested_indent != 0:
- self.write_indent()
- self.write_indicator(u'"', True)
- start = end = 0
- while end <= len(text):
- ch = None
- if end < len(text):
- ch = text[end]
- if (
- ch is None
- or ch in u'"\\\x85\u2028\u2029\uFEFF'
- or not (
- u"\x20" <= ch <= u"\x7E"
- or (
- self.allow_unicode
- and (u"\xA0" <= ch <= u"\uD7FF" or u"\uE000" <= ch <= u"\uFFFD")
- )
- )
- ):
- if start < end:
- data = text[start:end]
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end
- if ch is not None:
- if ch in self.ESCAPE_REPLACEMENTS:
- data = u"\\" + self.ESCAPE_REPLACEMENTS[ch]
- elif ch <= u"\xFF":
- data = u"\\x%02X" % ord(ch)
- elif ch <= u"\uFFFF":
- data = u"\\u%04X" % ord(ch)
- else:
- data = u"\\U%08X" % ord(ch)
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end + 1
- if (
- 0 < end < len(text) - 1
- and (ch == u" " or start >= end)
- and self.column + (end - start) > self.best_width
- and split
- ):
- data = text[start:end] + u"\\"
- if start < end:
- start = end
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- self.write_indent()
- self.whitespace = False
- self.indention = False
- if text[start] == u" ":
- data = u"\\"
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- end += 1
- self.write_indicator(u'"', False)
-
- def determine_block_hints(self, text):
- # type: (Any) -> Any
- indent = 0
- indicator = u""
- hints = u""
- if text:
- if text[0] in u" \n\x85\u2028\u2029":
- indent = self.best_sequence_indent
- hints += text_type(indent)
- elif self.root_context:
- for end in ["\n---", "\n..."]:
- pos = 0
- while True:
- pos = text.find(end, pos)
- if pos == -1:
- break
- try:
- if text[pos + 4] in " \r\n":
- break
- except IndexError:
- pass
- pos += 1
- if pos > -1:
- break
- if pos > 0:
- indent = self.best_sequence_indent
- if text[-1] not in u"\n\x85\u2028\u2029":
- indicator = u"-"
- elif len(text) == 1 or text[-2] in u"\n\x85\u2028\u2029":
- indicator = u"+"
- hints += indicator
- return hints, indent, indicator
-
- def write_folded(self, text):
- # type: (Any) -> None
- hints, _indent, _indicator = self.determine_block_hints(text)
- self.write_indicator(u">" + hints, True)
- if _indicator == u"+":
- self.open_ended = True
- self.write_line_break()
- leading_space = True
- spaces = False
- breaks = True
- start = end = 0
- while end <= len(text):
- ch = None
- if end < len(text):
- ch = text[end]
- if breaks:
- if ch is None or ch not in u"\n\x85\u2028\u2029\a":
- if (
- not leading_space
- and ch is not None
- and ch != u" "
- and text[start] == u"\n"
- ):
- self.write_line_break()
- leading_space = ch == u" "
- for br in text[start:end]:
- if br == u"\n":
- self.write_line_break()
- else:
- self.write_line_break(br)
- if ch is not None:
- self.write_indent()
- start = end
- elif spaces:
- if ch != u" ":
- if start + 1 == end and self.column > self.best_width:
- self.write_indent()
- else:
- data = text[start:end]
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end
- else:
- if ch is None or ch in u" \n\x85\u2028\u2029\a":
- data = text[start:end]
- self.column += len(data)
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- if ch == u"\a":
- if end < (len(text) - 1) and not text[end + 2].isspace():
- self.write_line_break()
- self.write_indent()
- end += 2 # \a and the space that is inserted on the fold
- else:
- raise EmitterError(
- "unexcpected fold indicator \\a before space"
- )
- if ch is None:
- self.write_line_break()
- start = end
- if ch is not None:
- breaks = ch in u"\n\x85\u2028\u2029"
- spaces = ch == u" "
- end += 1
-
- def write_literal(self, text, comment=None):
- # type: (Any, Any) -> None
- hints, _indent, _indicator = self.determine_block_hints(text)
- self.write_indicator(u"|" + hints, True)
- try:
- comment = comment[1][0]
- if comment:
- self.stream.write(comment)
- except (TypeError, IndexError):
- pass
- if _indicator == u"+":
- self.open_ended = True
- self.write_line_break()
- breaks = True
- start = end = 0
- while end <= len(text):
- ch = None
- if end < len(text):
- ch = text[end]
- if breaks:
- if ch is None or ch not in u"\n\x85\u2028\u2029":
- for br in text[start:end]:
- if br == u"\n":
- self.write_line_break()
- else:
- self.write_line_break(br)
- if ch is not None:
- if self.root_context:
- idnx = self.indent if self.indent is not None else 0
- self.stream.write(u" " * (_indent + idnx))
- else:
- self.write_indent()
- start = end
- else:
- if ch is None or ch in u"\n\x85\u2028\u2029":
- data = text[start:end]
- if bool(self.encoding):
- data = data.encode(self.encoding)
- self.stream.write(data)
- if ch is None:
- self.write_line_break()
- start = end
- if ch is not None:
- breaks = ch in u"\n\x85\u2028\u2029"
- end += 1
-
- def write_plain(self, text, split=True):
- # type: (Any, Any) -> None
- if self.root_context:
- if self.requested_indent is not None:
- self.write_line_break()
- if self.requested_indent != 0:
- self.write_indent()
- else:
- self.open_ended = True
- if not text:
- return
- if not self.whitespace:
- data = u" "
- self.column += len(data)
- if self.encoding:
- data = data.encode(self.encoding)
- self.stream.write(data)
- self.whitespace = False
- self.indention = False
- spaces = False
- breaks = False
- start = end = 0
- while end <= len(text):
- ch = None
- if end < len(text):
- ch = text[end]
- if spaces:
- if ch != u" ":
- if start + 1 == end and self.column > self.best_width and split:
- self.write_indent()
- self.whitespace = False
- self.indention = False
- else:
- data = text[start:end]
- self.column += len(data)
- if self.encoding:
- data = data.encode(self.encoding)
- self.stream.write(data)
- start = end
- elif breaks:
- if ch not in u"\n\x85\u2028\u2029": # type: ignore
- if text[start] == u"\n":
- self.write_line_break()
- for br in text[start:end]:
- if br == u"\n":
- self.write_line_break()
- else:
- self.write_line_break(br)
- self.write_indent()
- self.whitespace = False
- self.indention = False
- start = end
- else:
- if ch is None or ch in u" \n\x85\u2028\u2029":
- data = text[start:end]
- self.column += len(data)
- if self.encoding:
- data = data.encode(self.encoding)
- try:
- self.stream.write(data)
- except: # NOQA
- sys.stdout.write(repr(data) + "\n")
- raise
- start = end
- if ch is not None:
- spaces = ch == u" "
- breaks = ch in u"\n\x85\u2028\u2029"
- end += 1
-
- def write_comment(self, comment, pre=False):
- # type: (Any, bool) -> None
- value = comment.value
- # nprintf('{:02d} {:02d} {!r}'.format(self.column, comment.start_mark.column, value))
- if not pre and value[-1] == "\n":
- value = value[:-1]
- try:
- # get original column position
- col = comment.start_mark.column
- if comment.value and comment.value.startswith("\n"):
- # never inject extra spaces if the comment starts with a newline
- # and not a real comment (e.g. if you have an empty line following a key-value
- col = self.column
- elif col < self.column + 1:
- ValueError
- except ValueError:
- col = self.column + 1
- # nprint('post_comment', self.line, self.column, value)
- try:
- # at least one space if the current column >= the start column of the comment
- # but not at the start of a line
- nr_spaces = col - self.column
- if self.column and value.strip() and nr_spaces < 1 and value[0] != "\n":
- nr_spaces = 1
- value = " " * nr_spaces + value
- try:
- if bool(self.encoding):
- value = value.encode(self.encoding)
- except UnicodeDecodeError:
- pass
- self.stream.write(value)
- except TypeError:
- raise
- if not pre:
- self.write_line_break()
-
- def write_pre_comment(self, event):
- # type: (Any) -> bool
- comments = event.comment[1]
- if comments is None:
- return False
- try:
- start_events = (MappingStartEvent, SequenceStartEvent)
- for comment in comments:
- if isinstance(event, start_events) and getattr(
- comment, "pre_done", None
- ):
- continue
- if self.column != 0:
- self.write_line_break()
- self.write_comment(comment, pre=True)
- if isinstance(event, start_events):
- comment.pre_done = True
- except TypeError:
- sys.stdout.write("eventtt {} {}".format(type(event), event))
- raise
- return True
-
- def write_post_comment(self, event):
- # type: (Any) -> bool
- if self.event.comment[0] is None:
- return False
- comment = event.comment[0]
- self.write_comment(comment)
- return True
diff --git a/srsly/ruamel_yaml/error.py b/srsly/ruamel_yaml/error.py
deleted file mode 100755
index 93040ac..0000000
--- a/srsly/ruamel_yaml/error.py
+++ /dev/null
@@ -1,321 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-import warnings
-import textwrap
-
-from .compat import utf8
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Text # NOQA
-
-
-__all__ = [
- "FileMark",
- "StringMark",
- "CommentMark",
- "YAMLError",
- "MarkedYAMLError",
- "ReusedAnchorWarning",
- "UnsafeLoaderWarning",
- "MarkedYAMLWarning",
- "MarkedYAMLFutureWarning",
-]
-
-
-class StreamMark(object):
- __slots__ = "name", "index", "line", "column"
-
- def __init__(self, name, index, line, column):
- # type: (Any, int, int, int) -> None
- self.name = name
- self.index = index
- self.line = line
- self.column = column
-
- def __str__(self):
- # type: () -> Any
- where = ' in "%s", line %d, column %d' % (
- self.name,
- self.line + 1,
- self.column + 1,
- )
- return where
-
- def __eq__(self, other):
- # type: (Any) -> bool
- if self.line != other.line or self.column != other.column:
- return False
- if self.name != other.name or self.index != other.index:
- return False
- return True
-
- def __ne__(self, other):
- # type: (Any) -> bool
- return not self.__eq__(other)
-
-
-class FileMark(StreamMark):
- __slots__ = ()
-
-
-class StringMark(StreamMark):
- __slots__ = "name", "index", "line", "column", "buffer", "pointer"
-
- def __init__(self, name, index, line, column, buffer, pointer):
- # type: (Any, int, int, int, Any, Any) -> None
- StreamMark.__init__(self, name, index, line, column)
- self.buffer = buffer
- self.pointer = pointer
-
- def get_snippet(self, indent=4, max_length=75):
- # type: (int, int) -> Any
- if self.buffer is None: # always False
- return None
- head = ""
- start = self.pointer
- while start > 0 and self.buffer[start - 1] not in u"\0\r\n\x85\u2028\u2029":
- start -= 1
- if self.pointer - start > max_length / 2 - 1:
- head = " ... "
- start += 5
- break
- tail = ""
- end = self.pointer
- while (
- end < len(self.buffer) and self.buffer[end] not in u"\0\r\n\x85\u2028\u2029"
- ):
- end += 1
- if end - self.pointer > max_length / 2 - 1:
- tail = " ... "
- end -= 5
- break
- snippet = utf8(self.buffer[start:end])
- caret = "^"
- caret = "^ (line: {})".format(self.line + 1)
- return (
- " " * indent
- + head
- + snippet
- + tail
- + "\n"
- + " " * (indent + self.pointer - start + len(head))
- + caret
- )
-
- def __str__(self):
- # type: () -> Any
- snippet = self.get_snippet()
- where = ' in "%s", line %d, column %d' % (
- self.name,
- self.line + 1,
- self.column + 1,
- )
- if snippet is not None:
- where += ":\n" + snippet
- return where
-
-
-class CommentMark(object):
- __slots__ = ("column",)
-
- def __init__(self, column):
- # type: (Any) -> None
- self.column = column
-
-
-class YAMLError(Exception):
- pass
-
-
-class MarkedYAMLError(YAMLError):
- def __init__(
- self,
- context=None,
- context_mark=None,
- problem=None,
- problem_mark=None,
- note=None,
- warn=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- self.context = context
- self.context_mark = context_mark
- self.problem = problem
- self.problem_mark = problem_mark
- self.note = note
- # warn is ignored
-
- def __str__(self):
- # type: () -> Any
- lines = [] # type: List[str]
- if self.context is not None:
- lines.append(self.context)
- if self.context_mark is not None and (
- self.problem is None
- or self.problem_mark is None
- or self.context_mark.name != self.problem_mark.name
- or self.context_mark.line != self.problem_mark.line
- or self.context_mark.column != self.problem_mark.column
- ):
- lines.append(str(self.context_mark))
- if self.problem is not None:
- lines.append(self.problem)
- if self.problem_mark is not None:
- lines.append(str(self.problem_mark))
- if self.note is not None and self.note:
- note = textwrap.dedent(self.note)
- lines.append(note)
- return "\n".join(lines)
-
-
-class YAMLStreamError(Exception):
- pass
-
-
-class YAMLWarning(Warning):
- pass
-
-
-class MarkedYAMLWarning(YAMLWarning):
- def __init__(
- self,
- context=None,
- context_mark=None,
- problem=None,
- problem_mark=None,
- note=None,
- warn=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- self.context = context
- self.context_mark = context_mark
- self.problem = problem
- self.problem_mark = problem_mark
- self.note = note
- self.warn = warn
-
- def __str__(self):
- # type: () -> Any
- lines = [] # type: List[str]
- if self.context is not None:
- lines.append(self.context)
- if self.context_mark is not None and (
- self.problem is None
- or self.problem_mark is None
- or self.context_mark.name != self.problem_mark.name
- or self.context_mark.line != self.problem_mark.line
- or self.context_mark.column != self.problem_mark.column
- ):
- lines.append(str(self.context_mark))
- if self.problem is not None:
- lines.append(self.problem)
- if self.problem_mark is not None:
- lines.append(str(self.problem_mark))
- if self.note is not None and self.note:
- note = textwrap.dedent(self.note)
- lines.append(note)
- if self.warn is not None and self.warn:
- warn = textwrap.dedent(self.warn)
- lines.append(warn)
- return "\n".join(lines)
-
-
-class ReusedAnchorWarning(YAMLWarning):
- pass
-
-
-class UnsafeLoaderWarning(YAMLWarning):
- text = """
-The default 'Loader' for 'load(stream)' without further arguments can be unsafe.
-Use 'load(stream, Loader=srsly.ruamel_yaml.Loader)' explicitly if that is OK.
-Alternatively include the following in your code:
-
- import warnings
- warnings.simplefilter('ignore', srsly.ruamel_yaml.error.UnsafeLoaderWarning)
-
-In most other cases you should consider using 'safe_load(stream)'"""
- pass
-
-
-warnings.simplefilter("once", UnsafeLoaderWarning)
-
-
-class MantissaNoDotYAML1_1Warning(YAMLWarning):
- def __init__(self, node, flt_str):
- # type: (Any, Any) -> None
- self.node = node
- self.flt = flt_str
-
- def __str__(self):
- # type: () -> Any
- line = self.node.start_mark.line
- col = self.node.start_mark.column
- return """
-In YAML 1.1 floating point values should have a dot ('.') in their mantissa.
-See the Floating-Point Language-Independent Type for YAML™ Version 1.1 specification
-( http://yaml.org/type/float.html ). This dot is not required for JSON nor for YAML 1.2
-
-Correct your float: "{}" on line: {}, column: {}
-
-or alternatively include the following in your code:
-
- import warnings
- warnings.simplefilter('ignore', srsly.ruamel_yaml.error.MantissaNoDotYAML1_1Warning)
-
-""".format(
- self.flt, line, col
- )
-
-
-warnings.simplefilter("once", MantissaNoDotYAML1_1Warning)
-
-
-class YAMLFutureWarning(Warning):
- pass
-
-
-class MarkedYAMLFutureWarning(YAMLFutureWarning):
- def __init__(
- self,
- context=None,
- context_mark=None,
- problem=None,
- problem_mark=None,
- note=None,
- warn=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- self.context = context
- self.context_mark = context_mark
- self.problem = problem
- self.problem_mark = problem_mark
- self.note = note
- self.warn = warn
-
- def __str__(self):
- # type: () -> Any
- lines = [] # type: List[str]
- if self.context is not None:
- lines.append(self.context)
-
- if self.context_mark is not None and (
- self.problem is None
- or self.problem_mark is None
- or self.context_mark.name != self.problem_mark.name
- or self.context_mark.line != self.problem_mark.line
- or self.context_mark.column != self.problem_mark.column
- ):
- lines.append(str(self.context_mark))
- if self.problem is not None:
- lines.append(self.problem)
- if self.problem_mark is not None:
- lines.append(str(self.problem_mark))
- if self.note is not None and self.note:
- note = textwrap.dedent(self.note)
- lines.append(note)
- if self.warn is not None and self.warn:
- warn = textwrap.dedent(self.warn)
- lines.append(warn)
- return "\n".join(lines)
diff --git a/srsly/ruamel_yaml/events.py b/srsly/ruamel_yaml/events.py
deleted file mode 100755
index 58b2121..0000000
--- a/srsly/ruamel_yaml/events.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# coding: utf-8
-
-# Abstract classes.
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List # NOQA
-
-
-def CommentCheck():
- # type: () -> None
- pass
-
-
-class Event(object):
- __slots__ = 'start_mark', 'end_mark', 'comment'
-
- def __init__(self, start_mark=None, end_mark=None, comment=CommentCheck):
- # type: (Any, Any, Any) -> None
- self.start_mark = start_mark
- self.end_mark = end_mark
- # assert comment is not CommentCheck
- if comment is CommentCheck:
- comment = None
- self.comment = comment
-
- def __repr__(self):
- # type: () -> Any
- attributes = [
- key
- for key in ['anchor', 'tag', 'implicit', 'value', 'flow_style', 'style']
- if hasattr(self, key)
- ]
- arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes])
- if self.comment not in [None, CommentCheck]:
- arguments += ', comment={!r}'.format(self.comment)
- return '%s(%s)' % (self.__class__.__name__, arguments)
-
-
-class NodeEvent(Event):
- __slots__ = ('anchor',)
-
- def __init__(self, anchor, start_mark=None, end_mark=None, comment=None):
- # type: (Any, Any, Any, Any) -> None
- Event.__init__(self, start_mark, end_mark, comment)
- self.anchor = anchor
-
-
-class CollectionStartEvent(NodeEvent):
- __slots__ = 'tag', 'implicit', 'flow_style', 'nr_items'
-
- def __init__(
- self,
- anchor,
- tag,
- implicit,
- start_mark=None,
- end_mark=None,
- flow_style=None,
- comment=None,
- nr_items=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any, Any, Optional[int]) -> None
- NodeEvent.__init__(self, anchor, start_mark, end_mark, comment)
- self.tag = tag
- self.implicit = implicit
- self.flow_style = flow_style
- self.nr_items = nr_items
-
-
-class CollectionEndEvent(Event):
- __slots__ = ()
-
-
-# Implementations.
-
-
-class StreamStartEvent(Event):
- __slots__ = ('encoding',)
-
- def __init__(self, start_mark=None, end_mark=None, encoding=None, comment=None):
- # type: (Any, Any, Any, Any) -> None
- Event.__init__(self, start_mark, end_mark, comment)
- self.encoding = encoding
-
-
-class StreamEndEvent(Event):
- __slots__ = ()
-
-
-class DocumentStartEvent(Event):
- __slots__ = 'explicit', 'version', 'tags'
-
- def __init__(
- self,
- start_mark=None,
- end_mark=None,
- explicit=None,
- version=None,
- tags=None,
- comment=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- Event.__init__(self, start_mark, end_mark, comment)
- self.explicit = explicit
- self.version = version
- self.tags = tags
-
-
-class DocumentEndEvent(Event):
- __slots__ = ('explicit',)
-
- def __init__(self, start_mark=None, end_mark=None, explicit=None, comment=None):
- # type: (Any, Any, Any, Any) -> None
- Event.__init__(self, start_mark, end_mark, comment)
- self.explicit = explicit
-
-
-class AliasEvent(NodeEvent):
- __slots__ = ()
-
-
-class ScalarEvent(NodeEvent):
- __slots__ = 'tag', 'implicit', 'value', 'style'
-
- def __init__(
- self,
- anchor,
- tag,
- implicit,
- value,
- start_mark=None,
- end_mark=None,
- style=None,
- comment=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any, Any, Any) -> None
- NodeEvent.__init__(self, anchor, start_mark, end_mark, comment)
- self.tag = tag
- self.implicit = implicit
- self.value = value
- self.style = style
-
-
-class SequenceStartEvent(CollectionStartEvent):
- __slots__ = ()
-
-
-class SequenceEndEvent(CollectionEndEvent):
- __slots__ = ()
-
-
-class MappingStartEvent(CollectionStartEvent):
- __slots__ = ()
-
-
-class MappingEndEvent(CollectionEndEvent):
- __slots__ = ()
diff --git a/srsly/ruamel_yaml/loader.py b/srsly/ruamel_yaml/loader.py
deleted file mode 100755
index 177cac2..0000000
--- a/srsly/ruamel_yaml/loader.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-
-from .reader import Reader
-from .scanner import Scanner, RoundTripScanner
-from .parser import Parser, RoundTripParser
-from .composer import Composer
-from .constructor import (
- BaseConstructor,
- SafeConstructor,
- Constructor,
- RoundTripConstructor,
-)
-from .resolver import VersionedResolver
-
-if False: # MYPY
- from typing import Any, Dict, List, Union, Optional # NOQA
- from .compat import StreamTextType, VersionType # NOQA
-
-__all__ = ["BaseLoader", "SafeLoader", "Loader", "RoundTripLoader"]
-
-
-class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, VersionedResolver):
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- Reader.__init__(self, stream, loader=self)
- Scanner.__init__(self, loader=self)
- Parser.__init__(self, loader=self)
- Composer.__init__(self, loader=self)
- BaseConstructor.__init__(self, loader=self)
- VersionedResolver.__init__(self, version, loader=self)
-
-
-class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, VersionedResolver):
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- Reader.__init__(self, stream, loader=self)
- Scanner.__init__(self, loader=self)
- Parser.__init__(self, loader=self)
- Composer.__init__(self, loader=self)
- SafeConstructor.__init__(self, loader=self)
- VersionedResolver.__init__(self, version, loader=self)
-
-
-class Loader(Reader, Scanner, Parser, Composer, Constructor, VersionedResolver):
- def __init__(self, stream, version=None, preserve_quotes=None):
- raise ValueError("Unsafe loader not implemented in this library.")
-
-
-class RoundTripLoader(
- Reader,
- RoundTripScanner,
- RoundTripParser,
- Composer,
- RoundTripConstructor,
- VersionedResolver,
-):
- def __init__(self, stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
- # self.reader = Reader.__init__(self, stream)
- Reader.__init__(self, stream, loader=self)
- RoundTripScanner.__init__(self, loader=self)
- RoundTripParser.__init__(self, loader=self)
- Composer.__init__(self, loader=self)
- RoundTripConstructor.__init__(
- self, preserve_quotes=preserve_quotes, loader=self
- )
- VersionedResolver.__init__(self, version, loader=self)
diff --git a/srsly/ruamel_yaml/main.py b/srsly/ruamel_yaml/main.py
deleted file mode 100755
index 433cf21..0000000
--- a/srsly/ruamel_yaml/main.py
+++ /dev/null
@@ -1,1561 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import, unicode_literals, print_function
-
-import sys
-import os
-import warnings
-import glob
-from importlib import import_module
-
-
-from . import resolver
-from . import emitter
-from . import representer
-from . import parser
-from . import composer
-from . import constructor
-from . import serializer
-from . import scanner
-from . import loader
-from . import dumper
-from . import reader
-from .error import UnsafeLoaderWarning, YAMLError # NOQA
-
-from .tokens import * # NOQA
-from .events import * # NOQA
-from .nodes import * # NOQA
-
-from .loader import BaseLoader, SafeLoader, Loader, RoundTripLoader # NOQA
-from .dumper import BaseDumper, SafeDumper, Dumper, RoundTripDumper # NOQA
-from .compat import StringIO, BytesIO, with_metaclass, PY3, nprint
-from .resolver import VersionedResolver, Resolver # NOQA
-from .representer import (
- BaseRepresenter,
- SafeRepresenter,
- Representer,
- RoundTripRepresenter,
-)
-from .constructor import (
- BaseConstructor,
- SafeConstructor,
- Constructor,
- RoundTripConstructor,
-)
-from .loader import Loader as UnsafeLoader
-
-if False: # MYPY
- from typing import List, Set, Dict, Union, Any, Callable, Optional, Text # NOQA
- from .compat import StreamType, StreamTextType, VersionType # NOQA
-
- if PY3:
- from pathlib import Path
- else:
- Path = Any
-
-try:
- from _ruamel_yaml import CParser, CEmitter # type: ignore
-except: # NOQA
- CParser = CEmitter = None
-
-# import io
-
-enforce = object()
-
-
-# YAML is an acronym, i.e. spoken: rhymes with "camel". And thus a
-# subset of abbreviations, which should be all caps according to PEP8
-
-
-class YAML(object):
- def __init__(
- self,
- _kw=enforce,
- typ=None,
- pure=False,
- output=None,
- plug_ins=None, # input=None,
- ):
- # type: (Any, Optional[Text], Any, Any, Any) -> None
- """
- _kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..)
- typ: 'rt'/None -> RoundTripLoader/RoundTripDumper, (default)
- 'safe' -> SafeLoader/SafeDumper,
- 'unsafe' -> normal/unsafe Loader/Dumper
- 'base' -> baseloader
- pure: if True only use Python modules
- input/output: needed to work as context manager
- plug_ins: a list of plug-in files
- """
- if _kw is not enforce:
- raise TypeError(
- "{}.__init__() takes no positional argument but at least "
- "one was given ({!r})".format(self.__class__.__name__, _kw)
- )
-
- self.typ = ["rt"] if typ is None else (typ if isinstance(typ, list) else [typ])
- self.pure = pure
-
- # self._input = input
- self._output = output
- self._context_manager = None # type: Any
-
- self.plug_ins = [] # type: List[Any]
- for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins():
- file_name = pu.replace(os.sep, ".")
- self.plug_ins.append(import_module(file_name))
- self.Resolver = resolver.VersionedResolver # type: Any
- self.allow_unicode = True
- self.Reader = None # type: Any
- self.Representer = None # type: Any
- self.Constructor = None # type: Any
- self.Scanner = None # type: Any
- self.Serializer = None # type: Any
- self.default_flow_style = None # type: Any
- typ_found = 1
- setup_rt = False
- if "rt" in self.typ:
- setup_rt = True
- elif "safe" in self.typ:
- self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter
- self.Representer = representer.SafeRepresenter
- self.Parser = parser.Parser if pure or CParser is None else CParser
- self.Composer = composer.Composer
- self.Constructor = constructor.SafeConstructor
- elif "base" in self.typ:
- self.Emitter = emitter.Emitter
- self.Representer = representer.BaseRepresenter
- self.Parser = parser.Parser if pure or CParser is None else CParser
- self.Composer = composer.Composer
- self.Constructor = constructor.BaseConstructor
- elif "unsafe" in self.typ:
- self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter
- self.Representer = representer.Representer
- self.Parser = parser.Parser if pure or CParser is None else CParser
- self.Composer = composer.Composer
- self.Constructor = constructor.Constructor
- else:
- setup_rt = True
- typ_found = 0
- if setup_rt:
- self.default_flow_style = False
- # no optimized rt-dumper yet
- self.Emitter = emitter.Emitter
- self.Serializer = serializer.Serializer
- self.Representer = representer.RoundTripRepresenter
- self.Scanner = scanner.RoundTripScanner
- # no optimized rt-parser yet
- self.Parser = parser.RoundTripParser
- self.Composer = composer.Composer
- self.Constructor = constructor.RoundTripConstructor
- del setup_rt
- self.stream = None
- self.canonical = None
- self.old_indent = None
- self.width = None
- self.line_break = None
-
- self.map_indent = None
- self.sequence_indent = None
- self.sequence_dash_offset = 0
- self.compact_seq_seq = None
- self.compact_seq_map = None
- self.sort_base_mapping_type_on_output = None # default: sort
-
- self.top_level_colon_align = None
- self.prefix_colon = None
- self.version = None
- self.preserve_quotes = None
- self.allow_duplicate_keys = False # duplicate keys in map, set
- self.encoding = "utf-8"
- self.explicit_start = None
- self.explicit_end = None
- self.tags = None
- self.default_style = None
- self.top_level_block_style_scalar_no_indent_error_1_1 = False
- # directives end indicator with single scalar document
- self.scalar_after_indicator = None
- # [a, b: 1, c: {d: 2}] vs. [a, {b: 1}, {c: {d: 2}}]
- self.brace_single_entry_mapping_in_flow_sequence = False
- for module in self.plug_ins:
- if getattr(module, "typ", None) in self.typ:
- typ_found += 1
- module.init_typ(self)
- break
- if typ_found == 0:
- raise NotImplementedError(
- 'typ "{}"not recognised (need to install plug-in?)'.format(self.typ)
- )
-
- @property
- def reader(self):
- # type: () -> Any
- try:
- return self._reader # type: ignore
- except AttributeError:
- self._reader = self.Reader(None, loader=self)
- return self._reader
-
- @property
- def scanner(self):
- # type: () -> Any
- try:
- return self._scanner # type: ignore
- except AttributeError:
- self._scanner = self.Scanner(loader=self)
- return self._scanner
-
- @property
- def parser(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- if self.Parser is not CParser:
- setattr(self, attr, self.Parser(loader=self))
- else:
- if getattr(self, "_stream", None) is None:
- # wait for the stream
- return None
- else:
- # if not hasattr(self._stream, 'read') and hasattr(self._stream, 'open'):
- # # pathlib.Path() instance
- # setattr(self, attr, CParser(self._stream))
- # else:
- setattr(self, attr, CParser(self._stream))
- # self._parser = self._composer = self
- # nprint('scanner', self.loader.scanner)
-
- return getattr(self, attr)
-
- @property
- def composer(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- setattr(self, attr, self.Composer(loader=self))
- return getattr(self, attr)
-
- @property
- def constructor(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- cnst = self.Constructor(preserve_quotes=self.preserve_quotes, loader=self)
- cnst.allow_duplicate_keys = self.allow_duplicate_keys
- setattr(self, attr, cnst)
- return getattr(self, attr)
-
- @property
- def resolver(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- setattr(self, attr, self.Resolver(version=self.version, loader=self))
- return getattr(self, attr)
-
- @property
- def emitter(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- if self.Emitter is not CEmitter:
- _emitter = self.Emitter(
- None,
- canonical=self.canonical,
- indent=self.old_indent,
- width=self.width,
- allow_unicode=self.allow_unicode,
- line_break=self.line_break,
- prefix_colon=self.prefix_colon,
- brace_single_entry_mapping_in_flow_sequence=self.brace_single_entry_mapping_in_flow_sequence, # NOQA
- dumper=self,
- )
- setattr(self, attr, _emitter)
- if self.map_indent is not None:
- _emitter.best_map_indent = self.map_indent
- if self.sequence_indent is not None:
- _emitter.best_sequence_indent = self.sequence_indent
- if self.sequence_dash_offset is not None:
- _emitter.sequence_dash_offset = self.sequence_dash_offset
- # _emitter.block_seq_indent = self.sequence_dash_offset
- if self.compact_seq_seq is not None:
- _emitter.compact_seq_seq = self.compact_seq_seq
- if self.compact_seq_map is not None:
- _emitter.compact_seq_map = self.compact_seq_map
- else:
- if getattr(self, "_stream", None) is None:
- # wait for the stream
- return None
- return None
- return getattr(self, attr)
-
- @property
- def serializer(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- setattr(
- self,
- attr,
- self.Serializer(
- encoding=self.encoding,
- explicit_start=self.explicit_start,
- explicit_end=self.explicit_end,
- version=self.version,
- tags=self.tags,
- dumper=self,
- ),
- )
- return getattr(self, attr)
-
- @property
- def representer(self):
- # type: () -> Any
- attr = "_" + sys._getframe().f_code.co_name
- if not hasattr(self, attr):
- repres = self.Representer(
- default_style=self.default_style,
- default_flow_style=self.default_flow_style,
- dumper=self,
- )
- if self.sort_base_mapping_type_on_output is not None:
- repres.sort_base_mapping_type_on_output = (
- self.sort_base_mapping_type_on_output
- )
- setattr(self, attr, repres)
- return getattr(self, attr)
-
- # separate output resolver?
-
- # def load(self, stream=None):
- # if self._context_manager:
- # if not self._input:
- # raise TypeError("Missing input stream while dumping from context manager")
- # for data in self._context_manager.load():
- # yield data
- # return
- # if stream is None:
- # raise TypeError("Need a stream argument when not loading from context manager")
- # return self.load_one(stream)
-
- def load(self, stream):
- # type: (Union[Path, StreamTextType]) -> Any
- """
- at this point you either have the non-pure Parser (which has its own reader and
- scanner) or you have the pure Parser.
- If the pure Parser is set, then set the Reader and Scanner, if not already set.
- If either the Scanner or Reader are set, you cannot use the non-pure Parser,
- so reset it to the pure parser and set the Reader resp. Scanner if necessary
- """
- if not hasattr(stream, "read") and hasattr(stream, "open"):
- # pathlib.Path() instance
- with stream.open("rb") as fp:
- return self.load(fp)
- constructor, parser = self.get_constructor_parser(stream)
- try:
- return constructor.get_single_data()
- finally:
- parser.dispose()
- try:
- self._reader.reset_reader()
- except AttributeError:
- pass
- try:
- self._scanner.reset_scanner()
- except AttributeError:
- pass
-
- def load_all(self, stream, _kw=enforce): # , skip=None):
- # type: (Union[Path, StreamTextType], Any) -> Any
- if _kw is not enforce:
- raise TypeError(
- "{}.__init__() takes no positional argument but at least "
- "one was given ({!r})".format(self.__class__.__name__, _kw)
- )
- if not hasattr(stream, "read") and hasattr(stream, "open"):
- # pathlib.Path() instance
- with stream.open("r") as fp:
- for d in self.load_all(fp, _kw=enforce):
- yield d
- return
- # if skip is None:
- # skip = []
- # elif isinstance(skip, int):
- # skip = [skip]
- constructor, parser = self.get_constructor_parser(stream)
- try:
- while constructor.check_data():
- yield constructor.get_data()
- finally:
- parser.dispose()
- try:
- self._reader.reset_reader()
- except AttributeError:
- pass
- try:
- self._scanner.reset_scanner()
- except AttributeError:
- pass
-
- def get_constructor_parser(self, stream):
- # type: (StreamTextType) -> Any
- """
- the old cyaml needs special setup, and therefore the stream
- """
- if self.Parser is not CParser:
- if self.Reader is None:
- self.Reader = reader.Reader
- if self.Scanner is None:
- self.Scanner = scanner.Scanner
- self.reader.stream = stream
- else:
- if self.Reader is not None:
- if self.Scanner is None:
- self.Scanner = scanner.Scanner
- self.Parser = parser.Parser
- self.reader.stream = stream
- elif self.Scanner is not None:
- if self.Reader is None:
- self.Reader = reader.Reader
- self.Parser = parser.Parser
- self.reader.stream = stream
- else:
- # combined C level reader>scanner>parser
- # does some calls to the resolver, e.g. BaseResolver.descend_resolver
- # if you just initialise the CParser, to much of resolver.py
- # is actually used
- rslvr = self.Resolver
- # if rslvr is srsly.ruamel_yaml.resolver.VersionedResolver:
- # rslvr = srsly.ruamel_yaml.resolver.Resolver
-
- class XLoader(self.Parser, self.Constructor, rslvr): # type: ignore
- def __init__(
- selfx, stream, version=self.version, preserve_quotes=None
- ):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None # NOQA
- CParser.__init__(selfx, stream)
- selfx._parser = selfx._composer = selfx
- self.Constructor.__init__(selfx, loader=selfx)
- selfx.allow_duplicate_keys = self.allow_duplicate_keys
- rslvr.__init__(selfx, version=version, loadumper=selfx)
-
- self._stream = stream
- loader = XLoader(stream)
- return loader, loader
- return self.constructor, self.parser
-
- def dump(self, data, stream=None, _kw=enforce, transform=None):
- # type: (Any, Union[Path, StreamType], Any, Any) -> Any
- if self._context_manager:
- if not self._output:
- raise TypeError(
- "Missing output stream while dumping from context manager"
- )
- if _kw is not enforce:
- raise TypeError(
- "{}.dump() takes one positional argument but at least "
- "two were given ({!r})".format(self.__class__.__name__, _kw)
- )
- if transform is not None:
- raise TypeError(
- "{}.dump() in the context manager cannot have transform keyword "
- "".format(self.__class__.__name__)
- )
- self._context_manager.dump(data)
- else: # old style
- if stream is None:
- raise TypeError(
- "Need a stream argument when not dumping from context manager"
- )
- return self.dump_all([data], stream, _kw, transform=transform)
-
- def dump_all(self, documents, stream, _kw=enforce, transform=None):
- # type: (Any, Union[Path, StreamType], Any, Any) -> Any
- if self._context_manager:
- raise NotImplementedError
- if _kw is not enforce:
- raise TypeError(
- "{}.dump(_all) takes two positional argument but at least "
- "three were given ({!r})".format(self.__class__.__name__, _kw)
- )
- self._output = stream
- self._context_manager = YAMLContextManager(self, transform=transform)
- for data in documents:
- self._context_manager.dump(data)
- self._context_manager.teardown_output()
- self._output = None
- self._context_manager = None
-
- def Xdump_all(self, documents, stream, _kw=enforce, transform=None):
- # type: (Any, Union[Path, StreamType], Any, Any) -> Any
- """
- Serialize a sequence of Python objects into a YAML stream.
- """
- if not hasattr(stream, "write") and hasattr(stream, "open"):
- # pathlib.Path() instance
- with stream.open("w") as fp:
- return self.dump_all(documents, fp, _kw, transform=transform)
- if _kw is not enforce:
- raise TypeError(
- "{}.dump(_all) takes two positional argument but at least "
- "three were given ({!r})".format(self.__class__.__name__, _kw)
- )
- # The stream should have the methods `write` and possibly `flush`.
- if self.top_level_colon_align is True:
- tlca = max([len(str(x)) for x in documents[0]]) # type: Any
- else:
- tlca = self.top_level_colon_align
- if transform is not None:
- fstream = stream
- if self.encoding is None:
- stream = StringIO()
- else:
- stream = BytesIO()
- serializer, representer, emitter = self.get_serializer_representer_emitter(
- stream, tlca
- )
- try:
- self.serializer.open()
- for data in documents:
- try:
- self.representer.represent(data)
- except AttributeError:
- # nprint(dir(dumper._representer))
- raise
- self.serializer.close()
- finally:
- try:
- self.emitter.dispose()
- except AttributeError:
- raise
- # self.dumper.dispose() # cyaml
- delattr(self, "_serializer")
- delattr(self, "_emitter")
- if transform:
- val = stream.getvalue()
- if self.encoding:
- val = val.decode(self.encoding)
- if fstream is None:
- transform(val)
- else:
- fstream.write(transform(val))
- return None
-
- def get_serializer_representer_emitter(self, stream, tlca):
- # type: (StreamType, Any) -> Any
- # we have only .Serializer to deal with (vs .Reader & .Scanner), much simpler
- if self.Emitter is not CEmitter:
- if self.Serializer is None:
- self.Serializer = serializer.Serializer
- self.emitter.stream = stream
- self.emitter.top_level_colon_align = tlca
- if self.scalar_after_indicator is not None:
- self.emitter.scalar_after_indicator = self.scalar_after_indicator
- return self.serializer, self.representer, self.emitter
- if self.Serializer is not None:
- # cannot set serializer with CEmitter
- self.Emitter = emitter.Emitter
- self.emitter.stream = stream
- self.emitter.top_level_colon_align = tlca
- if self.scalar_after_indicator is not None:
- self.emitter.scalar_after_indicator = self.scalar_after_indicator
- return self.serializer, self.representer, self.emitter
- # C routines
-
- rslvr = resolver.BaseResolver if "base" in self.typ else resolver.Resolver
-
- class XDumper(CEmitter, self.Representer, rslvr): # type: ignore
- def __init__(
- selfx,
- stream,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- ):
- # type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
- CEmitter.__init__(
- selfx,
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- encoding=encoding,
- allow_unicode=allow_unicode,
- line_break=line_break,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- )
- selfx._emitter = selfx._serializer = selfx._representer = selfx
- self.Representer.__init__(
- selfx,
- default_style=default_style,
- default_flow_style=default_flow_style,
- )
- rslvr.__init__(selfx)
-
- self._stream = stream
- dumper = XDumper(
- stream,
- default_style=self.default_style,
- default_flow_style=self.default_flow_style,
- canonical=self.canonical,
- indent=self.old_indent,
- width=self.width,
- allow_unicode=self.allow_unicode,
- line_break=self.line_break,
- explicit_start=self.explicit_start,
- explicit_end=self.explicit_end,
- version=self.version,
- tags=self.tags,
- )
- self._emitter = self._serializer = dumper
- return dumper, dumper, dumper
-
- # basic types
- def map(self, **kw):
- # type: (Any) -> Any
- if "rt" in self.typ:
- from .comments import CommentedMap
-
- return CommentedMap(**kw)
- else:
- return dict(**kw)
-
- def seq(self, *args):
- # type: (Any) -> Any
- if "rt" in self.typ:
- from .comments import CommentedSeq
-
- return CommentedSeq(*args)
- else:
- return list(*args)
-
- # helpers
- def official_plug_ins(self):
- # type: () -> Any
- bd = os.path.dirname(__file__)
- gpbd = os.path.dirname(os.path.dirname(bd))
- res = [x.replace(gpbd, "")[1:-3] for x in glob.glob(bd + "/*/__plug_in__.py")]
- return res
-
- def register_class(self, cls):
- # type:(Any) -> Any
- """
- register a class for dumping loading
- - if it has attribute yaml_tag use that to register, else use class name
- - if it has methods to_yaml/from_yaml use those to dump/load else dump attributes
- as mapping
- """
- tag = getattr(cls, "yaml_tag", "!" + cls.__name__)
- try:
- self.representer.add_representer(cls, cls.to_yaml)
- except AttributeError:
-
- def t_y(representer, data):
- # type: (Any, Any) -> Any
- return representer.represent_yaml_object(
- tag, data, cls, flow_style=representer.default_flow_style
- )
-
- self.representer.add_representer(cls, t_y)
- try:
- self.constructor.add_constructor(tag, cls.from_yaml)
- except AttributeError:
-
- def f_y(constructor, node):
- # type: (Any, Any) -> Any
- return constructor.construct_yaml_object(node, cls)
-
- self.constructor.add_constructor(tag, f_y)
- return cls
-
- def parse(self, stream):
- # type: (StreamTextType) -> Any
- """
- Parse a YAML stream and produce parsing events.
- """
- _, parser = self.get_constructor_parser(stream)
- try:
- while parser.check_event():
- yield parser.get_event()
- finally:
- parser.dispose()
- try:
- self._reader.reset_reader()
- except AttributeError:
- pass
- try:
- self._scanner.reset_scanner()
- except AttributeError:
- pass
-
- # ### context manager
-
- def __enter__(self):
- # type: () -> Any
- self._context_manager = YAMLContextManager(self)
- return self
-
- def __exit__(self, typ, value, traceback):
- # type: (Any, Any, Any) -> None
- if typ:
- nprint("typ", typ)
- self._context_manager.teardown_output()
- # self._context_manager.teardown_input()
- self._context_manager = None
-
- # ### backwards compatibility
- def _indent(self, mapping=None, sequence=None, offset=None):
- # type: (Any, Any, Any) -> None
- if mapping is not None:
- self.map_indent = mapping
- if sequence is not None:
- self.sequence_indent = sequence
- if offset is not None:
- self.sequence_dash_offset = offset
-
- @property
- def indent(self):
- # type: () -> Any
- return self._indent
-
- @indent.setter
- def indent(self, val):
- # type: (Any) -> None
- self.old_indent = val
-
- @property
- def block_seq_indent(self):
- # type: () -> Any
- return self.sequence_dash_offset
-
- @block_seq_indent.setter
- def block_seq_indent(self, val):
- # type: (Any) -> None
- self.sequence_dash_offset = val
-
- def compact(self, seq_seq=None, seq_map=None):
- # type: (Any, Any) -> None
- self.compact_seq_seq = seq_seq
- self.compact_seq_map = seq_map
-
-
-class YAMLContextManager(object):
- def __init__(self, yaml, transform=None):
- # type: (Any, Any) -> None # used to be: (Any, Optional[Callable]) -> None
- self._yaml = yaml
- self._output_inited = False
- self._output_path = None
- self._output = self._yaml._output
- self._transform = transform
-
- # self._input_inited = False
- # self._input = input
- # self._input_path = None
- # self._transform = yaml.transform
- # self._fstream = None
-
- if not hasattr(self._output, "write") and hasattr(self._output, "open"):
- # pathlib.Path() instance, open with the same mode
- self._output_path = self._output
- self._output = self._output_path.open("w")
-
- # if not hasattr(self._stream, 'write') and hasattr(stream, 'open'):
- # if not hasattr(self._input, 'read') and hasattr(self._input, 'open'):
- # # pathlib.Path() instance, open with the same mode
- # self._input_path = self._input
- # self._input = self._input_path.open('r')
-
- if self._transform is not None:
- self._fstream = self._output
- if self._yaml.encoding is None:
- self._output = StringIO()
- else:
- self._output = BytesIO()
-
- def teardown_output(self):
- # type: () -> None
- if self._output_inited:
- self._yaml.serializer.close()
- else:
- return
- try:
- self._yaml.emitter.dispose()
- except AttributeError:
- raise
- # self.dumper.dispose() # cyaml
- try:
- delattr(self._yaml, "_serializer")
- delattr(self._yaml, "_emitter")
- except AttributeError:
- raise
- if self._transform:
- val = self._output.getvalue()
- if self._yaml.encoding:
- val = val.decode(self._yaml.encoding)
- if self._fstream is None:
- self._transform(val)
- else:
- self._fstream.write(self._transform(val))
- self._fstream.flush()
- self._output = self._fstream # maybe not necessary
- if self._output_path is not None:
- self._output.close()
-
- def init_output(self, first_data):
- # type: (Any) -> None
- if self._yaml.top_level_colon_align is True:
- tlca = max([len(str(x)) for x in first_data]) # type: Any
- else:
- tlca = self._yaml.top_level_colon_align
- self._yaml.get_serializer_representer_emitter(self._output, tlca)
- self._yaml.serializer.open()
- self._output_inited = True
-
- def dump(self, data):
- # type: (Any) -> None
- if not self._output_inited:
- self.init_output(data)
- try:
- self._yaml.representer.represent(data)
- except AttributeError:
- # nprint(dir(dumper._representer))
- raise
-
- # def teardown_input(self):
- # pass
- #
- # def init_input(self):
- # # set the constructor and parser on YAML() instance
- # self._yaml.get_constructor_parser(stream)
- #
- # def load(self):
- # if not self._input_inited:
- # self.init_input()
- # try:
- # while self._yaml.constructor.check_data():
- # yield self._yaml.constructor.get_data()
- # finally:
- # parser.dispose()
- # try:
- # self._reader.reset_reader() # type: ignore
- # except AttributeError:
- # pass
- # try:
- # self._scanner.reset_scanner() # type: ignore
- # except AttributeError:
- # pass
-
-
-def yaml_object(yml):
- # type: (Any) -> Any
- """ decorator for classes that needs to dump/load objects
- The tag for such objects is taken from the class attribute yaml_tag (or the
- class name in lowercase in case unavailable)
- If methods to_yaml and/or from_yaml are available, these are called for dumping resp.
- loading, default routines (dumping a mapping of the attributes) used otherwise.
- """
-
- def yo_deco(cls):
- # type: (Any) -> Any
- tag = getattr(cls, "yaml_tag", "!" + cls.__name__)
- try:
- yml.representer.add_representer(cls, cls.to_yaml)
- except AttributeError:
-
- def t_y(representer, data):
- # type: (Any, Any) -> Any
- return representer.represent_yaml_object(
- tag, data, cls, flow_style=representer.default_flow_style
- )
-
- yml.representer.add_representer(cls, t_y)
- try:
- yml.constructor.add_constructor(tag, cls.from_yaml)
- except AttributeError:
-
- def f_y(constructor, node):
- # type: (Any, Any) -> Any
- return constructor.construct_yaml_object(node, cls)
-
- yml.constructor.add_constructor(tag, f_y)
- return cls
-
- return yo_deco
-
-
-########################################################################################
-
-
-def scan(stream, Loader=Loader):
- # type: (StreamTextType, Any) -> Any
- """
- Scan a YAML stream and produce scanning tokens.
- """
- loader = Loader(stream)
- try:
- while loader.scanner.check_token():
- yield loader.scanner.get_token()
- finally:
- loader._parser.dispose()
-
-
-def parse(stream, Loader=Loader):
- # type: (StreamTextType, Any) -> Any
- """
- Parse a YAML stream and produce parsing events.
- """
- loader = Loader(stream)
- try:
- while loader._parser.check_event():
- yield loader._parser.get_event()
- finally:
- loader._parser.dispose()
-
-
-def compose(stream, Loader=Loader):
- # type: (StreamTextType, Any) -> Any
- """
- Parse the first YAML document in a stream
- and produce the corresponding representation tree.
- """
- loader = Loader(stream)
- try:
- return loader.get_single_node()
- finally:
- loader.dispose()
-
-
-def compose_all(stream, Loader=Loader):
- # type: (StreamTextType, Any) -> Any
- """
- Parse all YAML documents in a stream
- and produce corresponding representation trees.
- """
- loader = Loader(stream)
- try:
- while loader.check_node():
- yield loader._composer.get_node()
- finally:
- loader._parser.dispose()
-
-
-def load(stream, Loader=None, version=None, preserve_quotes=None):
- # type: (StreamTextType, Any, Optional[VersionType], Any) -> Any
- """
- Parse the first YAML document in a stream
- and produce the corresponding Python object.
- """
- if Loader is None:
- warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2)
- Loader = UnsafeLoader
- loader = Loader(stream, version, preserve_quotes=preserve_quotes)
- try:
- return loader._constructor.get_single_data()
- finally:
- loader._parser.dispose()
- try:
- loader._reader.reset_reader()
- except AttributeError:
- pass
- try:
- loader._scanner.reset_scanner()
- except AttributeError:
- pass
-
-
-def load_all(stream, Loader=None, version=None, preserve_quotes=None):
- # type: (Optional[StreamTextType], Any, Optional[VersionType], Optional[bool]) -> Any # NOQA
- """
- Parse all YAML documents in a stream
- and produce corresponding Python objects.
- """
- if Loader is None:
- warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2)
- Loader = UnsafeLoader
- loader = Loader(stream, version, preserve_quotes=preserve_quotes)
- try:
- while loader._constructor.check_data():
- yield loader._constructor.get_data()
- finally:
- loader._parser.dispose()
- try:
- loader._reader.reset_reader()
- except AttributeError:
- pass
- try:
- loader._scanner.reset_scanner()
- except AttributeError:
- pass
-
-
-def safe_load(stream, version=None):
- # type: (StreamTextType, Optional[VersionType]) -> Any
- """
- Parse the first YAML document in a stream
- and produce the corresponding Python object.
- Resolve only basic YAML tags.
- """
- return load(stream, SafeLoader, version)
-
-
-def safe_load_all(stream, version=None):
- # type: (StreamTextType, Optional[VersionType]) -> Any
- """
- Parse all YAML documents in a stream
- and produce corresponding Python objects.
- Resolve only basic YAML tags.
- """
- return load_all(stream, SafeLoader, version)
-
-
-def round_trip_load(stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
- """
- Parse the first YAML document in a stream
- and produce the corresponding Python object.
- Resolve only basic YAML tags.
- """
- return load(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes)
-
-
-def round_trip_load_all(stream, version=None, preserve_quotes=None):
- # type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
- """
- Parse all YAML documents in a stream
- and produce corresponding Python objects.
- Resolve only basic YAML tags.
- """
- return load_all(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes)
-
-
-def emit(
- events,
- stream=None,
- Dumper=Dumper,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
-):
- # type: (Any, Optional[StreamType], Any, Optional[bool], Union[int, None], Optional[int], Optional[bool], Any) -> Any # NOQA
- """
- Emit YAML parsing events into a stream.
- If stream is None, return the produced string instead.
- """
- getvalue = None
- if stream is None:
- stream = StringIO()
- getvalue = stream.getvalue
- dumper = Dumper(
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- )
- try:
- for event in events:
- dumper.emit(event)
- finally:
- try:
- dumper._emitter.dispose()
- except AttributeError:
- raise
- dumper.dispose() # cyaml
- if getvalue is not None:
- return getvalue()
-
-
-enc = None if PY3 else "utf-8"
-
-
-def serialize_all(
- nodes,
- stream=None,
- Dumper=Dumper,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=enc,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
-):
- # type: (Any, Optional[StreamType], Any, Any, Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any) -> Any # NOQA
- """
- Serialize a sequence of representation trees into a YAML stream.
- If stream is None, return the produced string instead.
- """
- getvalue = None
- if stream is None:
- if encoding is None:
- stream = StringIO()
- else:
- stream = BytesIO()
- getvalue = stream.getvalue
- dumper = Dumper(
- stream,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- encoding=encoding,
- version=version,
- tags=tags,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- )
- try:
- dumper._serializer.open()
- for node in nodes:
- dumper.serialize(node)
- dumper._serializer.close()
- finally:
- try:
- dumper._emitter.dispose()
- except AttributeError:
- raise
- dumper.dispose() # cyaml
- if getvalue is not None:
- return getvalue()
-
-
-def serialize(node, stream=None, Dumper=Dumper, **kwds):
- # type: (Any, Optional[StreamType], Any, Any) -> Any
- """
- Serialize a representation tree into a YAML stream.
- If stream is None, return the produced string instead.
- """
- return serialize_all([node], stream, Dumper=Dumper, **kwds)
-
-
-def dump_all(
- documents,
- stream=None,
- Dumper=Dumper,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=enc,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
-):
- # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> Optional[str] # NOQA
- """
- Serialize a sequence of Python objects into a YAML stream.
- If stream is None, return the produced string instead.
- """
- getvalue = None
- if top_level_colon_align is True:
- top_level_colon_align = max([len(str(x)) for x in documents[0]])
- if stream is None:
- if encoding is None:
- stream = StringIO()
- else:
- stream = BytesIO()
- getvalue = stream.getvalue
- dumper = Dumper(
- stream,
- default_style=default_style,
- default_flow_style=default_flow_style,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- )
- try:
- dumper._serializer.open()
- for data in documents:
- try:
- dumper._representer.represent(data)
- except AttributeError:
- # nprint(dir(dumper._representer))
- raise
- dumper._serializer.close()
- finally:
- try:
- dumper._emitter.dispose()
- except AttributeError:
- raise
- dumper.dispose() # cyaml
- if getvalue is not None:
- return getvalue()
- return None
-
-
-def dump(
- data,
- stream=None,
- Dumper=Dumper,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=enc,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
-):
- # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> Optional[str] # NOQA
- """
- Serialize a Python object into a YAML stream.
- If stream is None, return the produced string instead.
-
- default_style ∈ None, '', '"', "'", '|', '>'
-
- """
- return dump_all(
- [data],
- stream,
- Dumper=Dumper,
- default_style=default_style,
- default_flow_style=default_flow_style,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- block_seq_indent=block_seq_indent,
- )
-
-
-def safe_dump_all(documents, stream=None, **kwds):
- # type: (Any, Optional[StreamType], Any) -> Optional[str]
- """
- Serialize a sequence of Python objects into a YAML stream.
- Produce only basic YAML tags.
- If stream is None, return the produced string instead.
- """
- return dump_all(documents, stream, Dumper=SafeDumper, **kwds)
-
-
-def safe_dump(data, stream=None, **kwds):
- # type: (Any, Optional[StreamType], Any) -> Optional[str]
- """
- Serialize a Python object into a YAML stream.
- Produce only basic YAML tags.
- If stream is None, return the produced string instead.
- """
- return dump_all([data], stream, Dumper=SafeDumper, **kwds)
-
-
-def round_trip_dump(
- data,
- stream=None,
- Dumper=RoundTripDumper,
- default_style=None,
- default_flow_style=None,
- canonical=None,
- indent=None,
- width=None,
- allow_unicode=None,
- line_break=None,
- encoding=enc,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
-):
- # type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any, Any, Any) -> Optional[str] # NOQA
- allow_unicode = True if allow_unicode is None else allow_unicode
- return dump_all(
- [data],
- stream,
- Dumper=Dumper,
- default_style=default_style,
- default_flow_style=default_flow_style,
- canonical=canonical,
- indent=indent,
- width=width,
- allow_unicode=allow_unicode,
- line_break=line_break,
- encoding=encoding,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- tags=tags,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- )
-
-
-# Loader/Dumper are no longer composites, to get to the associated
-# Resolver()/Representer(), etc., you need to instantiate the class
-
-
-def add_implicit_resolver(
- tag, regexp, first=None, Loader=None, Dumper=None, resolver=Resolver
-):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- """
- Add an implicit scalar detector.
- If an implicit scalar value matches the given regexp,
- the corresponding tag is assigned to the scalar.
- first is a sequence of possible initial characters or None.
- """
- if Loader is None and Dumper is None:
- resolver.add_implicit_resolver(tag, regexp, first)
- return
- if Loader:
- if hasattr(Loader, "add_implicit_resolver"):
- Loader.add_implicit_resolver(tag, regexp, first)
- elif issubclass(
- Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader)
- ):
- Resolver.add_implicit_resolver(tag, regexp, first)
- else:
- raise NotImplementedError
- if Dumper:
- if hasattr(Dumper, "add_implicit_resolver"):
- Dumper.add_implicit_resolver(tag, regexp, first)
- elif issubclass(
- Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper)
- ):
- Resolver.add_implicit_resolver(tag, regexp, first)
- else:
- raise NotImplementedError
-
-
-# this code currently not tested
-def add_path_resolver(
- tag, path, kind=None, Loader=None, Dumper=None, resolver=Resolver
-):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- """
- Add a path based resolver for the given tag.
- A path is a list of keys that forms a path
- to a node in the representation tree.
- Keys can be string values, integers, or None.
- """
- if Loader is None and Dumper is None:
- resolver.add_path_resolver(tag, path, kind)
- return
- if Loader:
- if hasattr(Loader, "add_path_resolver"):
- Loader.add_path_resolver(tag, path, kind)
- elif issubclass(
- Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader)
- ):
- Resolver.add_path_resolver(tag, path, kind)
- else:
- raise NotImplementedError
- if Dumper:
- if hasattr(Dumper, "add_path_resolver"):
- Dumper.add_path_resolver(tag, path, kind)
- elif issubclass(
- Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper)
- ):
- Resolver.add_path_resolver(tag, path, kind)
- else:
- raise NotImplementedError
-
-
-def add_constructor(tag, object_constructor, Loader=None, constructor=Constructor):
- # type: (Any, Any, Any, Any) -> None
- """
- Add an object constructor for the given tag.
- object_onstructor is a function that accepts a Loader instance
- and a node object and produces the corresponding Python object.
- """
- if Loader is None:
- constructor.add_constructor(tag, object_constructor)
- else:
- if hasattr(Loader, "add_constructor"):
- Loader.add_constructor(tag, object_constructor)
- return
- if issubclass(Loader, BaseLoader):
- BaseConstructor.add_constructor(tag, object_constructor)
- elif issubclass(Loader, SafeLoader):
- SafeConstructor.add_constructor(tag, object_constructor)
- elif issubclass(Loader, Loader):
- Constructor.add_constructor(tag, object_constructor)
- elif issubclass(Loader, RoundTripLoader):
- RoundTripConstructor.add_constructor(tag, object_constructor)
- else:
- raise NotImplementedError
-
-
-def add_multi_constructor(
- tag_prefix, multi_constructor, Loader=None, constructor=Constructor
-):
- # type: (Any, Any, Any, Any) -> None
- """
- Add a multi-constructor for the given tag prefix.
- Multi-constructor is called for a node if its tag starts with tag_prefix.
- Multi-constructor accepts a Loader instance, a tag suffix,
- and a node object and produces the corresponding Python object.
- """
- if Loader is None:
- constructor.add_multi_constructor(tag_prefix, multi_constructor)
- else:
- if False and hasattr(Loader, "add_multi_constructor"):
- Loader.add_multi_constructor(tag_prefix, constructor)
- return
- if issubclass(Loader, BaseLoader):
- BaseConstructor.add_multi_constructor(tag_prefix, multi_constructor)
- elif issubclass(Loader, SafeLoader):
- SafeConstructor.add_multi_constructor(tag_prefix, multi_constructor)
- elif issubclass(Loader, loader.Loader):
- Constructor.add_multi_constructor(tag_prefix, multi_constructor)
- elif issubclass(Loader, RoundTripLoader):
- RoundTripConstructor.add_multi_constructor(tag_prefix, multi_constructor)
- else:
- raise NotImplementedError
-
-
-def add_representer(
- data_type, object_representer, Dumper=None, representer=Representer
-):
- # type: (Any, Any, Any, Any) -> None
- """
- Add a representer for the given type.
- object_representer is a function accepting a Dumper instance
- and an instance of the given data type
- and producing the corresponding representation node.
- """
- if Dumper is None:
- representer.add_representer(data_type, object_representer)
- else:
- if hasattr(Dumper, "add_representer"):
- Dumper.add_representer(data_type, object_representer)
- return
- if issubclass(Dumper, BaseDumper):
- BaseRepresenter.add_representer(data_type, object_representer)
- elif issubclass(Dumper, SafeDumper):
- SafeRepresenter.add_representer(data_type, object_representer)
- elif issubclass(Dumper, Dumper):
- Representer.add_representer(data_type, object_representer)
- elif issubclass(Dumper, RoundTripDumper):
- RoundTripRepresenter.add_representer(data_type, object_representer)
- else:
- raise NotImplementedError
-
-
-# this code currently not tested
-def add_multi_representer(
- data_type, multi_representer, Dumper=None, representer=Representer
-):
- # type: (Any, Any, Any, Any) -> None
- """
- Add a representer for the given type.
- multi_representer is a function accepting a Dumper instance
- and an instance of the given data type or subtype
- and producing the corresponding representation node.
- """
- if Dumper is None:
- representer.add_multi_representer(data_type, multi_representer)
- else:
- if hasattr(Dumper, "add_multi_representer"):
- Dumper.add_multi_representer(data_type, multi_representer)
- return
- if issubclass(Dumper, BaseDumper):
- BaseRepresenter.add_multi_representer(data_type, multi_representer)
- elif issubclass(Dumper, SafeDumper):
- SafeRepresenter.add_multi_representer(data_type, multi_representer)
- elif issubclass(Dumper, Dumper):
- Representer.add_multi_representer(data_type, multi_representer)
- elif issubclass(Dumper, RoundTripDumper):
- RoundTripRepresenter.add_multi_representer(data_type, multi_representer)
- else:
- raise NotImplementedError
-
-
-class YAMLObjectMetaclass(type):
- """
- The metaclass for YAMLObject.
- """
-
- def __init__(cls, name, bases, kwds):
- # type: (Any, Any, Any) -> None
- super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds)
- if "yaml_tag" in kwds and kwds["yaml_tag"] is not None:
- cls.yaml_constructor.add_constructor(
- cls.yaml_tag, cls.from_yaml
- ) # type: ignore
- cls.yaml_representer.add_representer(cls, cls.to_yaml) # type: ignore
-
-
-class YAMLObject(with_metaclass(YAMLObjectMetaclass)): # type: ignore
- """
- An object that can dump itself to a YAML stream
- and load itself from a YAML stream.
- """
-
- __slots__ = () # no direct instantiation, so allow immutable subclasses
-
- yaml_constructor = Constructor
- yaml_representer = Representer
-
- yaml_tag = None # type: Any
- yaml_flow_style = None # type: Any
-
- @classmethod
- def from_yaml(cls, constructor, node):
- # type: (Any, Any) -> Any
- """
- Convert a representation node to a Python object.
- """
- return constructor.construct_yaml_object(node, cls)
-
- @classmethod
- def to_yaml(cls, representer, data):
- # type: (Any, Any) -> Any
- """
- Convert a Python object to a representation node.
- """
- return representer.represent_yaml_object(
- cls.yaml_tag, data, cls, flow_style=cls.yaml_flow_style
- )
diff --git a/srsly/ruamel_yaml/nodes.py b/srsly/ruamel_yaml/nodes.py
deleted file mode 100755
index da86e9c..0000000
--- a/srsly/ruamel_yaml/nodes.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-import sys
-from .compat import string_types
-
-if False: # MYPY
- from typing import Dict, Any, Text # NOQA
-
-
-class Node(object):
- __slots__ = 'tag', 'value', 'start_mark', 'end_mark', 'comment', 'anchor'
-
- def __init__(self, tag, value, start_mark, end_mark, comment=None, anchor=None):
- # type: (Any, Any, Any, Any, Any, Any) -> None
- self.tag = tag
- self.value = value
- self.start_mark = start_mark
- self.end_mark = end_mark
- self.comment = comment
- self.anchor = anchor
-
- def __repr__(self):
- # type: () -> str
- value = self.value
- # if isinstance(value, list):
- # if len(value) == 0:
- # value = ''
- # elif len(value) == 1:
- # value = '<1 item>'
- # else:
- # value = '<%d items>' % len(value)
- # else:
- # if len(value) > 75:
- # value = repr(value[:70]+u' ... ')
- # else:
- # value = repr(value)
- value = repr(value)
- return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value)
-
- def dump(self, indent=0):
- # type: (int) -> None
- if isinstance(self.value, string_types):
- sys.stdout.write(
- '{}{}(tag={!r}, value={!r})\n'.format(
- ' ' * indent, self.__class__.__name__, self.tag, self.value
- )
- )
- if self.comment:
- sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment))
- return
- sys.stdout.write(
- '{}{}(tag={!r})\n'.format(' ' * indent, self.__class__.__name__, self.tag)
- )
- if self.comment:
- sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment))
- for v in self.value:
- if isinstance(v, tuple):
- for v1 in v:
- v1.dump(indent + 1)
- elif isinstance(v, Node):
- v.dump(indent + 1)
- else:
- sys.stdout.write('Node value type? {}\n'.format(type(v)))
-
-
-class ScalarNode(Node):
- """
- styles:
- ? -> set() ? key, no value
- " -> double quoted
- ' -> single quoted
- | -> literal style
- > -> folding style
- """
-
- __slots__ = ('style',)
- id = 'scalar'
-
- def __init__(
- self, tag, value, start_mark=None, end_mark=None, style=None, comment=None, anchor=None
- ):
- # type: (Any, Any, Any, Any, Any, Any, Any) -> None
- Node.__init__(self, tag, value, start_mark, end_mark, comment=comment, anchor=anchor)
- self.style = style
-
-
-class CollectionNode(Node):
- __slots__ = ('flow_style',)
-
- def __init__(
- self,
- tag,
- value,
- start_mark=None,
- end_mark=None,
- flow_style=None,
- comment=None,
- anchor=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any, Any) -> None
- Node.__init__(self, tag, value, start_mark, end_mark, comment=comment)
- self.flow_style = flow_style
- self.anchor = anchor
-
-
-class SequenceNode(CollectionNode):
- __slots__ = ()
- id = 'sequence'
-
-
-class MappingNode(CollectionNode):
- __slots__ = ('merge',)
- id = 'mapping'
-
- def __init__(
- self,
- tag,
- value,
- start_mark=None,
- end_mark=None,
- flow_style=None,
- comment=None,
- anchor=None,
- ):
- # type: (Any, Any, Any, Any, Any, Any, Any) -> None
- CollectionNode.__init__(
- self, tag, value, start_mark, end_mark, flow_style, comment, anchor
- )
- self.merge = None
diff --git a/srsly/ruamel_yaml/parser.py b/srsly/ruamel_yaml/parser.py
deleted file mode 100755
index 437bd22..0000000
--- a/srsly/ruamel_yaml/parser.py
+++ /dev/null
@@ -1,844 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-# The following YAML grammar is LL(1) and is parsed by a recursive descent
-# parser.
-#
-# stream ::= STREAM-START implicit_document? explicit_document*
-# STREAM-END
-# implicit_document ::= block_node DOCUMENT-END*
-# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
-# block_node_or_indentless_sequence ::=
-# ALIAS
-# | properties (block_content |
-# indentless_block_sequence)?
-# | block_content
-# | indentless_block_sequence
-# block_node ::= ALIAS
-# | properties block_content?
-# | block_content
-# flow_node ::= ALIAS
-# | properties flow_content?
-# | flow_content
-# properties ::= TAG ANCHOR? | ANCHOR TAG?
-# block_content ::= block_collection | flow_collection | SCALAR
-# flow_content ::= flow_collection | SCALAR
-# block_collection ::= block_sequence | block_mapping
-# flow_collection ::= flow_sequence | flow_mapping
-# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)*
-# BLOCK-END
-# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
-# block_mapping ::= BLOCK-MAPPING_START
-# ((KEY block_node_or_indentless_sequence?)?
-# (VALUE block_node_or_indentless_sequence?)?)*
-# BLOCK-END
-# flow_sequence ::= FLOW-SEQUENCE-START
-# (flow_sequence_entry FLOW-ENTRY)*
-# flow_sequence_entry?
-# FLOW-SEQUENCE-END
-# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
-# flow_mapping ::= FLOW-MAPPING-START
-# (flow_mapping_entry FLOW-ENTRY)*
-# flow_mapping_entry?
-# FLOW-MAPPING-END
-# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
-#
-# FIRST sets:
-#
-# stream: { STREAM-START }
-# explicit_document: { DIRECTIVE DOCUMENT-START }
-# implicit_document: FIRST(block_node)
-# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START
-# BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START }
-# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START }
-# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START
-# FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
-# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
-# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START }
-# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
-# block_sequence: { BLOCK-SEQUENCE-START }
-# block_mapping: { BLOCK-MAPPING-START }
-# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR
-# BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START
-# FLOW-MAPPING-START BLOCK-ENTRY }
-# indentless_sequence: { ENTRY }
-# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
-# flow_sequence: { FLOW-SEQUENCE-START }
-# flow_mapping: { FLOW-MAPPING-START }
-# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START
-# FLOW-MAPPING-START KEY }
-# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START
-# FLOW-MAPPING-START KEY }
-
-# need to have full path with import, as pkg_resources tries to load parser.py in __init__.py
-# only to not do anything with the package afterwards
-# and for Jython too
-
-
-from .error import MarkedYAMLError
-from .tokens import * # NOQA
-from .events import * # NOQA
-from .scanner import Scanner, RoundTripScanner, ScannerError # NOQA
-from .compat import utf8, nprint, nprintf # NOQA
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List # NOQA
-
-__all__ = ["Parser", "RoundTripParser", "ParserError"]
-
-
-class ParserError(MarkedYAMLError):
- pass
-
-
-class Parser(object):
- # Since writing a recursive-descendant parser is a straightforward task, we
- # do not give many comments here.
-
- DEFAULT_TAGS = {u"!": u"!", u"!!": u"tag:yaml.org,2002:"}
-
- def __init__(self, loader):
- # type: (Any) -> None
- self.loader = loader
- if self.loader is not None and getattr(self.loader, "_parser", None) is None:
- self.loader._parser = self
- self.reset_parser()
-
- def reset_parser(self):
- # type: () -> None
- # Reset the state attributes (to clear self-references)
- self.current_event = None
- self.tag_handles = {} # type: Dict[Any, Any]
- self.states = [] # type: List[Any]
- self.marks = [] # type: List[Any]
- self.state = self.parse_stream_start # type: Any
-
- def dispose(self):
- # type: () -> None
- self.reset_parser()
-
- @property
- def scanner(self):
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- return self.loader.scanner
- return self.loader._scanner
-
- @property
- def resolver(self):
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- return self.loader.resolver
- return self.loader._resolver
-
- def check_event(self, *choices):
- # type: (Any) -> bool
- # Check the type of the next event.
- if self.current_event is None:
- if self.state:
- self.current_event = self.state()
- if self.current_event is not None:
- if not choices:
- return True
- for choice in choices:
- if isinstance(self.current_event, choice):
- return True
- return False
-
- def peek_event(self):
- # type: () -> Any
- # Get the next event.
- if self.current_event is None:
- if self.state:
- self.current_event = self.state()
- return self.current_event
-
- def get_event(self):
- # type: () -> Any
- # Get the next event and proceed further.
- if self.current_event is None:
- if self.state:
- self.current_event = self.state()
- value = self.current_event
- self.current_event = None
- return value
-
- # stream ::= STREAM-START implicit_document? explicit_document*
- # STREAM-END
- # implicit_document ::= block_node DOCUMENT-END*
- # explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
-
- def parse_stream_start(self):
- # type: () -> Any
- # Parse the stream start.
- token = self.scanner.get_token()
- token.move_comment(self.scanner.peek_token())
- event = StreamStartEvent(
- token.start_mark, token.end_mark, encoding=token.encoding
- )
-
- # Prepare the next state.
- self.state = self.parse_implicit_document_start
-
- return event
-
- def parse_implicit_document_start(self):
- # type: () -> Any
- # Parse an implicit document.
- if not self.scanner.check_token(
- DirectiveToken, DocumentStartToken, StreamEndToken
- ):
- self.tag_handles = self.DEFAULT_TAGS
- token = self.scanner.peek_token()
- start_mark = end_mark = token.start_mark
- event = DocumentStartEvent(start_mark, end_mark, explicit=False)
-
- # Prepare the next state.
- self.states.append(self.parse_document_end)
- self.state = self.parse_block_node
-
- return event
-
- else:
- return self.parse_document_start()
-
- def parse_document_start(self):
- # type: () -> Any
- # Parse any extra document end indicators.
- while self.scanner.check_token(DocumentEndToken):
- self.scanner.get_token()
- # Parse an explicit document.
- if not self.scanner.check_token(StreamEndToken):
- token = self.scanner.peek_token()
- start_mark = token.start_mark
- version, tags = self.process_directives()
- if not self.scanner.check_token(DocumentStartToken):
- raise ParserError(
- None,
- None,
- "expected '', but found %r"
- % self.scanner.peek_token().id,
- self.scanner.peek_token().start_mark,
- )
- token = self.scanner.get_token()
- end_mark = token.end_mark
- # if self.loader is not None and \
- # end_mark.line != self.scanner.peek_token().start_mark.line:
- # self.loader.scalar_after_indicator = False
- event = DocumentStartEvent(
- start_mark, end_mark, explicit=True, version=version, tags=tags
- ) # type: Any
- self.states.append(self.parse_document_end)
- self.state = self.parse_document_content
- else:
- # Parse the end of the stream.
- token = self.scanner.get_token()
- event = StreamEndEvent(
- token.start_mark, token.end_mark, comment=token.comment
- )
- assert not self.states
- assert not self.marks
- self.state = None
- return event
-
- def parse_document_end(self):
- # type: () -> Any
- # Parse the document end.
- token = self.scanner.peek_token()
- start_mark = end_mark = token.start_mark
- explicit = False
- if self.scanner.check_token(DocumentEndToken):
- token = self.scanner.get_token()
- end_mark = token.end_mark
- explicit = True
- event = DocumentEndEvent(start_mark, end_mark, explicit=explicit)
-
- # Prepare the next state.
- if self.resolver.processing_version == (1, 1):
- self.state = self.parse_document_start
- else:
- self.state = self.parse_implicit_document_start
-
- return event
-
- def parse_document_content(self):
- # type: () -> Any
- if self.scanner.check_token(
- DirectiveToken, DocumentStartToken, DocumentEndToken, StreamEndToken
- ):
- event = self.process_empty_scalar(self.scanner.peek_token().start_mark)
- self.state = self.states.pop()
- return event
- else:
- return self.parse_block_node()
-
- def process_directives(self):
- # type: () -> Any
- yaml_version = None
- self.tag_handles = {}
- while self.scanner.check_token(DirectiveToken):
- token = self.scanner.get_token()
- if token.name == u"YAML":
- if yaml_version is not None:
- raise ParserError(
- None, None, "found duplicate YAML directive", token.start_mark
- )
- major, minor = token.value
- if major != 1:
- raise ParserError(
- None,
- None,
- "found incompatible YAML document (version 1.* is " "required)",
- token.start_mark,
- )
- yaml_version = token.value
- elif token.name == u"TAG":
- handle, prefix = token.value
- if handle in self.tag_handles:
- raise ParserError(
- None,
- None,
- "duplicate tag handle %r" % utf8(handle),
- token.start_mark,
- )
- self.tag_handles[handle] = prefix
- if bool(self.tag_handles):
- value = yaml_version, self.tag_handles.copy() # type: Any
- else:
- value = yaml_version, None
- if self.loader is not None and hasattr(self.loader, "tags"):
- self.loader.version = yaml_version
- if self.loader.tags is None:
- self.loader.tags = {}
- for k in self.tag_handles:
- self.loader.tags[k] = self.tag_handles[k]
- for key in self.DEFAULT_TAGS:
- if key not in self.tag_handles:
- self.tag_handles[key] = self.DEFAULT_TAGS[key]
- return value
-
- # block_node_or_indentless_sequence ::= ALIAS
- # | properties (block_content | indentless_block_sequence)?
- # | block_content
- # | indentless_block_sequence
- # block_node ::= ALIAS
- # | properties block_content?
- # | block_content
- # flow_node ::= ALIAS
- # | properties flow_content?
- # | flow_content
- # properties ::= TAG ANCHOR? | ANCHOR TAG?
- # block_content ::= block_collection | flow_collection | SCALAR
- # flow_content ::= flow_collection | SCALAR
- # block_collection ::= block_sequence | block_mapping
- # flow_collection ::= flow_sequence | flow_mapping
-
- def parse_block_node(self):
- # type: () -> Any
- return self.parse_node(block=True)
-
- def parse_flow_node(self):
- # type: () -> Any
- return self.parse_node()
-
- def parse_block_node_or_indentless_sequence(self):
- # type: () -> Any
- return self.parse_node(block=True, indentless_sequence=True)
-
- def transform_tag(self, handle, suffix):
- # type: (Any, Any) -> Any
- return self.tag_handles[handle] + suffix
-
- def parse_node(self, block=False, indentless_sequence=False):
- # type: (bool, bool) -> Any
- if self.scanner.check_token(AliasToken):
- token = self.scanner.get_token()
- event = AliasEvent(
- token.value, token.start_mark, token.end_mark
- ) # type: Any
- self.state = self.states.pop()
- return event
-
- anchor = None
- tag = None
- start_mark = end_mark = tag_mark = None
- if self.scanner.check_token(AnchorToken):
- token = self.scanner.get_token()
- start_mark = token.start_mark
- end_mark = token.end_mark
- anchor = token.value
- if self.scanner.check_token(TagToken):
- token = self.scanner.get_token()
- tag_mark = token.start_mark
- end_mark = token.end_mark
- tag = token.value
- elif self.scanner.check_token(TagToken):
- token = self.scanner.get_token()
- start_mark = tag_mark = token.start_mark
- end_mark = token.end_mark
- tag = token.value
- if self.scanner.check_token(AnchorToken):
- token = self.scanner.get_token()
- start_mark = tag_mark = token.start_mark
- end_mark = token.end_mark
- anchor = token.value
- if tag is not None:
- handle, suffix = tag
- if handle is not None:
- if handle not in self.tag_handles:
- raise ParserError(
- "while parsing a node",
- start_mark,
- "found undefined tag handle %r" % utf8(handle),
- tag_mark,
- )
- tag = self.transform_tag(handle, suffix)
- else:
- tag = suffix
- # if tag == u'!':
- # raise ParserError("while parsing a node", start_mark,
- # "found non-specific tag '!'", tag_mark,
- # "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag'
- # and share your opinion.")
- if start_mark is None:
- start_mark = end_mark = self.scanner.peek_token().start_mark
- event = None
- implicit = tag is None or tag == u"!"
- if indentless_sequence and self.scanner.check_token(BlockEntryToken):
- comment = None
- pt = self.scanner.peek_token()
- if pt.comment and pt.comment[0]:
- comment = [pt.comment[0], []]
- pt.comment[0] = None
- end_mark = self.scanner.peek_token().end_mark
- event = SequenceStartEvent(
- anchor,
- tag,
- implicit,
- start_mark,
- end_mark,
- flow_style=False,
- comment=comment,
- )
- self.state = self.parse_indentless_sequence_entry
- return event
-
- if self.scanner.check_token(ScalarToken):
- token = self.scanner.get_token()
- # self.scanner.peek_token_same_line_comment(token)
- end_mark = token.end_mark
- if (token.plain and tag is None) or tag == u"!":
- implicit = (True, False)
- elif tag is None:
- implicit = (False, True)
- else:
- implicit = (False, False)
- # nprint('se', token.value, token.comment)
- event = ScalarEvent(
- anchor,
- tag,
- implicit,
- token.value,
- start_mark,
- end_mark,
- style=token.style,
- comment=token.comment,
- )
- self.state = self.states.pop()
- elif self.scanner.check_token(FlowSequenceStartToken):
- pt = self.scanner.peek_token()
- end_mark = pt.end_mark
- event = SequenceStartEvent(
- anchor,
- tag,
- implicit,
- start_mark,
- end_mark,
- flow_style=True,
- comment=pt.comment,
- )
- self.state = self.parse_flow_sequence_first_entry
- elif self.scanner.check_token(FlowMappingStartToken):
- pt = self.scanner.peek_token()
- end_mark = pt.end_mark
- event = MappingStartEvent(
- anchor,
- tag,
- implicit,
- start_mark,
- end_mark,
- flow_style=True,
- comment=pt.comment,
- )
- self.state = self.parse_flow_mapping_first_key
- elif block and self.scanner.check_token(BlockSequenceStartToken):
- end_mark = self.scanner.peek_token().start_mark
- # should inserting the comment be dependent on the
- # indentation?
- pt = self.scanner.peek_token()
- comment = pt.comment
- # nprint('pt0', type(pt))
- if comment is None or comment[1] is None:
- comment = pt.split_comment()
- # nprint('pt1', comment)
- event = SequenceStartEvent(
- anchor,
- tag,
- implicit,
- start_mark,
- end_mark,
- flow_style=False,
- comment=comment,
- )
- self.state = self.parse_block_sequence_first_entry
- elif block and self.scanner.check_token(BlockMappingStartToken):
- end_mark = self.scanner.peek_token().start_mark
- comment = self.scanner.peek_token().comment
- event = MappingStartEvent(
- anchor,
- tag,
- implicit,
- start_mark,
- end_mark,
- flow_style=False,
- comment=comment,
- )
- self.state = self.parse_block_mapping_first_key
- elif anchor is not None or tag is not None:
- # Empty scalars are allowed even if a tag or an anchor is
- # specified.
- event = ScalarEvent(
- anchor, tag, (implicit, False), "", start_mark, end_mark
- )
- self.state = self.states.pop()
- else:
- if block:
- node = "block"
- else:
- node = "flow"
- token = self.scanner.peek_token()
- raise ParserError(
- "while parsing a %s node" % node,
- start_mark,
- "expected the node content, but found %r" % token.id,
- token.start_mark,
- )
- return event
-
- # block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)*
- # BLOCK-END
-
- def parse_block_sequence_first_entry(self):
- # type: () -> Any
- token = self.scanner.get_token()
- # move any comment from start token
- # token.move_comment(self.scanner.peek_token())
- self.marks.append(token.start_mark)
- return self.parse_block_sequence_entry()
-
- def parse_block_sequence_entry(self):
- # type: () -> Any
- if self.scanner.check_token(BlockEntryToken):
- token = self.scanner.get_token()
- token.move_comment(self.scanner.peek_token())
- if not self.scanner.check_token(BlockEntryToken, BlockEndToken):
- self.states.append(self.parse_block_sequence_entry)
- return self.parse_block_node()
- else:
- self.state = self.parse_block_sequence_entry
- return self.process_empty_scalar(token.end_mark)
- if not self.scanner.check_token(BlockEndToken):
- token = self.scanner.peek_token()
- raise ParserError(
- "while parsing a block collection",
- self.marks[-1],
- "expected , but found %r" % token.id,
- token.start_mark,
- )
- token = self.scanner.get_token() # BlockEndToken
- event = SequenceEndEvent(
- token.start_mark, token.end_mark, comment=token.comment
- )
- self.state = self.states.pop()
- self.marks.pop()
- return event
-
- # indentless_sequence ::= (BLOCK-ENTRY block_node?)+
-
- # indentless_sequence?
- # sequence:
- # - entry
- # - nested
-
- def parse_indentless_sequence_entry(self):
- # type: () -> Any
- if self.scanner.check_token(BlockEntryToken):
- token = self.scanner.get_token()
- token.move_comment(self.scanner.peek_token())
- if not self.scanner.check_token(
- BlockEntryToken, KeyToken, ValueToken, BlockEndToken
- ):
- self.states.append(self.parse_indentless_sequence_entry)
- return self.parse_block_node()
- else:
- self.state = self.parse_indentless_sequence_entry
- return self.process_empty_scalar(token.end_mark)
- token = self.scanner.peek_token()
- event = SequenceEndEvent(
- token.start_mark, token.start_mark, comment=token.comment
- )
- self.state = self.states.pop()
- return event
-
- # block_mapping ::= BLOCK-MAPPING_START
- # ((KEY block_node_or_indentless_sequence?)?
- # (VALUE block_node_or_indentless_sequence?)?)*
- # BLOCK-END
-
- def parse_block_mapping_first_key(self):
- # type: () -> Any
- token = self.scanner.get_token()
- self.marks.append(token.start_mark)
- return self.parse_block_mapping_key()
-
- def parse_block_mapping_key(self):
- # type: () -> Any
- if self.scanner.check_token(KeyToken):
- token = self.scanner.get_token()
- token.move_comment(self.scanner.peek_token())
- if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken):
- self.states.append(self.parse_block_mapping_value)
- return self.parse_block_node_or_indentless_sequence()
- else:
- self.state = self.parse_block_mapping_value
- return self.process_empty_scalar(token.end_mark)
- if self.resolver.processing_version > (1, 1) and self.scanner.check_token(
- ValueToken
- ):
- self.state = self.parse_block_mapping_value
- return self.process_empty_scalar(self.scanner.peek_token().start_mark)
- if not self.scanner.check_token(BlockEndToken):
- token = self.scanner.peek_token()
- raise ParserError(
- "while parsing a block mapping",
- self.marks[-1],
- "expected , but found %r" % token.id,
- token.start_mark,
- )
- token = self.scanner.get_token()
- token.move_comment(self.scanner.peek_token())
- event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment)
- self.state = self.states.pop()
- self.marks.pop()
- return event
-
- def parse_block_mapping_value(self):
- # type: () -> Any
- if self.scanner.check_token(ValueToken):
- token = self.scanner.get_token()
- # value token might have post comment move it to e.g. block
- if self.scanner.check_token(ValueToken):
- token.move_comment(self.scanner.peek_token())
- else:
- if not self.scanner.check_token(KeyToken):
- token.move_comment(self.scanner.peek_token(), empty=True)
- # else: empty value for this key cannot move token.comment
- if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken):
- self.states.append(self.parse_block_mapping_key)
- return self.parse_block_node_or_indentless_sequence()
- else:
- self.state = self.parse_block_mapping_key
- comment = token.comment
- if comment is None:
- token = self.scanner.peek_token()
- comment = token.comment
- if comment:
- token._comment = [None, comment[1]]
- comment = [comment[0], None]
- return self.process_empty_scalar(token.end_mark, comment=comment)
- else:
- self.state = self.parse_block_mapping_key
- token = self.scanner.peek_token()
- return self.process_empty_scalar(token.start_mark)
-
- # flow_sequence ::= FLOW-SEQUENCE-START
- # (flow_sequence_entry FLOW-ENTRY)*
- # flow_sequence_entry?
- # FLOW-SEQUENCE-END
- # flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
- #
- # Note that while production rules for both flow_sequence_entry and
- # flow_mapping_entry are equal, their interpretations are different.
- # For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?`
- # generate an inline mapping (set syntax).
-
- def parse_flow_sequence_first_entry(self):
- # type: () -> Any
- token = self.scanner.get_token()
- self.marks.append(token.start_mark)
- return self.parse_flow_sequence_entry(first=True)
-
- def parse_flow_sequence_entry(self, first=False):
- # type: (bool) -> Any
- if not self.scanner.check_token(FlowSequenceEndToken):
- if not first:
- if self.scanner.check_token(FlowEntryToken):
- self.scanner.get_token()
- else:
- token = self.scanner.peek_token()
- raise ParserError(
- "while parsing a flow sequence",
- self.marks[-1],
- "expected ',' or ']', but got %r" % token.id,
- token.start_mark,
- )
-
- if self.scanner.check_token(KeyToken):
- token = self.scanner.peek_token()
- event = MappingStartEvent(
- None, None, True, token.start_mark, token.end_mark, flow_style=True
- ) # type: Any
- self.state = self.parse_flow_sequence_entry_mapping_key
- return event
- elif not self.scanner.check_token(FlowSequenceEndToken):
- self.states.append(self.parse_flow_sequence_entry)
- return self.parse_flow_node()
- token = self.scanner.get_token()
- event = SequenceEndEvent(
- token.start_mark, token.end_mark, comment=token.comment
- )
- self.state = self.states.pop()
- self.marks.pop()
- return event
-
- def parse_flow_sequence_entry_mapping_key(self):
- # type: () -> Any
- token = self.scanner.get_token()
- if not self.scanner.check_token(
- ValueToken, FlowEntryToken, FlowSequenceEndToken
- ):
- self.states.append(self.parse_flow_sequence_entry_mapping_value)
- return self.parse_flow_node()
- else:
- self.state = self.parse_flow_sequence_entry_mapping_value
- return self.process_empty_scalar(token.end_mark)
-
- def parse_flow_sequence_entry_mapping_value(self):
- # type: () -> Any
- if self.scanner.check_token(ValueToken):
- token = self.scanner.get_token()
- if not self.scanner.check_token(FlowEntryToken, FlowSequenceEndToken):
- self.states.append(self.parse_flow_sequence_entry_mapping_end)
- return self.parse_flow_node()
- else:
- self.state = self.parse_flow_sequence_entry_mapping_end
- return self.process_empty_scalar(token.end_mark)
- else:
- self.state = self.parse_flow_sequence_entry_mapping_end
- token = self.scanner.peek_token()
- return self.process_empty_scalar(token.start_mark)
-
- def parse_flow_sequence_entry_mapping_end(self):
- # type: () -> Any
- self.state = self.parse_flow_sequence_entry
- token = self.scanner.peek_token()
- return MappingEndEvent(token.start_mark, token.start_mark)
-
- # flow_mapping ::= FLOW-MAPPING-START
- # (flow_mapping_entry FLOW-ENTRY)*
- # flow_mapping_entry?
- # FLOW-MAPPING-END
- # flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
-
- def parse_flow_mapping_first_key(self):
- # type: () -> Any
- token = self.scanner.get_token()
- self.marks.append(token.start_mark)
- return self.parse_flow_mapping_key(first=True)
-
- def parse_flow_mapping_key(self, first=False):
- # type: (Any) -> Any
- if not self.scanner.check_token(FlowMappingEndToken):
- if not first:
- if self.scanner.check_token(FlowEntryToken):
- self.scanner.get_token()
- else:
- token = self.scanner.peek_token()
- raise ParserError(
- "while parsing a flow mapping",
- self.marks[-1],
- "expected ',' or '}', but got %r" % token.id,
- token.start_mark,
- )
- if self.scanner.check_token(KeyToken):
- token = self.scanner.get_token()
- if not self.scanner.check_token(
- ValueToken, FlowEntryToken, FlowMappingEndToken
- ):
- self.states.append(self.parse_flow_mapping_value)
- return self.parse_flow_node()
- else:
- self.state = self.parse_flow_mapping_value
- return self.process_empty_scalar(token.end_mark)
- elif self.resolver.processing_version > (1, 1) and self.scanner.check_token(
- ValueToken
- ):
- self.state = self.parse_flow_mapping_value
- return self.process_empty_scalar(self.scanner.peek_token().end_mark)
- elif not self.scanner.check_token(FlowMappingEndToken):
- self.states.append(self.parse_flow_mapping_empty_value)
- return self.parse_flow_node()
- token = self.scanner.get_token()
- event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment)
- self.state = self.states.pop()
- self.marks.pop()
- return event
-
- def parse_flow_mapping_value(self):
- # type: () -> Any
- if self.scanner.check_token(ValueToken):
- token = self.scanner.get_token()
- if not self.scanner.check_token(FlowEntryToken, FlowMappingEndToken):
- self.states.append(self.parse_flow_mapping_key)
- return self.parse_flow_node()
- else:
- self.state = self.parse_flow_mapping_key
- return self.process_empty_scalar(token.end_mark)
- else:
- self.state = self.parse_flow_mapping_key
- token = self.scanner.peek_token()
- return self.process_empty_scalar(token.start_mark)
-
- def parse_flow_mapping_empty_value(self):
- # type: () -> Any
- self.state = self.parse_flow_mapping_key
- return self.process_empty_scalar(self.scanner.peek_token().start_mark)
-
- def process_empty_scalar(self, mark, comment=None):
- # type: (Any, Any) -> Any
- return ScalarEvent(None, None, (True, False), "", mark, mark, comment=comment)
-
-
-class RoundTripParser(Parser):
- """roundtrip is a safe loader, that wants to see the unmangled tag"""
-
- def transform_tag(self, handle, suffix):
- # type: (Any, Any) -> Any
- # return self.tag_handles[handle]+suffix
- if handle == "!!" and suffix in (
- u"null",
- u"bool",
- u"int",
- u"float",
- u"binary",
- u"timestamp",
- u"omap",
- u"pairs",
- u"set",
- u"str",
- u"seq",
- u"map",
- ):
- return Parser.transform_tag(self, handle, suffix)
- return handle + suffix
diff --git a/srsly/ruamel_yaml/py.typed b/srsly/ruamel_yaml/py.typed
deleted file mode 100755
index e69de29..0000000
diff --git a/srsly/ruamel_yaml/reader.py b/srsly/ruamel_yaml/reader.py
deleted file mode 100755
index b90c73b..0000000
--- a/srsly/ruamel_yaml/reader.py
+++ /dev/null
@@ -1,327 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-# This module contains abstractions for the input stream. You don't have to
-# looks further, there are no pretty code.
-#
-# We define two classes here.
-#
-# Mark(source, line, column)
-# It's just a record and its only use is producing nice error messages.
-# Parser does not use it for any other purposes.
-#
-# Reader(source, data)
-# Reader determines the encoding of `data` and converts it to unicode.
-# Reader provides the following methods and attributes:
-# reader.peek(length=1) - return the next `length` characters
-# reader.forward(length=1) - move the current position to `length`
-# characters.
-# reader.index - the number of the current character.
-# reader.line, stream.column - the line and the column of the current
-# character.
-
-import codecs
-
-from .error import YAMLError, FileMark, StringMark, YAMLStreamError
-from .compat import text_type, binary_type, PY3, UNICODE_SIZE
-from .util import RegExp
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Union, Text, Tuple, Optional # NOQA
-# from srsly.ruamel_yaml.compat import StreamTextType # NOQA
-
-__all__ = ["Reader", "ReaderError"]
-
-
-class ReaderError(YAMLError):
- def __init__(self, name, position, character, encoding, reason):
- # type: (Any, Any, Any, Any, Any) -> None
- self.name = name
- self.character = character
- self.position = position
- self.encoding = encoding
- self.reason = reason
-
- def __str__(self):
- # type: () -> str
- if isinstance(self.character, binary_type):
- return (
- "'%s' codec can't decode byte #x%02x: %s\n"
- ' in "%s", position %d'
- % (
- self.encoding,
- ord(self.character),
- self.reason,
- self.name,
- self.position,
- )
- )
- else:
- return "unacceptable character #x%04x: %s\n" ' in "%s", position %d' % (
- self.character,
- self.reason,
- self.name,
- self.position,
- )
-
-
-class Reader(object):
- # Reader:
- # - determines the data encoding and converts it to a unicode string,
- # - checks if characters are in allowed range,
- # - adds '\0' to the end.
-
- # Reader accepts
- # - a `str` object (PY2) / a `bytes` object (PY3),
- # - a `unicode` object (PY2) / a `str` object (PY3),
- # - a file-like object with its `read` method returning `str`,
- # - a file-like object with its `read` method returning `unicode`.
-
- # Yeah, it's ugly and slow.
-
- def __init__(self, stream, loader=None):
- # type: (Any, Any) -> None
- self.loader = loader
- if self.loader is not None and getattr(self.loader, "_reader", None) is None:
- self.loader._reader = self
- self.reset_reader()
- self.stream = stream # type: Any # as .read is called
-
- def reset_reader(self):
- # type: () -> None
- self.name = None # type: Any
- self.stream_pointer = 0
- self.eof = True
- self.buffer = ""
- self.pointer = 0
- self.raw_buffer = None # type: Any
- self.raw_decode = None
- self.encoding = None # type: Optional[Text]
- self.index = 0
- self.line = 0
- self.column = 0
-
- @property
- def stream(self):
- # type: () -> Any
- try:
- return self._stream
- except AttributeError:
- raise YAMLStreamError("input stream needs to specified")
-
- @stream.setter
- def stream(self, val):
- # type: (Any) -> None
- if val is None:
- return
- self._stream = None
- if isinstance(val, text_type):
- self.name = ""
- self.check_printable(val)
- self.buffer = val + u"\0" # type: ignore
- elif isinstance(val, binary_type):
- self.name = ""
- self.raw_buffer = val
- self.determine_encoding()
- else:
- if not hasattr(val, "read"):
- raise YAMLStreamError("stream argument needs to have a read() method")
- self._stream = val
- self.name = getattr(self.stream, "name", "")
- self.eof = False
- self.raw_buffer = None
- self.determine_encoding()
-
- def peek(self, index=0):
- # type: (int) -> Text
- try:
- return self.buffer[self.pointer + index]
- except IndexError:
- self.update(index + 1)
- return self.buffer[self.pointer + index]
-
- def prefix(self, length=1):
- # type: (int) -> Any
- if self.pointer + length >= len(self.buffer):
- self.update(length)
- return self.buffer[self.pointer : self.pointer + length]
-
- def forward_1_1(self, length=1):
- # type: (int) -> None
- if self.pointer + length + 1 >= len(self.buffer):
- self.update(length + 1)
- while length != 0:
- ch = self.buffer[self.pointer]
- self.pointer += 1
- self.index += 1
- if ch in u"\n\x85\u2028\u2029" or (
- ch == u"\r" and self.buffer[self.pointer] != u"\n"
- ):
- self.line += 1
- self.column = 0
- elif ch != u"\uFEFF":
- self.column += 1
- length -= 1
-
- def forward(self, length=1):
- # type: (int) -> None
- if self.pointer + length + 1 >= len(self.buffer):
- self.update(length + 1)
- while length != 0:
- ch = self.buffer[self.pointer]
- self.pointer += 1
- self.index += 1
- if ch == u"\n" or (ch == u"\r" and self.buffer[self.pointer] != u"\n"):
- self.line += 1
- self.column = 0
- elif ch != u"\uFEFF":
- self.column += 1
- length -= 1
-
- def get_mark(self):
- # type: () -> Any
- if self.stream is None:
- return StringMark(
- self.name, self.index, self.line, self.column, self.buffer, self.pointer
- )
- else:
- return FileMark(self.name, self.index, self.line, self.column)
-
- def determine_encoding(self):
- # type: () -> None
- while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2):
- self.update_raw()
- if isinstance(self.raw_buffer, binary_type):
- if self.raw_buffer.startswith(codecs.BOM_UTF16_LE):
- self.raw_decode = codecs.utf_16_le_decode # type: ignore
- self.encoding = "utf-16-le"
- elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE):
- self.raw_decode = codecs.utf_16_be_decode # type: ignore
- self.encoding = "utf-16-be"
- else:
- self.raw_decode = codecs.utf_8_decode # type: ignore
- self.encoding = "utf-8"
- self.update(1)
-
- if UNICODE_SIZE == 2:
- NON_PRINTABLE = RegExp(
- u"[^\x09\x0A\x0D\x20-\x7E\x85" u"\xA0-\uD7FF" u"\uE000-\uFFFD" u"]"
- )
- else:
- NON_PRINTABLE = RegExp(
- u"[^\x09\x0A\x0D\x20-\x7E\x85"
- u"\xA0-\uD7FF"
- u"\uE000-\uFFFD"
- u"\U00010000-\U0010FFFF"
- u"]"
- )
-
- _printable_ascii = ("\x09\x0A\x0D" + "".join(map(chr, range(0x20, 0x7F)))).encode(
- "ascii"
- )
-
- @classmethod
- def _get_non_printable_ascii(cls, data): # type: ignore
- # type: (Text, bytes) -> Optional[Tuple[int, Text]]
- ascii_bytes = data.encode("ascii")
- non_printables = ascii_bytes.translate(
- None, cls._printable_ascii
- ) # type: ignore
- if not non_printables:
- return None
- non_printable = non_printables[:1]
- return ascii_bytes.index(non_printable), non_printable.decode("ascii")
-
- @classmethod
- def _get_non_printable_regex(cls, data):
- # type: (Text) -> Optional[Tuple[int, Text]]
- match = cls.NON_PRINTABLE.search(data)
- if not bool(match):
- return None
- return match.start(), match.group()
-
- @classmethod
- def _get_non_printable(cls, data):
- # type: (Text) -> Optional[Tuple[int, Text]]
- try:
- return cls._get_non_printable_ascii(data) # type: ignore
- except UnicodeEncodeError:
- return cls._get_non_printable_regex(data)
-
- def check_printable(self, data):
- # type: (Any) -> None
- non_printable_match = self._get_non_printable(data)
- if non_printable_match is not None:
- start, character = non_printable_match
- position = self.index + (len(self.buffer) - self.pointer) + start
- raise ReaderError(
- self.name,
- position,
- ord(character),
- "unicode",
- "special characters are not allowed",
- )
-
- def update(self, length):
- # type: (int) -> None
- if self.raw_buffer is None:
- return
- self.buffer = self.buffer[self.pointer :]
- self.pointer = 0
- while len(self.buffer) < length:
- if not self.eof:
- self.update_raw()
- if self.raw_decode is not None:
- try:
- data, converted = self.raw_decode(
- self.raw_buffer, "strict", self.eof
- )
- except UnicodeDecodeError as exc:
- if PY3:
- character = self.raw_buffer[exc.start]
- else:
- character = exc.object[exc.start]
- if self.stream is not None:
- position = (
- self.stream_pointer - len(self.raw_buffer) + exc.start
- )
- elif self.stream is not None:
- position = (
- self.stream_pointer - len(self.raw_buffer) + exc.start
- )
- else:
- position = exc.start
- raise ReaderError(
- self.name, position, character, exc.encoding, exc.reason
- )
- else:
- data = self.raw_buffer
- converted = len(data)
- self.check_printable(data)
- self.buffer += data
- self.raw_buffer = self.raw_buffer[converted:]
- if self.eof:
- self.buffer += "\0"
- self.raw_buffer = None
- break
-
- def update_raw(self, size=None):
- # type: (Optional[int]) -> None
- if size is None:
- size = 4096 if PY3 else 1024
- data = self.stream.read(size)
- if self.raw_buffer is None:
- self.raw_buffer = data
- else:
- self.raw_buffer += data
- self.stream_pointer += len(data)
- if not data:
- self.eof = True
-
-
-# try:
-# import psyco
-# psyco.bind(Reader)
-# except ImportError:
-# pass
diff --git a/srsly/ruamel_yaml/representer.py b/srsly/ruamel_yaml/representer.py
deleted file mode 100755
index 73e1793..0000000
--- a/srsly/ruamel_yaml/representer.py
+++ /dev/null
@@ -1,1330 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division
-
-
-from .error import * # NOQA
-from .nodes import * # NOQA
-from .compat import text_type, binary_type, to_unicode, PY2, PY3
-from .compat import ordereddict # type: ignore
-from .compat import nprint, nprintf # NOQA
-from .scalarstring import (
- LiteralScalarString,
- FoldedScalarString,
- SingleQuotedScalarString,
- DoubleQuotedScalarString,
- PlainScalarString,
-)
-from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
-from .scalarfloat import ScalarFloat
-from .scalarbool import ScalarBoolean
-from .timestamp import TimeStamp
-
-import datetime
-import sys
-import types
-
-if PY3:
- import copyreg
- import base64
-else:
- import copy_reg as copyreg # type: ignore
-
-if False: # MYPY
- from typing import Dict, List, Any, Union, Text, Optional # NOQA
-
-# fmt: off
-__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer',
- 'RepresenterError', 'RoundTripRepresenter']
-# fmt: on
-
-
-class RepresenterError(YAMLError):
- pass
-
-
-if PY2:
-
- def get_classobj_bases(cls):
- # type: (Any) -> Any
- bases = [cls]
- for base in cls.__bases__:
- bases.extend(get_classobj_bases(base))
- return bases
-
-
-class BaseRepresenter(object):
-
- yaml_representers = {} # type: Dict[Any, Any]
- yaml_multi_representers = {} # type: Dict[Any, Any]
-
- def __init__(self, default_style=None, default_flow_style=None, dumper=None):
- # type: (Any, Any, Any, Any) -> None
- self.dumper = dumper
- if self.dumper is not None:
- self.dumper._representer = self
- self.default_style = default_style
- self.default_flow_style = default_flow_style
- self.represented_objects = {} # type: Dict[Any, Any]
- self.object_keeper = [] # type: List[Any]
- self.alias_key = None # type: Optional[int]
- self.sort_base_mapping_type_on_output = True
-
- @property
- def serializer(self):
- # type: () -> Any
- try:
- if hasattr(self.dumper, "typ"):
- return self.dumper.serializer
- return self.dumper._serializer
- except AttributeError:
- return self # cyaml
-
- def represent(self, data):
- # type: (Any) -> None
- node = self.represent_data(data)
- self.serializer.serialize(node)
- self.represented_objects = {}
- self.object_keeper = []
- self.alias_key = None
-
- def represent_data(self, data):
- # type: (Any) -> Any
- if self.ignore_aliases(data):
- self.alias_key = None
- else:
- self.alias_key = id(data)
- if self.alias_key is not None:
- if self.alias_key in self.represented_objects:
- node = self.represented_objects[self.alias_key]
- # if node is None:
- # raise RepresenterError(
- # "recursive objects are not allowed: %r" % data)
- return node
- # self.represented_objects[alias_key] = None
- self.object_keeper.append(data)
- data_types = type(data).__mro__
- if PY2:
- # if type(data) is types.InstanceType:
- if isinstance(data, types.InstanceType):
- data_types = get_classobj_bases(data.__class__) + list(data_types)
- if data_types[0] in self.yaml_representers:
- node = self.yaml_representers[data_types[0]](self, data)
- else:
- for data_type in data_types:
- if data_type in self.yaml_multi_representers:
- node = self.yaml_multi_representers[data_type](self, data)
- break
- else:
- if None in self.yaml_multi_representers:
- node = self.yaml_multi_representers[None](self, data)
- elif None in self.yaml_representers:
- node = self.yaml_representers[None](self, data)
- else:
- node = ScalarNode(None, text_type(data))
- # if alias_key is not None:
- # self.represented_objects[alias_key] = node
- return node
-
- def represent_key(self, data):
- # type: (Any) -> Any
- """
- David Fraser: Extract a method to represent keys in mappings, so that
- a subclass can choose not to quote them (for example)
- used in represent_mapping
- https://bitbucket.org/davidfraser/pyyaml/commits/d81df6eb95f20cac4a79eed95ae553b5c6f77b8c
- """
- return self.represent_data(data)
-
- @classmethod
- def add_representer(cls, data_type, representer):
- # type: (Any, Any) -> None
- if "yaml_representers" not in cls.__dict__:
- cls.yaml_representers = cls.yaml_representers.copy()
- cls.yaml_representers[data_type] = representer
-
- @classmethod
- def add_multi_representer(cls, data_type, representer):
- # type: (Any, Any) -> None
- if "yaml_multi_representers" not in cls.__dict__:
- cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
- cls.yaml_multi_representers[data_type] = representer
-
- def represent_scalar(self, tag, value, style=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- if style is None:
- style = self.default_style
- comment = None
- if style and style[0] in "|>":
- comment = getattr(value, "comment", None)
- if comment:
- comment = [None, [comment]]
- node = ScalarNode(tag, value, style=style, comment=comment, anchor=anchor)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- return node
-
- def represent_sequence(self, tag, sequence, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- node = SequenceNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- for item in sequence:
- node_item = self.represent_data(item)
- if not (isinstance(node_item, ScalarNode) and not node_item.style):
- best_style = False
- value.append(node_item)
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- def represent_omap(self, tag, omap, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- node = SequenceNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- for item_key in omap:
- item_val = omap[item_key]
- node_item = self.represent_data({item_key: item_val})
- # if not (isinstance(node_item, ScalarNode) \
- # and not node_item.style):
- # best_style = False
- value.append(node_item)
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- def represent_mapping(self, tag, mapping, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- node = MappingNode(tag, value, flow_style=flow_style)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- if hasattr(mapping, "items"):
- mapping = list(mapping.items())
- if self.sort_base_mapping_type_on_output:
- try:
- mapping = sorted(mapping)
- except TypeError:
- pass
- for item_key, item_value in mapping:
- node_key = self.represent_key(item_key)
- node_value = self.represent_data(item_value)
- if not (isinstance(node_key, ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- def ignore_aliases(self, data):
- # type: (Any) -> bool
- return False
-
-
-class SafeRepresenter(BaseRepresenter):
- def ignore_aliases(self, data):
- # type: (Any) -> bool
- # https://docs.python.org/3/reference/expressions.html#parenthesized-forms :
- # "i.e. two occurrences of the empty tuple may or may not yield the same object"
- # so "data is ()" should not be used
- if data is None or (isinstance(data, tuple) and data == ()):
- return True
- if isinstance(data, (binary_type, text_type, bool, int, float)):
- return True
- return False
-
- def represent_none(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(u"tag:yaml.org,2002:null", u"null")
-
- if PY3:
-
- def represent_str(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(u"tag:yaml.org,2002:str", data)
-
- def represent_binary(self, data):
- # type: (Any) -> Any
- if hasattr(base64, "encodebytes"):
- data = base64.encodebytes(data).decode("ascii")
- else:
- data = base64.encodestring(data).decode("ascii")
- return self.represent_scalar(u"tag:yaml.org,2002:binary", data, style="|")
-
- else:
-
- def represent_str(self, data):
- # type: (Any) -> Any
- tag = None
- style = None
- try:
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- except UnicodeDecodeError:
- try:
- data = unicode(data, "utf-8")
- tag = u"tag:yaml.org,2002:str"
- except UnicodeDecodeError:
- data = data.encode("base64")
- tag = u"tag:yaml.org,2002:binary"
- style = "|"
- return self.represent_scalar(tag, data, style=style)
-
- def represent_unicode(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(u"tag:yaml.org,2002:str", data)
-
- def represent_bool(self, data, anchor=None):
- # type: (Any, Optional[Any]) -> Any
- try:
- value = self.dumper.boolean_representation[bool(data)]
- except AttributeError:
- if data:
- value = u"true"
- else:
- value = u"false"
- return self.represent_scalar(u"tag:yaml.org,2002:bool", value, anchor=anchor)
-
- def represent_int(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data))
-
- if PY2:
-
- def represent_long(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data))
-
- inf_value = 1e300
- while repr(inf_value) != repr(inf_value * inf_value):
- inf_value *= inf_value
-
- def represent_float(self, data):
- # type: (Any) -> Any
- if data != data or (data == 0.0 and data == 1.0):
- value = u".nan"
- elif data == self.inf_value:
- value = u".inf"
- elif data == -self.inf_value:
- value = u"-.inf"
- else:
- value = to_unicode(repr(data)).lower()
- if getattr(self.serializer, "use_version", None) == (1, 1):
- if u"." not in value and u"e" in value:
- # Note that in some cases `repr(data)` represents a float number
- # without the decimal parts. For instance:
- # >>> repr(1e17)
- # '1e17'
- # Unfortunately, this is not a valid float representation according
- # to the definition of the `!!float` tag in YAML 1.1. We fix
- # this by adding '.0' before the 'e' symbol.
- value = value.replace(u"e", u".0e", 1)
- return self.represent_scalar(u"tag:yaml.org,2002:float", value)
-
- def represent_list(self, data):
- # type: (Any) -> Any
- # pairs = (len(data) > 0 and isinstance(data, list))
- # if pairs:
- # for item in data:
- # if not isinstance(item, tuple) or len(item) != 2:
- # pairs = False
- # break
- # if not pairs:
- return self.represent_sequence(u"tag:yaml.org,2002:seq", data)
-
- # value = []
- # for item_key, item_value in data:
- # value.append(self.represent_mapping(u'tag:yaml.org,2002:map',
- # [(item_key, item_value)]))
- # return SequenceNode(u'tag:yaml.org,2002:pairs', value)
-
- def represent_dict(self, data):
- # type: (Any) -> Any
- return self.represent_mapping(u"tag:yaml.org,2002:map", data)
-
- def represent_ordereddict(self, data):
- # type: (Any) -> Any
- return self.represent_omap(u"tag:yaml.org,2002:omap", data)
-
- def represent_set(self, data):
- # type: (Any) -> Any
- value = {} # type: Dict[Any, None]
- for key in data:
- value[key] = None
- return self.represent_mapping(u"tag:yaml.org,2002:set", value)
-
- def represent_date(self, data):
- # type: (Any) -> Any
- value = to_unicode(data.isoformat())
- return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value)
-
- def represent_datetime(self, data):
- # type: (Any) -> Any
- value = to_unicode(data.isoformat(" "))
- return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value)
-
- def represent_yaml_object(self, tag, data, cls, flow_style=None):
- # type: (Any, Any, Any, Any) -> Any
- if hasattr(data, "__getstate__"):
- state = data.__getstate__()
- else:
- state = data.__dict__.copy()
- return self.represent_mapping(tag, state, flow_style=flow_style)
-
- def represent_undefined(self, data):
- # type: (Any) -> None
- raise RepresenterError("cannot represent an object: %s" % (data,))
-
-
-SafeRepresenter.add_representer(type(None), SafeRepresenter.represent_none)
-
-SafeRepresenter.add_representer(str, SafeRepresenter.represent_str)
-
-if PY2:
- SafeRepresenter.add_representer(unicode, SafeRepresenter.represent_unicode)
-else:
- SafeRepresenter.add_representer(bytes, SafeRepresenter.represent_binary)
-
-SafeRepresenter.add_representer(bool, SafeRepresenter.represent_bool)
-
-SafeRepresenter.add_representer(int, SafeRepresenter.represent_int)
-
-if PY2:
- SafeRepresenter.add_representer(long, SafeRepresenter.represent_long)
-
-SafeRepresenter.add_representer(float, SafeRepresenter.represent_float)
-
-SafeRepresenter.add_representer(list, SafeRepresenter.represent_list)
-
-SafeRepresenter.add_representer(tuple, SafeRepresenter.represent_list)
-
-SafeRepresenter.add_representer(dict, SafeRepresenter.represent_dict)
-
-SafeRepresenter.add_representer(set, SafeRepresenter.represent_set)
-
-SafeRepresenter.add_representer(ordereddict, SafeRepresenter.represent_ordereddict)
-
-if sys.version_info >= (2, 7):
- import collections
-
- SafeRepresenter.add_representer(
- collections.OrderedDict, SafeRepresenter.represent_ordereddict
- )
-
-SafeRepresenter.add_representer(datetime.date, SafeRepresenter.represent_date)
-
-SafeRepresenter.add_representer(datetime.datetime, SafeRepresenter.represent_datetime)
-
-SafeRepresenter.add_representer(None, SafeRepresenter.represent_undefined)
-
-
-class Representer(SafeRepresenter):
- if PY2:
-
- def represent_str(self, data):
- # type: (Any) -> Any
- tag = None
- style = None
- try:
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- except UnicodeDecodeError:
- try:
- data = unicode(data, "utf-8")
- tag = u"tag:yaml.org,2002:python/str"
- except UnicodeDecodeError:
- data = data.encode("base64")
- tag = u"tag:yaml.org,2002:binary"
- style = "|"
- return self.represent_scalar(tag, data, style=style)
-
- def represent_unicode(self, data):
- # type: (Any) -> Any
- tag = None
- try:
- data.encode("ascii")
- tag = u"tag:yaml.org,2002:python/unicode"
- except UnicodeEncodeError:
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data)
-
- def represent_long(self, data):
- # type: (Any) -> Any
- tag = u"tag:yaml.org,2002:int"
- if int(data) is not data:
- tag = u"tag:yaml.org,2002:python/long"
- return self.represent_scalar(tag, to_unicode(data))
-
- def represent_complex(self, data):
- # type: (Any) -> Any
- if data.imag == 0.0:
- data = u"%r" % data.real
- elif data.real == 0.0:
- data = u"%rj" % data.imag
- elif data.imag > 0:
- data = u"%r+%rj" % (data.real, data.imag)
- else:
- data = u"%r%rj" % (data.real, data.imag)
- return self.represent_scalar(u"tag:yaml.org,2002:python/complex", data)
-
- def represent_tuple(self, data):
- # type: (Any) -> Any
- return self.represent_sequence(u"tag:yaml.org,2002:python/tuple", data)
-
- def represent_name(self, data):
- # type: (Any) -> Any
- try:
- name = u"%s.%s" % (data.__module__, data.__qualname__)
- except AttributeError:
- # probably PY2
- name = u"%s.%s" % (data.__module__, data.__name__)
- return self.represent_scalar(u"tag:yaml.org,2002:python/name:" + name, "")
-
- def represent_module(self, data):
- # type: (Any) -> Any
- return self.represent_scalar(
- u"tag:yaml.org,2002:python/module:" + data.__name__, ""
- )
-
- if PY2:
-
- def represent_instance(self, data):
- # type: (Any) -> Any
- # For instances of classic classes, we use __getinitargs__ and
- # __getstate__ to serialize the data.
-
- # If data.__getinitargs__ exists, the object must be reconstructed
- # by calling cls(**args), where args is a tuple returned by
- # __getinitargs__. Otherwise, the cls.__init__ method should never
- # be called and the class instance is created by instantiating a
- # trivial class and assigning to the instance's __class__ variable.
-
- # If data.__getstate__ exists, it returns the state of the object.
- # Otherwise, the state of the object is data.__dict__.
-
- # We produce either a !!python/object or !!python/object/new node.
- # If data.__getinitargs__ does not exist and state is a dictionary,
- # we produce a !!python/object node . Otherwise we produce a
- # !!python/object/new node.
-
- cls = data.__class__
- class_name = u"%s.%s" % (cls.__module__, cls.__name__)
- args = None
- state = None
- if hasattr(data, "__getinitargs__"):
- args = list(data.__getinitargs__())
- if hasattr(data, "__getstate__"):
- state = data.__getstate__()
- else:
- state = data.__dict__
- if args is None and isinstance(state, dict):
- return self.represent_mapping(
- u"tag:yaml.org,2002:python/object:" + class_name, state
- )
- if isinstance(state, dict) and not state:
- return self.represent_sequence(
- u"tag:yaml.org,2002:python/object/new:" + class_name, args
- )
- value = {}
- if bool(args):
- value["args"] = args
- value["state"] = state # type: ignore
- return self.represent_mapping(
- u"tag:yaml.org,2002:python/object/new:" + class_name, value
- )
-
- def represent_object(self, data):
- # type: (Any) -> Any
- # We use __reduce__ API to save the data. data.__reduce__ returns
- # a tuple of length 2-5:
- # (function, args, state, listitems, dictitems)
-
- # For reconstructing, we calls function(*args), then set its state,
- # listitems, and dictitems if they are not None.
-
- # A special case is when function.__name__ == '__newobj__'. In this
- # case we create the object with args[0].__new__(*args).
-
- # Another special case is when __reduce__ returns a string - we don't
- # support it.
-
- # We produce a !!python/object, !!python/object/new or
- # !!python/object/apply node.
-
- cls = type(data)
- if cls in copyreg.dispatch_table:
- reduce = copyreg.dispatch_table[cls](data)
- elif hasattr(data, "__reduce_ex__"):
- reduce = data.__reduce_ex__(2)
- elif hasattr(data, "__reduce__"):
- reduce = data.__reduce__()
- else:
- raise RepresenterError("cannot represent object: %r" % (data,))
- reduce = (list(reduce) + [None] * 5)[:5]
- function, args, state, listitems, dictitems = reduce
- args = list(args)
- if state is None:
- state = {}
- if listitems is not None:
- listitems = list(listitems)
- if dictitems is not None:
- dictitems = dict(dictitems)
- if function.__name__ == "__newobj__":
- function = args[0]
- args = args[1:]
- tag = u"tag:yaml.org,2002:python/object/new:"
- newobj = True
- else:
- tag = u"tag:yaml.org,2002:python/object/apply:"
- newobj = False
- try:
- function_name = u"%s.%s" % (function.__module__, function.__qualname__)
- except AttributeError:
- # probably PY2
- function_name = u"%s.%s" % (function.__module__, function.__name__)
- if (
- not args
- and not listitems
- and not dictitems
- and isinstance(state, dict)
- and newobj
- ):
- return self.represent_mapping(
- u"tag:yaml.org,2002:python/object:" + function_name, state
- )
- if not listitems and not dictitems and isinstance(state, dict) and not state:
- return self.represent_sequence(tag + function_name, args)
- value = {}
- if args:
- value["args"] = args
- if state or not isinstance(state, dict):
- value["state"] = state
- if listitems:
- value["listitems"] = listitems
- if dictitems:
- value["dictitems"] = dictitems
- return self.represent_mapping(tag + function_name, value)
-
-
-if PY2:
- Representer.add_representer(str, Representer.represent_str)
-
- Representer.add_representer(unicode, Representer.represent_unicode)
-
- Representer.add_representer(long, Representer.represent_long)
-
-Representer.add_representer(complex, Representer.represent_complex)
-
-Representer.add_representer(tuple, Representer.represent_tuple)
-
-Representer.add_representer(type, Representer.represent_name)
-
-if PY2:
- Representer.add_representer(types.ClassType, Representer.represent_name)
-
-Representer.add_representer(types.FunctionType, Representer.represent_name)
-
-Representer.add_representer(types.BuiltinFunctionType, Representer.represent_name)
-
-Representer.add_representer(types.ModuleType, Representer.represent_module)
-
-if PY2:
- Representer.add_multi_representer(
- types.InstanceType, Representer.represent_instance
- )
-
-Representer.add_multi_representer(object, Representer.represent_object)
-
-Representer.add_multi_representer(type, Representer.represent_name)
-
-from .comments import (
- CommentedMap,
- CommentedOrderedMap,
- CommentedSeq,
- CommentedKeySeq,
- CommentedKeyMap,
- CommentedSet,
- comment_attrib,
- merge_attrib,
- TaggedScalar,
-) # NOQA
-
-
-class RoundTripRepresenter(SafeRepresenter):
- # need to add type here and write out the .comment
- # in serializer and emitter
-
- def __init__(self, default_style=None, default_flow_style=None, dumper=None):
- # type: (Any, Any, Any) -> None
- if not hasattr(dumper, "typ") and default_flow_style is None:
- default_flow_style = False
- SafeRepresenter.__init__(
- self,
- default_style=default_style,
- default_flow_style=default_flow_style,
- dumper=dumper,
- )
-
- def ignore_aliases(self, data):
- # type: (Any) -> bool
- try:
- if data.anchor is not None and data.anchor.value is not None:
- return False
- except AttributeError:
- pass
- return SafeRepresenter.ignore_aliases(self, data)
-
- def represent_none(self, data):
- # type: (Any) -> Any
- if (
- len(self.represented_objects) == 0
- and not self.serializer.use_explicit_start
- ):
- # this will be open ended (although it is not yet)
- return self.represent_scalar(u"tag:yaml.org,2002:null", u"null")
- return self.represent_scalar(u"tag:yaml.org,2002:null", "")
-
- def represent_literal_scalarstring(self, data):
- # type: (Any) -> Any
- tag = None
- style = "|"
- anchor = data.yaml_anchor(any=True)
- if PY2 and not isinstance(data, unicode):
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data, style=style, anchor=anchor)
-
- represent_preserved_scalarstring = represent_literal_scalarstring
-
- def represent_folded_scalarstring(self, data):
- # type: (Any) -> Any
- tag = None
- style = ">"
- anchor = data.yaml_anchor(any=True)
- for fold_pos in reversed(getattr(data, "fold_pos", [])):
- if (
- data[fold_pos] == " "
- and (fold_pos > 0 and not data[fold_pos - 1].isspace())
- and (fold_pos < len(data) and not data[fold_pos + 1].isspace())
- ):
- data = data[:fold_pos] + "\a" + data[fold_pos:]
- if PY2 and not isinstance(data, unicode):
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data, style=style, anchor=anchor)
-
- def represent_single_quoted_scalarstring(self, data):
- # type: (Any) -> Any
- tag = None
- style = "'"
- anchor = data.yaml_anchor(any=True)
- if PY2 and not isinstance(data, unicode):
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data, style=style, anchor=anchor)
-
- def represent_double_quoted_scalarstring(self, data):
- # type: (Any) -> Any
- tag = None
- style = '"'
- anchor = data.yaml_anchor(any=True)
- if PY2 and not isinstance(data, unicode):
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data, style=style, anchor=anchor)
-
- def represent_plain_scalarstring(self, data):
- # type: (Any) -> Any
- tag = None
- style = ""
- anchor = data.yaml_anchor(any=True)
- if PY2 and not isinstance(data, unicode):
- data = unicode(data, "ascii")
- tag = u"tag:yaml.org,2002:str"
- return self.represent_scalar(tag, data, style=style, anchor=anchor)
-
- def insert_underscore(self, prefix, s, underscore, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- if underscore is None:
- return self.represent_scalar(
- u"tag:yaml.org,2002:int", prefix + s, anchor=anchor
- )
- if underscore[0]:
- sl = list(s)
- pos = len(s) - underscore[0]
- while pos > 0:
- sl.insert(pos, "_")
- pos -= underscore[0]
- s = "".join(sl)
- if underscore[1]:
- s = "_" + s
- if underscore[2]:
- s += "_"
- return self.represent_scalar(
- u"tag:yaml.org,2002:int", prefix + s, anchor=anchor
- )
-
- def represent_scalar_int(self, data):
- # type: (Any) -> Any
- if data._width is not None:
- s = "{:0{}d}".format(data, data._width)
- else:
- s = format(data, "d")
- anchor = data.yaml_anchor(any=True)
- return self.insert_underscore("", s, data._underscore, anchor=anchor)
-
- def represent_binary_int(self, data):
- # type: (Any) -> Any
- if data._width is not None:
- # cannot use '{:#0{}b}', that strips the zeros
- s = "{:0{}b}".format(data, data._width)
- else:
- s = format(data, "b")
- anchor = data.yaml_anchor(any=True)
- return self.insert_underscore("0b", s, data._underscore, anchor=anchor)
-
- def represent_octal_int(self, data):
- # type: (Any) -> Any
- if data._width is not None:
- # cannot use '{:#0{}o}', that strips the zeros
- s = "{:0{}o}".format(data, data._width)
- else:
- s = format(data, "o")
- anchor = data.yaml_anchor(any=True)
- return self.insert_underscore("0o", s, data._underscore, anchor=anchor)
-
- def represent_hex_int(self, data):
- # type: (Any) -> Any
- if data._width is not None:
- # cannot use '{:#0{}x}', that strips the zeros
- s = "{:0{}x}".format(data, data._width)
- else:
- s = format(data, "x")
- anchor = data.yaml_anchor(any=True)
- return self.insert_underscore("0x", s, data._underscore, anchor=anchor)
-
- def represent_hex_caps_int(self, data):
- # type: (Any) -> Any
- if data._width is not None:
- # cannot use '{:#0{}X}', that strips the zeros
- s = "{:0{}X}".format(data, data._width)
- else:
- s = format(data, "X")
- anchor = data.yaml_anchor(any=True)
- return self.insert_underscore("0x", s, data._underscore, anchor=anchor)
-
- def represent_scalar_float(self, data):
- # type: (Any) -> Any
- """ this is way more complicated """
- value = None
- anchor = data.yaml_anchor(any=True)
- if data != data or (data == 0.0 and data == 1.0):
- value = u".nan"
- elif data == self.inf_value:
- value = u".inf"
- elif data == -self.inf_value:
- value = u"-.inf"
- if value:
- return self.represent_scalar(
- u"tag:yaml.org,2002:float", value, anchor=anchor
- )
- if data._exp is None and data._prec > 0 and data._prec == data._width - 1:
- # no exponent, but trailing dot
- value = u"{}{:d}.".format(
- data._m_sign if data._m_sign else "", abs(int(data))
- )
- elif data._exp is None:
- # no exponent, "normal" dot
- prec = data._prec
- ms = data._m_sign if data._m_sign else ""
- # -1 for the dot
- value = u"{}{:0{}.{}f}".format(
- ms, abs(data), data._width - len(ms), data._width - prec - 1
- )
- if prec == 0 or (prec == 1 and ms != ""):
- value = value.replace(u"0.", u".")
- while len(value) < data._width:
- value += u"0"
- else:
- # exponent
- m, es = u"{:{}.{}e}".format(
- # data, data._width, data._width - data._prec + (1 if data._m_sign else 0)
- data,
- data._width,
- data._width + (1 if data._m_sign else 0),
- ).split("e")
- w = data._width if data._prec > 0 else (data._width + 1)
- if data < 0:
- w += 1
- m = m[:w]
- e = int(es)
- m1, m2 = m.split(".") # always second?
- while len(m1) + len(m2) < data._width - (1 if data._prec >= 0 else 0):
- m2 += u"0"
- if data._m_sign and data > 0:
- m1 = "+" + m1
- esgn = u"+" if data._e_sign else ""
- if data._prec < 0: # mantissa without dot
- if m2 != u"0":
- e -= len(m2)
- else:
- m2 = ""
- while (len(m1) + len(m2) - (1 if data._m_sign else 0)) < data._width:
- m2 += u"0"
- e -= 1
- value = (
- m1 + m2 + data._exp + u"{:{}0{}d}".format(e, esgn, data._e_width)
- )
- elif data._prec == 0: # mantissa with trailing dot
- e -= len(m2)
- value = (
- m1
- + m2
- + u"."
- + data._exp
- + u"{:{}0{}d}".format(e, esgn, data._e_width)
- )
- else:
- if data._m_lead0 > 0:
- m2 = u"0" * (data._m_lead0 - 1) + m1 + m2
- m1 = u"0"
- m2 = m2[: -data._m_lead0] # these should be zeros
- e += data._m_lead0
- while len(m1) < data._prec:
- m1 += m2[0]
- m2 = m2[1:]
- e -= 1
- value = (
- m1
- + u"."
- + m2
- + data._exp
- + u"{:{}0{}d}".format(e, esgn, data._e_width)
- )
-
- if value is None:
- value = to_unicode(repr(data)).lower()
- return self.represent_scalar(u"tag:yaml.org,2002:float", value, anchor=anchor)
-
- def represent_sequence(self, tag, sequence, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- # if the flow_style is None, the flow style tacked on to the object
- # explicitly will be taken. If that is None as well the default flow
- # style rules
- try:
- flow_style = sequence.fa.flow_style(flow_style)
- except AttributeError:
- flow_style = flow_style
- try:
- anchor = sequence.yaml_anchor()
- except AttributeError:
- anchor = None
- node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- try:
- comment = getattr(sequence, comment_attrib)
- node.comment = comment.comment
- # reset any comment already printed information
- if node.comment and node.comment[1]:
- for ct in node.comment[1]:
- ct.reset()
- item_comments = comment.items
- for v in item_comments.values():
- if v and v[1]:
- for ct in v[1]:
- ct.reset()
- item_comments = comment.items
- node.comment = comment.comment
- try:
- node.comment.append(comment.end)
- except AttributeError:
- pass
- except AttributeError:
- item_comments = {}
- for idx, item in enumerate(sequence):
- node_item = self.represent_data(item)
- self.merge_comments(node_item, item_comments.get(idx))
- if not (isinstance(node_item, ScalarNode) and not node_item.style):
- best_style = False
- value.append(node_item)
- if flow_style is None:
- if len(sequence) != 0 and self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- def merge_comments(self, node, comments):
- # type: (Any, Any) -> Any
- if comments is None:
- assert hasattr(node, "comment")
- return node
- if getattr(node, "comment", None) is not None:
- for idx, val in enumerate(comments):
- if idx >= len(node.comment):
- continue
- nc = node.comment[idx]
- if nc is not None:
- assert val is None or val == nc
- comments[idx] = nc
- node.comment = comments
- return node
-
- def represent_key(self, data):
- # type: (Any) -> Any
- if isinstance(data, CommentedKeySeq):
- self.alias_key = None
- return self.represent_sequence(
- u"tag:yaml.org,2002:seq", data, flow_style=True
- )
- if isinstance(data, CommentedKeyMap):
- self.alias_key = None
- return self.represent_mapping(
- u"tag:yaml.org,2002:map", data, flow_style=True
- )
- return SafeRepresenter.represent_key(self, data)
-
- def represent_mapping(self, tag, mapping, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- try:
- flow_style = mapping.fa.flow_style(flow_style)
- except AttributeError:
- flow_style = flow_style
- try:
- anchor = mapping.yaml_anchor()
- except AttributeError:
- anchor = None
- node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- # no sorting! !!
- try:
- comment = getattr(mapping, comment_attrib)
- node.comment = comment.comment
- if node.comment and node.comment[1]:
- for ct in node.comment[1]:
- ct.reset()
- item_comments = comment.items
- for v in item_comments.values():
- if v and v[1]:
- for ct in v[1]:
- ct.reset()
- try:
- node.comment.append(comment.end)
- except AttributeError:
- pass
- except AttributeError:
- item_comments = {}
- merge_list = [m[1] for m in getattr(mapping, merge_attrib, [])]
- try:
- merge_pos = getattr(mapping, merge_attrib, [[0]])[0][0]
- except IndexError:
- merge_pos = 0
- item_count = 0
- if bool(merge_list):
- items = mapping.non_merged_items()
- else:
- items = mapping.items()
- for item_key, item_value in items:
- item_count += 1
- node_key = self.represent_key(item_key)
- node_value = self.represent_data(item_value)
- item_comment = item_comments.get(item_key)
- if item_comment:
- assert getattr(node_key, "comment", None) is None
- node_key.comment = item_comment[:2]
- nvc = getattr(node_value, "comment", None)
- if nvc is not None: # end comment already there
- nvc[0] = item_comment[2]
- nvc[1] = item_comment[3]
- else:
- node_value.comment = item_comment[2:]
- if not (isinstance(node_key, ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- if flow_style is None:
- if (
- (item_count != 0) or bool(merge_list)
- ) and self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- if bool(merge_list):
- # because of the call to represent_data here, the anchors
- # are marked as being used and thereby created
- if len(merge_list) == 1:
- arg = self.represent_data(merge_list[0])
- else:
- arg = self.represent_data(merge_list)
- arg.flow_style = True
- value.insert(merge_pos, (ScalarNode(u"tag:yaml.org,2002:merge", "<<"), arg))
- return node
-
- def represent_omap(self, tag, omap, flow_style=None):
- # type: (Any, Any, Any) -> Any
- value = [] # type: List[Any]
- try:
- flow_style = omap.fa.flow_style(flow_style)
- except AttributeError:
- flow_style = flow_style
- try:
- anchor = omap.yaml_anchor()
- except AttributeError:
- anchor = None
- node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- try:
- comment = getattr(omap, comment_attrib)
- node.comment = comment.comment
- if node.comment and node.comment[1]:
- for ct in node.comment[1]:
- ct.reset()
- item_comments = comment.items
- for v in item_comments.values():
- if v and v[1]:
- for ct in v[1]:
- ct.reset()
- try:
- node.comment.append(comment.end)
- except AttributeError:
- pass
- except AttributeError:
- item_comments = {}
- for item_key in omap:
- item_val = omap[item_key]
- node_item = self.represent_data({item_key: item_val})
- # node_item.flow_style = False
- # node item has two scalars in value: node_key and node_value
- item_comment = item_comments.get(item_key)
- if item_comment:
- if item_comment[1]:
- node_item.comment = [None, item_comment[1]]
- assert getattr(node_item.value[0][0], "comment", None) is None
- node_item.value[0][0].comment = [item_comment[0], None]
- nvc = getattr(node_item.value[0][1], "comment", None)
- if nvc is not None: # end comment already there
- nvc[0] = item_comment[2]
- nvc[1] = item_comment[3]
- else:
- node_item.value[0][1].comment = item_comment[2:]
- # if not (isinstance(node_item, ScalarNode) \
- # and not node_item.style):
- # best_style = False
- value.append(node_item)
- if flow_style is None:
- if self.default_flow_style is not None:
- node.flow_style = self.default_flow_style
- else:
- node.flow_style = best_style
- return node
-
- def represent_set(self, setting):
- # type: (Any) -> Any
- flow_style = False
- tag = u"tag:yaml.org,2002:set"
- # return self.represent_mapping(tag, value)
- value = [] # type: List[Any]
- flow_style = setting.fa.flow_style(flow_style)
- try:
- anchor = setting.yaml_anchor()
- except AttributeError:
- anchor = None
- node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor)
- if self.alias_key is not None:
- self.represented_objects[self.alias_key] = node
- best_style = True
- # no sorting! !!
- try:
- comment = getattr(setting, comment_attrib)
- node.comment = comment.comment
- if node.comment and node.comment[1]:
- for ct in node.comment[1]:
- ct.reset()
- item_comments = comment.items
- for v in item_comments.values():
- if v and v[1]:
- for ct in v[1]:
- ct.reset()
- try:
- node.comment.append(comment.end)
- except AttributeError:
- pass
- except AttributeError:
- item_comments = {}
- for item_key in setting.odict:
- node_key = self.represent_key(item_key)
- node_value = self.represent_data(None)
- item_comment = item_comments.get(item_key)
- if item_comment:
- assert getattr(node_key, "comment", None) is None
- node_key.comment = item_comment[:2]
- node_key.style = node_value.style = "?"
- if not (isinstance(node_key, ScalarNode) and not node_key.style):
- best_style = False
- if not (isinstance(node_value, ScalarNode) and not node_value.style):
- best_style = False
- value.append((node_key, node_value))
- best_style = best_style
- return node
-
- def represent_dict(self, data):
- # type: (Any) -> Any
- """write out tag if saved on loading"""
- try:
- t = data.tag.value
- except AttributeError:
- t = None
- if t:
- if t.startswith("!!"):
- tag = "tag:yaml.org,2002:" + t[2:]
- else:
- tag = t
- else:
- tag = u"tag:yaml.org,2002:map"
- return self.represent_mapping(tag, data)
-
- def represent_list(self, data):
- # type: (Any) -> Any
- try:
- t = data.tag.value
- except AttributeError:
- t = None
- if t:
- if t.startswith("!!"):
- tag = "tag:yaml.org,2002:" + t[2:]
- else:
- tag = t
- else:
- tag = u"tag:yaml.org,2002:seq"
- return self.represent_sequence(tag, data)
-
- def represent_datetime(self, data):
- # type: (Any) -> Any
- inter = "T" if data._yaml["t"] else " "
- _yaml = data._yaml
- if _yaml["delta"]:
- data += _yaml["delta"]
- value = data.isoformat(inter)
- else:
- value = data.isoformat(inter)
- if _yaml["tz"]:
- value += _yaml["tz"]
- return self.represent_scalar(u"tag:yaml.org,2002:timestamp", to_unicode(value))
-
- def represent_tagged_scalar(self, data):
- # type: (Any) -> Any
- try:
- tag = data.tag.value
- except AttributeError:
- tag = None
- try:
- anchor = data.yaml_anchor()
- except AttributeError:
- anchor = None
- return self.represent_scalar(tag, data.value, style=data.style, anchor=anchor)
-
- def represent_scalar_bool(self, data):
- # type: (Any) -> Any
- try:
- anchor = data.yaml_anchor()
- except AttributeError:
- anchor = None
- return SafeRepresenter.represent_bool(self, data, anchor=anchor)
-
-
-RoundTripRepresenter.add_representer(type(None), RoundTripRepresenter.represent_none)
-
-RoundTripRepresenter.add_representer(
- LiteralScalarString, RoundTripRepresenter.represent_literal_scalarstring
-)
-
-RoundTripRepresenter.add_representer(
- FoldedScalarString, RoundTripRepresenter.represent_folded_scalarstring
-)
-
-RoundTripRepresenter.add_representer(
- SingleQuotedScalarString, RoundTripRepresenter.represent_single_quoted_scalarstring
-)
-
-RoundTripRepresenter.add_representer(
- DoubleQuotedScalarString, RoundTripRepresenter.represent_double_quoted_scalarstring
-)
-
-RoundTripRepresenter.add_representer(
- PlainScalarString, RoundTripRepresenter.represent_plain_scalarstring
-)
-
-# RoundTripRepresenter.add_representer(tuple, Representer.represent_tuple)
-
-RoundTripRepresenter.add_representer(
- ScalarInt, RoundTripRepresenter.represent_scalar_int
-)
-
-RoundTripRepresenter.add_representer(
- BinaryInt, RoundTripRepresenter.represent_binary_int
-)
-
-RoundTripRepresenter.add_representer(OctalInt, RoundTripRepresenter.represent_octal_int)
-
-RoundTripRepresenter.add_representer(HexInt, RoundTripRepresenter.represent_hex_int)
-
-RoundTripRepresenter.add_representer(
- HexCapsInt, RoundTripRepresenter.represent_hex_caps_int
-)
-
-RoundTripRepresenter.add_representer(
- ScalarFloat, RoundTripRepresenter.represent_scalar_float
-)
-
-RoundTripRepresenter.add_representer(
- ScalarBoolean, RoundTripRepresenter.represent_scalar_bool
-)
-
-RoundTripRepresenter.add_representer(CommentedSeq, RoundTripRepresenter.represent_list)
-
-RoundTripRepresenter.add_representer(CommentedMap, RoundTripRepresenter.represent_dict)
-
-RoundTripRepresenter.add_representer(
- CommentedOrderedMap, RoundTripRepresenter.represent_ordereddict
-)
-
-if sys.version_info >= (2, 7):
- import collections
-
- RoundTripRepresenter.add_representer(
- collections.OrderedDict, RoundTripRepresenter.represent_ordereddict
- )
-
-RoundTripRepresenter.add_representer(CommentedSet, RoundTripRepresenter.represent_set)
-
-RoundTripRepresenter.add_representer(
- TaggedScalar, RoundTripRepresenter.represent_tagged_scalar
-)
-
-RoundTripRepresenter.add_representer(TimeStamp, RoundTripRepresenter.represent_datetime)
diff --git a/srsly/ruamel_yaml/resolver.py b/srsly/ruamel_yaml/resolver.py
deleted file mode 100755
index edcee9b..0000000
--- a/srsly/ruamel_yaml/resolver.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-import re
-
-if False: # MYPY
- from typing import Any, Dict, List, Union, Text, Optional # NOQA
- from .compat import VersionType # NOQA
-
-from .compat import string_types, _DEFAULT_YAML_VERSION # NOQA
-from .error import * # NOQA
-from .nodes import MappingNode, ScalarNode, SequenceNode # NOQA
-from .util import RegExp # NOQA
-
-__all__ = ["BaseResolver", "Resolver", "VersionedResolver"]
-
-
-# fmt: off
-# resolvers consist of
-# - a list of applicable version
-# - a tag
-# - a regexp
-# - a list of first characters to match
-implicit_resolvers = [
- ([(1, 2)],
- u'tag:yaml.org,2002:bool',
- RegExp(u'''^(?:true|True|TRUE|false|False|FALSE)$''', re.X),
- list(u'tTfF')),
- ([(1, 1)],
- u'tag:yaml.org,2002:bool',
- RegExp(u'''^(?:y|Y|yes|Yes|YES|n|N|no|No|NO
- |true|True|TRUE|false|False|FALSE
- |on|On|ON|off|Off|OFF)$''', re.X),
- list(u'yYnNtTfFoO')),
- ([(1, 2)],
- u'tag:yaml.org,2002:float',
- RegExp(u'''^(?:
- [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
- |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
- |[-+]?\\.[0-9_]+(?:[eE][-+][0-9]+)?
- |[-+]?\\.(?:inf|Inf|INF)
- |\\.(?:nan|NaN|NAN))$''', re.X),
- list(u'-+0123456789.')),
- ([(1, 1)],
- u'tag:yaml.org,2002:float',
- RegExp(u'''^(?:
- [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
- |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
- |\\.[0-9_]+(?:[eE][-+][0-9]+)?
- |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* # sexagesimal float
- |[-+]?\\.(?:inf|Inf|INF)
- |\\.(?:nan|NaN|NAN))$''', re.X),
- list(u'-+0123456789.')),
- ([(1, 2)],
- u'tag:yaml.org,2002:int',
- RegExp(u'''^(?:[-+]?0b[0-1_]+
- |[-+]?0o?[0-7_]+
- |[-+]?[0-9_]+
- |[-+]?0x[0-9a-fA-F_]+)$''', re.X),
- list(u'-+0123456789')),
- ([(1, 1)],
- u'tag:yaml.org,2002:int',
- RegExp(u'''^(?:[-+]?0b[0-1_]+
- |[-+]?0?[0-7_]+
- |[-+]?(?:0|[1-9][0-9_]*)
- |[-+]?0x[0-9a-fA-F_]+
- |[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), # sexagesimal int
- list(u'-+0123456789')),
- ([(1, 2), (1, 1)],
- u'tag:yaml.org,2002:merge',
- RegExp(u'^(?:<<)$'),
- [u'<']),
- ([(1, 2), (1, 1)],
- u'tag:yaml.org,2002:null',
- RegExp(u'''^(?: ~
- |null|Null|NULL
- | )$''', re.X),
- [u'~', u'n', u'N', u'']),
- ([(1, 2), (1, 1)],
- u'tag:yaml.org,2002:timestamp',
- RegExp(u'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]
- |[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]?
- (?:[Tt]|[ \\t]+)[0-9][0-9]?
- :[0-9][0-9] :[0-9][0-9] (?:\\.[0-9]*)?
- (?:[ \\t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X),
- list(u'0123456789')),
- ([(1, 2), (1, 1)],
- u'tag:yaml.org,2002:value',
- RegExp(u'^(?:=)$'),
- [u'=']),
- # The following resolver is only for documentation purposes. It cannot work
- # because plain scalars cannot start with '!', '&', or '*'.
- ([(1, 2), (1, 1)],
- u'tag:yaml.org,2002:yaml',
- RegExp(u'^(?:!|&|\\*)$'),
- list(u'!&*')),
-]
-# fmt: on
-
-
-class ResolverError(YAMLError):
- pass
-
-
-class BaseResolver(object):
-
- DEFAULT_SCALAR_TAG = u"tag:yaml.org,2002:str"
- DEFAULT_SEQUENCE_TAG = u"tag:yaml.org,2002:seq"
- DEFAULT_MAPPING_TAG = u"tag:yaml.org,2002:map"
-
- yaml_implicit_resolvers = {} # type: Dict[Any, Any]
- yaml_path_resolvers = {} # type: Dict[Any, Any]
-
- def __init__(self, loadumper=None):
- # type: (Any, Any) -> None
- self.loadumper = loadumper
- if (
- self.loadumper is not None
- and getattr(self.loadumper, "_resolver", None) is None
- ):
- self.loadumper._resolver = self.loadumper
- self._loader_version = None # type: Any
- self.resolver_exact_paths = [] # type: List[Any]
- self.resolver_prefix_paths = [] # type: List[Any]
-
- @property
- def parser(self):
- # type: () -> Any
- if self.loadumper is not None:
- if hasattr(self.loadumper, "typ"):
- return self.loadumper.parser
- return self.loadumper._parser
- return None
-
- @classmethod
- def add_implicit_resolver_base(cls, tag, regexp, first):
- # type: (Any, Any, Any) -> None
- if "yaml_implicit_resolvers" not in cls.__dict__:
- # deepcopy doesn't work here
- cls.yaml_implicit_resolvers = dict(
- (k, cls.yaml_implicit_resolvers[k][:])
- for k in cls.yaml_implicit_resolvers
- )
- if first is None:
- first = [None]
- for ch in first:
- cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
-
- @classmethod
- def add_implicit_resolver(cls, tag, regexp, first):
- # type: (Any, Any, Any) -> None
- if "yaml_implicit_resolvers" not in cls.__dict__:
- # deepcopy doesn't work here
- cls.yaml_implicit_resolvers = dict(
- (k, cls.yaml_implicit_resolvers[k][:])
- for k in cls.yaml_implicit_resolvers
- )
- if first is None:
- first = [None]
- for ch in first:
- cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
- implicit_resolvers.append(([(1, 2), (1, 1)], tag, regexp, first))
-
- # @classmethod
- # def add_implicit_resolver(cls, tag, regexp, first):
-
- @classmethod
- def add_path_resolver(cls, tag, path, kind=None):
- # type: (Any, Any, Any) -> None
- # Note: `add_path_resolver` is experimental. The API could be changed.
- # `new_path` is a pattern that is matched against the path from the
- # root to the node that is being considered. `node_path` elements are
- # tuples `(node_check, index_check)`. `node_check` is a node class:
- # `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None`
- # matches any kind of a node. `index_check` could be `None`, a boolean
- # value, a string value, or a number. `None` and `False` match against
- # any _value_ of sequence and mapping nodes. `True` matches against
- # any _key_ of a mapping node. A string `index_check` matches against
- # a mapping value that corresponds to a scalar key which content is
- # equal to the `index_check` value. An integer `index_check` matches
- # against a sequence value with the index equal to `index_check`.
- if "yaml_path_resolvers" not in cls.__dict__:
- cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy()
- new_path = [] # type: List[Any]
- for element in path:
- if isinstance(element, (list, tuple)):
- if len(element) == 2:
- node_check, index_check = element
- elif len(element) == 1:
- node_check = element[0]
- index_check = True
- else:
- raise ResolverError("Invalid path element: %s" % (element,))
- else:
- node_check = None
- index_check = element
- if node_check is str:
- node_check = ScalarNode
- elif node_check is list:
- node_check = SequenceNode
- elif node_check is dict:
- node_check = MappingNode
- elif (
- node_check not in [ScalarNode, SequenceNode, MappingNode]
- and not isinstance(node_check, string_types)
- and node_check is not None
- ):
- raise ResolverError("Invalid node checker: %s" % (node_check,))
- if (
- not isinstance(index_check, (string_types, int))
- and index_check is not None
- ):
- raise ResolverError("Invalid index checker: %s" % (index_check,))
- new_path.append((node_check, index_check))
- if kind is str:
- kind = ScalarNode
- elif kind is list:
- kind = SequenceNode
- elif kind is dict:
- kind = MappingNode
- elif kind not in [ScalarNode, SequenceNode, MappingNode] and kind is not None:
- raise ResolverError("Invalid node kind: %s" % (kind,))
- cls.yaml_path_resolvers[tuple(new_path), kind] = tag
-
- def descend_resolver(self, current_node, current_index):
- # type: (Any, Any) -> None
- if not self.yaml_path_resolvers:
- return
- exact_paths = {}
- prefix_paths = []
- if current_node:
- depth = len(self.resolver_prefix_paths)
- for path, kind in self.resolver_prefix_paths[-1]:
- if self.check_resolver_prefix(
- depth, path, kind, current_node, current_index
- ):
- if len(path) > depth:
- prefix_paths.append((path, kind))
- else:
- exact_paths[kind] = self.yaml_path_resolvers[path, kind]
- else:
- for path, kind in self.yaml_path_resolvers:
- if not path:
- exact_paths[kind] = self.yaml_path_resolvers[path, kind]
- else:
- prefix_paths.append((path, kind))
- self.resolver_exact_paths.append(exact_paths)
- self.resolver_prefix_paths.append(prefix_paths)
-
- def ascend_resolver(self):
- # type: () -> None
- if not self.yaml_path_resolvers:
- return
- self.resolver_exact_paths.pop()
- self.resolver_prefix_paths.pop()
-
- def check_resolver_prefix(self, depth, path, kind, current_node, current_index):
- # type: (int, Text, Any, Any, Any) -> bool
- node_check, index_check = path[depth - 1]
- if isinstance(node_check, string_types):
- if current_node.tag != node_check:
- return False
- elif node_check is not None:
- if not isinstance(current_node, node_check):
- return False
- if index_check is True and current_index is not None:
- return False
- if (index_check is False or index_check is None) and current_index is None:
- return False
- if isinstance(index_check, string_types):
- if not (
- isinstance(current_index, ScalarNode)
- and index_check == current_index.value
- ):
- return False
- elif isinstance(index_check, int) and not isinstance(index_check, bool):
- if index_check != current_index:
- return False
- return True
-
- def resolve(self, kind, value, implicit):
- # type: (Any, Any, Any) -> Any
- if kind is ScalarNode and implicit[0]:
- if value == "":
- resolvers = self.yaml_implicit_resolvers.get("", [])
- else:
- resolvers = self.yaml_implicit_resolvers.get(value[0], [])
- resolvers += self.yaml_implicit_resolvers.get(None, [])
- for tag, regexp in resolvers:
- if regexp.match(value):
- return tag
- implicit = implicit[1]
- if bool(self.yaml_path_resolvers):
- exact_paths = self.resolver_exact_paths[-1]
- if kind in exact_paths:
- return exact_paths[kind]
- if None in exact_paths:
- return exact_paths[None]
- if kind is ScalarNode:
- return self.DEFAULT_SCALAR_TAG
- elif kind is SequenceNode:
- return self.DEFAULT_SEQUENCE_TAG
- elif kind is MappingNode:
- return self.DEFAULT_MAPPING_TAG
-
- @property
- def processing_version(self):
- # type: () -> Any
- return None
-
-
-class Resolver(BaseResolver):
- pass
-
-
-for ir in implicit_resolvers:
- if (1, 2) in ir[0]:
- Resolver.add_implicit_resolver_base(*ir[1:])
-
-
-class VersionedResolver(BaseResolver):
- """
- contrary to the "normal" resolver, the smart resolver delays loading
- the pattern matching rules. That way it can decide to load 1.1 rules
- or the (default) 1.2 rules, that no longer support octal without 0o, sexagesimals
- and Yes/No/On/Off booleans.
- """
-
- def __init__(self, version=None, loader=None, loadumper=None):
- # type: (Optional[VersionType], Any, Any) -> None
- if loader is None and loadumper is not None:
- loader = loadumper
- BaseResolver.__init__(self, loader)
- self._loader_version = self.get_loader_version(version)
- self._version_implicit_resolver = {} # type: Dict[Any, Any]
-
- def add_version_implicit_resolver(self, version, tag, regexp, first):
- # type: (VersionType, Any, Any, Any) -> None
- if first is None:
- first = [None]
- impl_resolver = self._version_implicit_resolver.setdefault(version, {})
- for ch in first:
- impl_resolver.setdefault(ch, []).append((tag, regexp))
-
- def get_loader_version(self, version):
- # type: (Optional[VersionType]) -> Any
- if version is None or isinstance(version, tuple):
- return version
- if isinstance(version, list):
- return tuple(version)
- # assume string
- return tuple(map(int, version.split(u".")))
-
- @property
- def versioned_resolver(self):
- # type: () -> Any
- """
- select the resolver based on the version we are parsing
- """
- version = self.processing_version
- if version not in self._version_implicit_resolver:
- for x in implicit_resolvers:
- if version in x[0]:
- self.add_version_implicit_resolver(version, x[1], x[2], x[3])
- return self._version_implicit_resolver[version]
-
- def resolve(self, kind, value, implicit):
- # type: (Any, Any, Any) -> Any
- if kind is ScalarNode and implicit[0]:
- if value == "":
- resolvers = self.versioned_resolver.get("", [])
- else:
- resolvers = self.versioned_resolver.get(value[0], [])
- resolvers += self.versioned_resolver.get(None, [])
- for tag, regexp in resolvers:
- if regexp.match(value):
- return tag
- implicit = implicit[1]
- if bool(self.yaml_path_resolvers):
- exact_paths = self.resolver_exact_paths[-1]
- if kind in exact_paths:
- return exact_paths[kind]
- if None in exact_paths:
- return exact_paths[None]
- if kind is ScalarNode:
- return self.DEFAULT_SCALAR_TAG
- elif kind is SequenceNode:
- return self.DEFAULT_SEQUENCE_TAG
- elif kind is MappingNode:
- return self.DEFAULT_MAPPING_TAG
-
- @property
- def processing_version(self):
- # type: () -> Any
- try:
- version = self.loadumper._scanner.yaml_version
- except AttributeError:
- try:
- if hasattr(self.loadumper, "typ"):
- version = self.loadumper.version
- else:
- version = self.loadumper._serializer.use_version # dumping
- except AttributeError:
- version = None
- if version is None:
- version = self._loader_version
- if version is None:
- version = _DEFAULT_YAML_VERSION
- return version
diff --git a/srsly/ruamel_yaml/scalarbool.py b/srsly/ruamel_yaml/scalarbool.py
deleted file mode 100755
index 277f4d1..0000000
--- a/srsly/ruamel_yaml/scalarbool.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-"""
-You cannot subclass bool, and this is necessary for round-tripping anchored
-bool values (and also if you want to preserve the original way of writing)
-
-bool.__bases__ is type 'int', so that is what is used as the basis for ScalarBoolean as well.
-
-You can use these in an if statement, but not when testing equivalence
-"""
-
-from .anchor import Anchor
-
-if False: # MYPY
- from typing import Text, Any, Dict, List # NOQA
-
-__all__ = ["ScalarBoolean"]
-
-# no need for no_limit_int -> int
-
-
-class ScalarBoolean(int):
- def __new__(cls, *args, **kw):
- # type: (Any, Any, Any) -> Any
- anchor = kw.pop("anchor", None) # type: ignore
- b = int.__new__(cls, *args, **kw) # type: ignore
- if anchor is not None:
- b.yaml_set_anchor(anchor, always_dump=True)
- return b
-
- @property
- def anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- setattr(self, Anchor.attrib, Anchor())
- return getattr(self, Anchor.attrib)
-
- def yaml_anchor(self, any=False):
- # type: (bool) -> Any
- if not hasattr(self, Anchor.attrib):
- return None
- if any or self.anchor.always_dump:
- return self.anchor
- return None
-
- def yaml_set_anchor(self, value, always_dump=False):
- # type: (Any, bool) -> None
- self.anchor.value = value
- self.anchor.always_dump = always_dump
diff --git a/srsly/ruamel_yaml/scalarfloat.py b/srsly/ruamel_yaml/scalarfloat.py
deleted file mode 100755
index 57fd7f9..0000000
--- a/srsly/ruamel_yaml/scalarfloat.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-import sys
-from .compat import no_limit_int # NOQA
-from .anchor import Anchor
-
-if False: # MYPY
- from typing import Text, Any, Dict, List # NOQA
-
-__all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"]
-
-
-class ScalarFloat(float):
- def __new__(cls, *args, **kw):
- # type: (Any, Any, Any) -> Any
- width = kw.pop("width", None) # type: ignore
- prec = kw.pop("prec", None) # type: ignore
- m_sign = kw.pop("m_sign", None) # type: ignore
- m_lead0 = kw.pop("m_lead0", 0) # type: ignore
- exp = kw.pop("exp", None) # type: ignore
- e_width = kw.pop("e_width", None) # type: ignore
- e_sign = kw.pop("e_sign", None) # type: ignore
- underscore = kw.pop("underscore", None) # type: ignore
- anchor = kw.pop("anchor", None) # type: ignore
- v = float.__new__(cls, *args, **kw) # type: ignore
- v._width = width
- v._prec = prec
- v._m_sign = m_sign
- v._m_lead0 = m_lead0
- v._exp = exp
- v._e_width = e_width
- v._e_sign = e_sign
- v._underscore = underscore
- if anchor is not None:
- v.yaml_set_anchor(anchor, always_dump=True)
- return v
-
- def __iadd__(self, a): # type: ignore
- # type: (Any) -> Any
- return float(self) + a
- x = type(self)(self + a)
- x._width = self._width
- x._underscore = (
- self._underscore[:] if self._underscore is not None else None
- ) # NOQA
- return x
-
- def __ifloordiv__(self, a): # type: ignore
- # type: (Any) -> Any
- return float(self) // a
- x = type(self)(self // a)
- x._width = self._width
- x._underscore = (
- self._underscore[:] if self._underscore is not None else None
- ) # NOQA
- return x
-
- def __imul__(self, a): # type: ignore
- # type: (Any) -> Any
- return float(self) * a
- x = type(self)(self * a)
- x._width = self._width
- x._underscore = (
- self._underscore[:] if self._underscore is not None else None
- ) # NOQA
- x._prec = self._prec # check for others
- return x
-
- def __ipow__(self, a): # type: ignore
- # type: (Any) -> Any
- return float(self) ** a
- x = type(self)(self ** a)
- x._width = self._width
- x._underscore = (
- self._underscore[:] if self._underscore is not None else None
- ) # NOQA
- return x
-
- def __isub__(self, a): # type: ignore
- # type: (Any) -> Any
- return float(self) - a
- x = type(self)(self - a)
- x._width = self._width
- x._underscore = (
- self._underscore[:] if self._underscore is not None else None
- ) # NOQA
- return x
-
- @property
- def anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- setattr(self, Anchor.attrib, Anchor())
- return getattr(self, Anchor.attrib)
-
- def yaml_anchor(self, any=False):
- # type: (bool) -> Any
- if not hasattr(self, Anchor.attrib):
- return None
- if any or self.anchor.always_dump:
- return self.anchor
- return None
-
- def yaml_set_anchor(self, value, always_dump=False):
- # type: (Any, bool) -> None
- self.anchor.value = value
- self.anchor.always_dump = always_dump
-
- def dump(self, out=sys.stdout):
- # type: (Any) -> Any
- out.write(
- "ScalarFloat({}| w:{}, p:{}, s:{}, lz:{}, _:{}|{}, w:{}, s:{})\n".format(
- self,
- self._width, # type: ignore
- self._prec, # type: ignore
- self._m_sign, # type: ignore
- self._m_lead0, # type: ignore
- self._underscore, # type: ignore
- self._exp, # type: ignore
- self._e_width, # type: ignore
- self._e_sign, # type: ignore
- )
- )
-
-
-class ExponentialFloat(ScalarFloat):
- def __new__(cls, value, width=None, underscore=None):
- # type: (Any, Any, Any) -> Any
- return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)
-
-
-class ExponentialCapsFloat(ScalarFloat):
- def __new__(cls, value, width=None, underscore=None):
- # type: (Any, Any, Any) -> Any
- return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)
diff --git a/srsly/ruamel_yaml/scalarint.py b/srsly/ruamel_yaml/scalarint.py
deleted file mode 100755
index 383e25c..0000000
--- a/srsly/ruamel_yaml/scalarint.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-from .compat import no_limit_int # NOQA
-from .anchor import Anchor
-
-if False: # MYPY
- from typing import Text, Any, Dict, List # NOQA
-
-__all__ = ["ScalarInt", "BinaryInt", "OctalInt", "HexInt", "HexCapsInt", "DecimalInt"]
-
-
-class ScalarInt(no_limit_int):
- def __new__(cls, *args, **kw):
- # type: (Any, Any, Any) -> Any
- width = kw.pop("width", None) # type: ignore
- underscore = kw.pop("underscore", None) # type: ignore
- anchor = kw.pop("anchor", None) # type: ignore
- v = no_limit_int.__new__(cls, *args, **kw) # type: ignore
- v._width = width
- v._underscore = underscore
- if anchor is not None:
- v.yaml_set_anchor(anchor, always_dump=True)
- return v
-
- def __iadd__(self, a): # type: ignore
- # type: (Any) -> Any
- x = type(self)(self + a)
- x._width = self._width # type: ignore
- x._underscore = ( # type: ignore
- self._underscore[:]
- if self._underscore is not None
- else None # type: ignore
- ) # NOQA
- return x
-
- def __ifloordiv__(self, a): # type: ignore
- # type: (Any) -> Any
- x = type(self)(self // a)
- x._width = self._width # type: ignore
- x._underscore = ( # type: ignore
- self._underscore[:]
- if self._underscore is not None
- else None # type: ignore
- ) # NOQA
- return x
-
- def __imul__(self, a): # type: ignore
- # type: (Any) -> Any
- x = type(self)(self * a)
- x._width = self._width # type: ignore
- x._underscore = ( # type: ignore
- self._underscore[:]
- if self._underscore is not None
- else None # type: ignore
- ) # NOQA
- return x
-
- def __ipow__(self, a): # type: ignore
- # type: (Any) -> Any
- x = type(self)(self ** a)
- x._width = self._width # type: ignore
- x._underscore = ( # type: ignore
- self._underscore[:]
- if self._underscore is not None
- else None # type: ignore
- ) # NOQA
- return x
-
- def __isub__(self, a): # type: ignore
- # type: (Any) -> Any
- x = type(self)(self - a)
- x._width = self._width # type: ignore
- x._underscore = ( # type: ignore
- self._underscore[:]
- if self._underscore is not None
- else None # type: ignore
- ) # NOQA
- return x
-
- @property
- def anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- setattr(self, Anchor.attrib, Anchor())
- return getattr(self, Anchor.attrib)
-
- def yaml_anchor(self, any=False):
- # type: (bool) -> Any
- if not hasattr(self, Anchor.attrib):
- return None
- if any or self.anchor.always_dump:
- return self.anchor
- return None
-
- def yaml_set_anchor(self, value, always_dump=False):
- # type: (Any, bool) -> None
- self.anchor.value = value
- self.anchor.always_dump = always_dump
-
-
-class BinaryInt(ScalarInt):
- def __new__(cls, value, width=None, underscore=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- return ScalarInt.__new__(
- cls, value, width=width, underscore=underscore, anchor=anchor
- )
-
-
-class OctalInt(ScalarInt):
- def __new__(cls, value, width=None, underscore=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- return ScalarInt.__new__(
- cls, value, width=width, underscore=underscore, anchor=anchor
- )
-
-
-# mixed casing of A-F is not supported, when loading the first non digit
-# determines the case
-
-
-class HexInt(ScalarInt):
- """uses lower case (a-f)"""
-
- def __new__(cls, value, width=None, underscore=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- return ScalarInt.__new__(
- cls, value, width=width, underscore=underscore, anchor=anchor
- )
-
-
-class HexCapsInt(ScalarInt):
- """uses upper case (A-F)"""
-
- def __new__(cls, value, width=None, underscore=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- return ScalarInt.__new__(
- cls, value, width=width, underscore=underscore, anchor=anchor
- )
-
-
-class DecimalInt(ScalarInt):
- """needed if anchor"""
-
- def __new__(cls, value, width=None, underscore=None, anchor=None):
- # type: (Any, Any, Any, Any) -> Any
- return ScalarInt.__new__(
- cls, value, width=width, underscore=underscore, anchor=anchor
- )
diff --git a/srsly/ruamel_yaml/scalarstring.py b/srsly/ruamel_yaml/scalarstring.py
deleted file mode 100755
index 0638e4b..0000000
--- a/srsly/ruamel_yaml/scalarstring.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-from .compat import text_type
-from .anchor import Anchor
-
-if False: # MYPY
- from typing import Text, Any, Dict, List # NOQA
-
-__all__ = [
- "ScalarString",
- "LiteralScalarString",
- "FoldedScalarString",
- "SingleQuotedScalarString",
- "DoubleQuotedScalarString",
- "PlainScalarString",
- # PreservedScalarString is the old name, as it was the first to be preserved on rt,
- # use LiteralScalarString instead
- "PreservedScalarString",
-]
-
-
-class ScalarString(text_type):
- __slots__ = Anchor.attrib
-
- def __new__(cls, *args, **kw):
- # type: (Any, Any) -> Any
- anchor = kw.pop("anchor", None) # type: ignore
- ret_val = text_type.__new__(cls, *args, **kw) # type: ignore
- if anchor is not None:
- ret_val.yaml_set_anchor(anchor, always_dump=True)
- return ret_val
-
- def replace(self, old, new, maxreplace=-1):
- # type: (Any, Any, int) -> Any
- return type(self)((text_type.replace(self, old, new, maxreplace)))
-
- @property
- def anchor(self):
- # type: () -> Any
- if not hasattr(self, Anchor.attrib):
- setattr(self, Anchor.attrib, Anchor())
- return getattr(self, Anchor.attrib)
-
- def yaml_anchor(self, any=False):
- # type: (bool) -> Any
- if not hasattr(self, Anchor.attrib):
- return None
- if any or self.anchor.always_dump:
- return self.anchor
- return None
-
- def yaml_set_anchor(self, value, always_dump=False):
- # type: (Any, bool) -> None
- self.anchor.value = value
- self.anchor.always_dump = always_dump
-
-
-class LiteralScalarString(ScalarString):
- __slots__ = "comment" # the comment after the | on the first line
-
- style = "|"
-
- def __new__(cls, value, anchor=None):
- # type: (Text, Any) -> Any
- return ScalarString.__new__(cls, value, anchor=anchor)
-
-
-PreservedScalarString = LiteralScalarString
-
-
-class FoldedScalarString(ScalarString):
- __slots__ = ("fold_pos", "comment") # the comment after the > on the first line
-
- style = ">"
-
- def __new__(cls, value, anchor=None):
- # type: (Text, Any) -> Any
- return ScalarString.__new__(cls, value, anchor=anchor)
-
-
-class SingleQuotedScalarString(ScalarString):
- __slots__ = ()
-
- style = "'"
-
- def __new__(cls, value, anchor=None):
- # type: (Text, Any) -> Any
- return ScalarString.__new__(cls, value, anchor=anchor)
-
-
-class DoubleQuotedScalarString(ScalarString):
- __slots__ = ()
-
- style = '"'
-
- def __new__(cls, value, anchor=None):
- # type: (Text, Any) -> Any
- return ScalarString.__new__(cls, value, anchor=anchor)
-
-
-class PlainScalarString(ScalarString):
- __slots__ = ()
-
- style = ""
-
- def __new__(cls, value, anchor=None):
- # type: (Text, Any) -> Any
- return ScalarString.__new__(cls, value, anchor=anchor)
-
-
-def preserve_literal(s):
- # type: (Text) -> Text
- return LiteralScalarString(s.replace("\r\n", "\n").replace("\r", "\n"))
-
-
-def walk_tree(base, map=None):
- # type: (Any, Any) -> None
- """
- the routine here walks over a simple yaml tree (recursing in
- dict values and list items) and converts strings that
- have multiple lines to literal scalars
-
- You can also provide an explicit (ordered) mapping for multiple transforms
- (first of which is executed):
- map = .compat.ordereddict
- map['\n'] = preserve_literal
- map[':'] = SingleQuotedScalarString
- walk_tree(data, map=map)
- """
- from .compat import string_types
- from .compat import MutableMapping, MutableSequence # type: ignore
-
- if map is None:
- map = {"\n": preserve_literal}
-
- if isinstance(base, MutableMapping):
- for k in base:
- v = base[k] # type: Text
- if isinstance(v, string_types):
- for ch in map:
- if ch in v:
- base[k] = map[ch](v)
- break
- else:
- walk_tree(v)
- elif isinstance(base, MutableSequence):
- for idx, elem in enumerate(base):
- if isinstance(elem, string_types):
- for ch in map:
- if ch in elem: # type: ignore
- base[idx] = map[ch](elem)
- break
- else:
- walk_tree(elem)
diff --git a/srsly/ruamel_yaml/scanner.py b/srsly/ruamel_yaml/scanner.py
deleted file mode 100755
index a1d1ad0..0000000
--- a/srsly/ruamel_yaml/scanner.py
+++ /dev/null
@@ -1,2011 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-# Scanner produces tokens of the following types:
-# STREAM-START
-# STREAM-END
-# DIRECTIVE(name, value)
-# DOCUMENT-START
-# DOCUMENT-END
-# BLOCK-SEQUENCE-START
-# BLOCK-MAPPING-START
-# BLOCK-END
-# FLOW-SEQUENCE-START
-# FLOW-MAPPING-START
-# FLOW-SEQUENCE-END
-# FLOW-MAPPING-END
-# BLOCK-ENTRY
-# FLOW-ENTRY
-# KEY
-# VALUE
-# ALIAS(value)
-# ANCHOR(value)
-# TAG(value)
-# SCALAR(value, plain, style)
-#
-# RoundTripScanner
-# COMMENT(value)
-#
-# Read comments in the Scanner code for more details.
-#
-
-from .error import MarkedYAMLError
-from .tokens import * # NOQA
-from .compat import utf8, unichr, PY3, check_anchorname_char, nprint # NOQA
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Union, Text # NOQA
- from .compat import VersionType # NOQA
-
-__all__ = ["Scanner", "RoundTripScanner", "ScannerError"]
-
-
-_THE_END = "\n\0\r\x85\u2028\u2029"
-_THE_END_SPACE_TAB = " \n\0\t\r\x85\u2028\u2029"
-_SPACE_TAB = " \t"
-
-
-class ScannerError(MarkedYAMLError):
- pass
-
-
-class SimpleKey(object):
- # See below simple keys treatment.
-
- def __init__(self, token_number, required, index, line, column, mark):
- # type: (Any, Any, int, int, int, Any) -> None
- self.token_number = token_number
- self.required = required
- self.index = index
- self.line = line
- self.column = column
- self.mark = mark
-
-
-class Scanner(object):
- def __init__(self, loader=None):
- # type: (Any) -> None
- """Initialize the scanner."""
- # It is assumed that Scanner and Reader will have a common descendant.
- # Reader do the dirty work of checking for BOM and converting the
- # input data to Unicode. It also adds NUL to the end.
- #
- # Reader supports the following methods
- # self.peek(i=0) # peek the next i-th character
- # self.prefix(l=1) # peek the next l characters
- # self.forward(l=1) # read the next l characters and move the pointer
-
- self.loader = loader
- if self.loader is not None and getattr(self.loader, "_scanner", None) is None:
- self.loader._scanner = self
- self.reset_scanner()
- self.first_time = False
- self.yaml_version = None # type: Any
-
- @property
- def flow_level(self):
- # type: () -> int
- return len(self.flow_context)
-
- def reset_scanner(self):
- # type: () -> None
- # Had we reached the end of the stream?
- self.done = False
-
- # flow_context is an expanding/shrinking list consisting of '{' and '['
- # for each unclosed flow context. If empty list that means block context
- self.flow_context = [] # type: List[Text]
-
- # List of processed tokens that are not yet emitted.
- self.tokens = [] # type: List[Any]
-
- # Add the STREAM-START token.
- self.fetch_stream_start()
-
- # Number of tokens that were emitted through the `get_token` method.
- self.tokens_taken = 0
-
- # The current indentation level.
- self.indent = -1
-
- # Past indentation levels.
- self.indents = [] # type: List[int]
-
- # Variables related to simple keys treatment.
-
- # A simple key is a key that is not denoted by the '?' indicator.
- # Example of simple keys:
- # ---
- # block simple key: value
- # ? not a simple key:
- # : { flow simple key: value }
- # We emit the KEY token before all keys, so when we find a potential
- # simple key, we try to locate the corresponding ':' indicator.
- # Simple keys should be limited to a single line and 1024 characters.
-
- # Can a simple key start at the current position? A simple key may
- # start:
- # - at the beginning of the line, not counting indentation spaces
- # (in block context),
- # - after '{', '[', ',' (in the flow context),
- # - after '?', ':', '-' (in the block context).
- # In the block context, this flag also signifies if a block collection
- # may start at the current position.
- self.allow_simple_key = True
-
- # Keep track of possible simple keys. This is a dictionary. The key
- # is `flow_level`; there can be no more that one possible simple key
- # for each level. The value is a SimpleKey record:
- # (token_number, required, index, line, column, mark)
- # A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow),
- # '[', or '{' tokens.
- self.possible_simple_keys = {} # type: Dict[Any, Any]
-
- @property
- def reader(self):
- # type: () -> Any
- try:
- return self._scanner_reader # type: ignore
- except AttributeError:
- if hasattr(self.loader, "typ"):
- self._scanner_reader = self.loader.reader
- else:
- self._scanner_reader = self.loader._reader
- return self._scanner_reader
-
- @property
- def scanner_processing_version(self): # prefix until un-composited
- # type: () -> Any
- if hasattr(self.loader, "typ"):
- return self.loader.resolver.processing_version
- return self.loader.processing_version
-
- # Public methods.
-
- def check_token(self, *choices):
- # type: (Any) -> bool
- # Check if the next token is one of the given types.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- if bool(self.tokens):
- if not choices:
- return True
- for choice in choices:
- if isinstance(self.tokens[0], choice):
- return True
- return False
-
- def peek_token(self):
- # type: () -> Any
- # Return the next token, but do not delete if from the queue.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- if bool(self.tokens):
- return self.tokens[0]
-
- def get_token(self):
- # type: () -> Any
- # Return the next token.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- if bool(self.tokens):
- self.tokens_taken += 1
- return self.tokens.pop(0)
-
- # Private methods.
-
- def need_more_tokens(self):
- # type: () -> bool
- if self.done:
- return False
- if not self.tokens:
- return True
- # The current token may be a potential simple key, so we
- # need to look further.
- self.stale_possible_simple_keys()
- if self.next_possible_simple_key() == self.tokens_taken:
- return True
- return False
-
- def fetch_comment(self, comment):
- # type: (Any) -> None
- raise NotImplementedError
-
- def fetch_more_tokens(self):
- # type: () -> Any
- # Eat whitespaces and comments until we reach the next token.
- comment = self.scan_to_next_token()
- if comment is not None: # never happens for base scanner
- return self.fetch_comment(comment)
- # Remove obsolete possible simple keys.
- self.stale_possible_simple_keys()
-
- # Compare the current indentation and column. It may add some tokens
- # and decrease the current indentation level.
- self.unwind_indent(self.reader.column)
-
- # Peek the next character.
- ch = self.reader.peek()
-
- # Is it the end of stream?
- if ch == "\0":
- return self.fetch_stream_end()
-
- # Is it a directive?
- if ch == "%" and self.check_directive():
- return self.fetch_directive()
-
- # Is it the document start?
- if ch == "-" and self.check_document_start():
- return self.fetch_document_start()
-
- # Is it the document end?
- if ch == "." and self.check_document_end():
- return self.fetch_document_end()
-
- # TODO: support for BOM within a stream.
- # if ch == u'\uFEFF':
- # return self.fetch_bom() <-- issue BOMToken
-
- # Note: the order of the following checks is NOT significant.
-
- # Is it the flow sequence start indicator?
- if ch == "[":
- return self.fetch_flow_sequence_start()
-
- # Is it the flow mapping start indicator?
- if ch == "{":
- return self.fetch_flow_mapping_start()
-
- # Is it the flow sequence end indicator?
- if ch == "]":
- return self.fetch_flow_sequence_end()
-
- # Is it the flow mapping end indicator?
- if ch == "}":
- return self.fetch_flow_mapping_end()
-
- # Is it the flow entry indicator?
- if ch == ",":
- return self.fetch_flow_entry()
-
- # Is it the block entry indicator?
- if ch == "-" and self.check_block_entry():
- return self.fetch_block_entry()
-
- # Is it the key indicator?
- if ch == "?" and self.check_key():
- return self.fetch_key()
-
- # Is it the value indicator?
- if ch == ":" and self.check_value():
- return self.fetch_value()
-
- # Is it an alias?
- if ch == "*":
- return self.fetch_alias()
-
- # Is it an anchor?
- if ch == "&":
- return self.fetch_anchor()
-
- # Is it a tag?
- if ch == "!":
- return self.fetch_tag()
-
- # Is it a literal scalar?
- if ch == "|" and not self.flow_level:
- return self.fetch_literal()
-
- # Is it a folded scalar?
- if ch == ">" and not self.flow_level:
- return self.fetch_folded()
-
- # Is it a single quoted scalar?
- if ch == "'":
- return self.fetch_single()
-
- # Is it a double quoted scalar?
- if ch == '"':
- return self.fetch_double()
-
- # It must be a plain scalar then.
- if self.check_plain():
- return self.fetch_plain()
-
- # No? It's an error. Let's produce a nice error message.
- raise ScannerError(
- "while scanning for the next token",
- None,
- "found character %r that cannot start any token" % utf8(ch),
- self.reader.get_mark(),
- )
-
- # Simple keys treatment.
-
- def next_possible_simple_key(self):
- # type: () -> Any
- # Return the number of the nearest possible simple key. Actually we
- # don't need to loop through the whole dictionary. We may replace it
- # with the following code:
- # if not self.possible_simple_keys:
- # return None
- # return self.possible_simple_keys[
- # min(self.possible_simple_keys.keys())].token_number
- min_token_number = None
- for level in self.possible_simple_keys:
- key = self.possible_simple_keys[level]
- if min_token_number is None or key.token_number < min_token_number:
- min_token_number = key.token_number
- return min_token_number
-
- def stale_possible_simple_keys(self):
- # type: () -> None
- # Remove entries that are no longer possible simple keys. According to
- # the YAML specification, simple keys
- # - should be limited to a single line,
- # - should be no longer than 1024 characters.
- # Disabling this procedure will allow simple keys of any length and
- # height (may cause problems if indentation is broken though).
- for level in list(self.possible_simple_keys):
- key = self.possible_simple_keys[level]
- if key.line != self.reader.line or self.reader.index - key.index > 1024:
- if key.required:
- raise ScannerError(
- "while scanning a simple key",
- key.mark,
- "could not find expected ':'",
- self.reader.get_mark(),
- )
- del self.possible_simple_keys[level]
-
- def save_possible_simple_key(self):
- # type: () -> None
- # The next token may start a simple key. We check if it's possible
- # and save its position. This function is called for
- # ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'.
-
- # Check if a simple key is required at the current position.
- required = not self.flow_level and self.indent == self.reader.column
-
- # The next token might be a simple key. Let's save it's number and
- # position.
- if self.allow_simple_key:
- self.remove_possible_simple_key()
- token_number = self.tokens_taken + len(self.tokens)
- key = SimpleKey(
- token_number,
- required,
- self.reader.index,
- self.reader.line,
- self.reader.column,
- self.reader.get_mark(),
- )
- self.possible_simple_keys[self.flow_level] = key
-
- def remove_possible_simple_key(self):
- # type: () -> None
- # Remove the saved possible key position at the current flow level.
- if self.flow_level in self.possible_simple_keys:
- key = self.possible_simple_keys[self.flow_level]
-
- if key.required:
- raise ScannerError(
- "while scanning a simple key",
- key.mark,
- "could not find expected ':'",
- self.reader.get_mark(),
- )
-
- del self.possible_simple_keys[self.flow_level]
-
- # Indentation functions.
-
- def unwind_indent(self, column):
- # type: (Any) -> None
- # In flow context, tokens should respect indentation.
- # Actually the condition should be `self.indent >= column` according to
- # the spec. But this condition will prohibit intuitively correct
- # constructions such as
- # key : {
- # }
- # ####
- # if self.flow_level and self.indent > column:
- # raise ScannerError(None, None,
- # "invalid intendation or unclosed '[' or '{'",
- # self.reader.get_mark())
-
- # In the flow context, indentation is ignored. We make the scanner less
- # restrictive then specification requires.
- if bool(self.flow_level):
- return
-
- # In block context, we may need to issue the BLOCK-END tokens.
- while self.indent > column:
- mark = self.reader.get_mark()
- self.indent = self.indents.pop()
- self.tokens.append(BlockEndToken(mark, mark))
-
- def add_indent(self, column):
- # type: (int) -> bool
- # Check if we need to increase indentation.
- if self.indent < column:
- self.indents.append(self.indent)
- self.indent = column
- return True
- return False
-
- # Fetchers.
-
- def fetch_stream_start(self):
- # type: () -> None
- # We always add STREAM-START as the first token and STREAM-END as the
- # last token.
- # Read the token.
- mark = self.reader.get_mark()
- # Add STREAM-START.
- self.tokens.append(StreamStartToken(mark, mark, encoding=self.reader.encoding))
-
- def fetch_stream_end(self):
- # type: () -> None
- # Set the current intendation to -1.
- self.unwind_indent(-1)
- # Reset simple keys.
- self.remove_possible_simple_key()
- self.allow_simple_key = False
- self.possible_simple_keys = {}
- # Read the token.
- mark = self.reader.get_mark()
- # Add STREAM-END.
- self.tokens.append(StreamEndToken(mark, mark))
- # The steam is finished.
- self.done = True
-
- def fetch_directive(self):
- # type: () -> None
- # Set the current intendation to -1.
- self.unwind_indent(-1)
-
- # Reset simple keys.
- self.remove_possible_simple_key()
- self.allow_simple_key = False
-
- # Scan and add DIRECTIVE.
- self.tokens.append(self.scan_directive())
-
- def fetch_document_start(self):
- # type: () -> None
- self.fetch_document_indicator(DocumentStartToken)
-
- def fetch_document_end(self):
- # type: () -> None
- self.fetch_document_indicator(DocumentEndToken)
-
- def fetch_document_indicator(self, TokenClass):
- # type: (Any) -> None
- # Set the current intendation to -1.
- self.unwind_indent(-1)
-
- # Reset simple keys. Note that there could not be a block collection
- # after '---'.
- self.remove_possible_simple_key()
- self.allow_simple_key = False
-
- # Add DOCUMENT-START or DOCUMENT-END.
- start_mark = self.reader.get_mark()
- self.reader.forward(3)
- end_mark = self.reader.get_mark()
- self.tokens.append(TokenClass(start_mark, end_mark))
-
- def fetch_flow_sequence_start(self):
- # type: () -> None
- self.fetch_flow_collection_start(FlowSequenceStartToken, to_push="[")
-
- def fetch_flow_mapping_start(self):
- # type: () -> None
- self.fetch_flow_collection_start(FlowMappingStartToken, to_push="{")
-
- def fetch_flow_collection_start(self, TokenClass, to_push):
- # type: (Any, Text) -> None
- # '[' and '{' may start a simple key.
- self.save_possible_simple_key()
- # Increase the flow level.
- self.flow_context.append(to_push)
- # Simple keys are allowed after '[' and '{'.
- self.allow_simple_key = True
- # Add FLOW-SEQUENCE-START or FLOW-MAPPING-START.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(TokenClass(start_mark, end_mark))
-
- def fetch_flow_sequence_end(self):
- # type: () -> None
- self.fetch_flow_collection_end(FlowSequenceEndToken)
-
- def fetch_flow_mapping_end(self):
- # type: () -> None
- self.fetch_flow_collection_end(FlowMappingEndToken)
-
- def fetch_flow_collection_end(self, TokenClass):
- # type: (Any) -> None
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
- # Decrease the flow level.
- try:
- popped = self.flow_context.pop() # NOQA
- except IndexError:
- # We must not be in a list or object.
- # Defer error handling to the parser.
- pass
- # No simple keys after ']' or '}'.
- self.allow_simple_key = False
- # Add FLOW-SEQUENCE-END or FLOW-MAPPING-END.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(TokenClass(start_mark, end_mark))
-
- def fetch_flow_entry(self):
- # type: () -> None
- # Simple keys are allowed after ','.
- self.allow_simple_key = True
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
- # Add FLOW-ENTRY.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(FlowEntryToken(start_mark, end_mark))
-
- def fetch_block_entry(self):
- # type: () -> None
- # Block context needs additional checks.
- if not self.flow_level:
- # Are we allowed to start a new entry?
- if not self.allow_simple_key:
- raise ScannerError(
- None,
- None,
- "sequence entries are not allowed here",
- self.reader.get_mark(),
- )
- # We may need to add BLOCK-SEQUENCE-START.
- if self.add_indent(self.reader.column):
- mark = self.reader.get_mark()
- self.tokens.append(BlockSequenceStartToken(mark, mark))
- # It's an error for the block entry to occur in the flow context,
- # but we let the parser detect this.
- else:
- pass
- # Simple keys are allowed after '-'.
- self.allow_simple_key = True
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
-
- # Add BLOCK-ENTRY.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(BlockEntryToken(start_mark, end_mark))
-
- def fetch_key(self):
- # type: () -> None
- # Block context needs additional checks.
- if not self.flow_level:
-
- # Are we allowed to start a key (not nessesary a simple)?
- if not self.allow_simple_key:
- raise ScannerError(
- None,
- None,
- "mapping keys are not allowed here",
- self.reader.get_mark(),
- )
-
- # We may need to add BLOCK-MAPPING-START.
- if self.add_indent(self.reader.column):
- mark = self.reader.get_mark()
- self.tokens.append(BlockMappingStartToken(mark, mark))
-
- # Simple keys are allowed after '?' in the block context.
- self.allow_simple_key = not self.flow_level
-
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
-
- # Add KEY.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(KeyToken(start_mark, end_mark))
-
- def fetch_value(self):
- # type: () -> None
- # Do we determine a simple key?
- if self.flow_level in self.possible_simple_keys:
- # Add KEY.
- key = self.possible_simple_keys[self.flow_level]
- del self.possible_simple_keys[self.flow_level]
- self.tokens.insert(
- key.token_number - self.tokens_taken, KeyToken(key.mark, key.mark)
- )
-
- # If this key starts a new block mapping, we need to add
- # BLOCK-MAPPING-START.
- if not self.flow_level:
- if self.add_indent(key.column):
- self.tokens.insert(
- key.token_number - self.tokens_taken,
- BlockMappingStartToken(key.mark, key.mark),
- )
-
- # There cannot be two simple keys one after another.
- self.allow_simple_key = False
-
- # It must be a part of a complex key.
- else:
-
- # Block context needs additional checks.
- # (Do we really need them? They will be caught by the parser
- # anyway.)
- if not self.flow_level:
-
- # We are allowed to start a complex value if and only if
- # we can start a simple key.
- if not self.allow_simple_key:
- raise ScannerError(
- None,
- None,
- "mapping values are not allowed here",
- self.reader.get_mark(),
- )
-
- # If this value starts a new block mapping, we need to add
- # BLOCK-MAPPING-START. It will be detected as an error later by
- # the parser.
- if not self.flow_level:
- if self.add_indent(self.reader.column):
- mark = self.reader.get_mark()
- self.tokens.append(BlockMappingStartToken(mark, mark))
-
- # Simple keys are allowed after ':' in the block context.
- self.allow_simple_key = not self.flow_level
-
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
-
- # Add VALUE.
- start_mark = self.reader.get_mark()
- self.reader.forward()
- end_mark = self.reader.get_mark()
- self.tokens.append(ValueToken(start_mark, end_mark))
-
- def fetch_alias(self):
- # type: () -> None
- # ALIAS could be a simple key.
- self.save_possible_simple_key()
- # No simple keys after ALIAS.
- self.allow_simple_key = False
- # Scan and add ALIAS.
- self.tokens.append(self.scan_anchor(AliasToken))
-
- def fetch_anchor(self):
- # type: () -> None
- # ANCHOR could start a simple key.
- self.save_possible_simple_key()
- # No simple keys after ANCHOR.
- self.allow_simple_key = False
- # Scan and add ANCHOR.
- self.tokens.append(self.scan_anchor(AnchorToken))
-
- def fetch_tag(self):
- # type: () -> None
- # TAG could start a simple key.
- self.save_possible_simple_key()
- # No simple keys after TAG.
- self.allow_simple_key = False
- # Scan and add TAG.
- self.tokens.append(self.scan_tag())
-
- def fetch_literal(self):
- # type: () -> None
- self.fetch_block_scalar(style="|")
-
- def fetch_folded(self):
- # type: () -> None
- self.fetch_block_scalar(style=">")
-
- def fetch_block_scalar(self, style):
- # type: (Any) -> None
- # A simple key may follow a block scalar.
- self.allow_simple_key = True
- # Reset possible simple key on the current level.
- self.remove_possible_simple_key()
- # Scan and add SCALAR.
- self.tokens.append(self.scan_block_scalar(style))
-
- def fetch_single(self):
- # type: () -> None
- self.fetch_flow_scalar(style="'")
-
- def fetch_double(self):
- # type: () -> None
- self.fetch_flow_scalar(style='"')
-
- def fetch_flow_scalar(self, style):
- # type: (Any) -> None
- # A flow scalar could be a simple key.
- self.save_possible_simple_key()
- # No simple keys after flow scalars.
- self.allow_simple_key = False
- # Scan and add SCALAR.
- self.tokens.append(self.scan_flow_scalar(style))
-
- def fetch_plain(self):
- # type: () -> None
- # A plain scalar could be a simple key.
- self.save_possible_simple_key()
- # No simple keys after plain scalars. But note that `scan_plain` will
- # change this flag if the scan is finished at the beginning of the
- # line.
- self.allow_simple_key = False
- # Scan and add SCALAR. May change `allow_simple_key`.
- self.tokens.append(self.scan_plain())
-
- # Checkers.
-
- def check_directive(self):
- # type: () -> Any
- # DIRECTIVE: ^ '%' ...
- # The '%' indicator is already checked.
- if self.reader.column == 0:
- return True
- return None
-
- def check_document_start(self):
- # type: () -> Any
- # DOCUMENT-START: ^ '---' (' '|'\n')
- if self.reader.column == 0:
- if (
- self.reader.prefix(3) == "---"
- and self.reader.peek(3) in _THE_END_SPACE_TAB
- ):
- return True
- return None
-
- def check_document_end(self):
- # type: () -> Any
- # DOCUMENT-END: ^ '...' (' '|'\n')
- if self.reader.column == 0:
- if (
- self.reader.prefix(3) == "..."
- and self.reader.peek(3) in _THE_END_SPACE_TAB
- ):
- return True
- return None
-
- def check_block_entry(self):
- # type: () -> Any
- # BLOCK-ENTRY: '-' (' '|'\n')
- return self.reader.peek(1) in _THE_END_SPACE_TAB
-
- def check_key(self):
- # type: () -> Any
- # KEY(flow context): '?'
- if bool(self.flow_level):
- return True
- # KEY(block context): '?' (' '|'\n')
- return self.reader.peek(1) in _THE_END_SPACE_TAB
-
- def check_value(self):
- # type: () -> Any
- # VALUE(flow context): ':'
- if self.scanner_processing_version == (1, 1):
- if bool(self.flow_level):
- return True
- else:
- if bool(self.flow_level):
- if self.flow_context[-1] == "[":
- if self.reader.peek(1) not in _THE_END_SPACE_TAB:
- return False
- elif self.tokens and isinstance(self.tokens[-1], ValueToken):
- # mapping flow context scanning a value token
- if self.reader.peek(1) not in _THE_END_SPACE_TAB:
- return False
- return True
- # VALUE(block context): ':' (' '|'\n')
- return self.reader.peek(1) in _THE_END_SPACE_TAB
-
- def check_plain(self):
- # type: () -> Any
- # A plain scalar may start with any non-space character except:
- # '-', '?', ':', ',', '[', ']', '{', '}',
- # '#', '&', '*', '!', '|', '>', '\'', '\"',
- # '%', '@', '`'.
- #
- # It may also start with
- # '-', '?', ':'
- # if it is followed by a non-space character.
- #
- # Note that we limit the last rule to the block context (except the
- # '-' character) because we want the flow context to be space
- # independent.
- srp = self.reader.peek
- ch = srp()
- if self.scanner_processing_version == (1, 1):
- return ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`" or (
- srp(1) not in _THE_END_SPACE_TAB
- and (ch == "-" or (not self.flow_level and ch in "?:"))
- )
- # YAML 1.2
- if ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`":
- # ################### ^ ???
- return True
- ch1 = srp(1)
- if ch == "-" and ch1 not in _THE_END_SPACE_TAB:
- return True
- if ch == ":" and bool(self.flow_level) and ch1 not in _SPACE_TAB:
- return True
-
- return srp(1) not in _THE_END_SPACE_TAB and (
- ch == "-" or (not self.flow_level and ch in "?:")
- )
-
- # Scanners.
-
- def scan_to_next_token(self):
- # type: () -> Any
- # We ignore spaces, line breaks and comments.
- # If we find a line break in the block context, we set the flag
- # `allow_simple_key` on.
- # The byte order mark is stripped if it's the first character in the
- # stream. We do not yet support BOM inside the stream as the
- # specification requires. Any such mark will be considered as a part
- # of the document.
- #
- # TODO: We need to make tab handling rules more sane. A good rule is
- # Tabs cannot precede tokens
- # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END,
- # KEY(block), VALUE(block), BLOCK-ENTRY
- # So the checking code is
- # if :
- # self.allow_simple_keys = False
- # We also need to add the check for `allow_simple_keys == True` to
- # `unwind_indent` before issuing BLOCK-END.
- # Scanners for block, flow, and plain scalars need to be modified.
- srp = self.reader.peek
- srf = self.reader.forward
- if self.reader.index == 0 and srp() == "\uFEFF":
- srf()
- found = False
- _the_end = _THE_END
- while not found:
- while srp() == " ":
- srf()
- if srp() == "#":
- while srp() not in _the_end:
- srf()
- if self.scan_line_break():
- if not self.flow_level:
- self.allow_simple_key = True
- else:
- found = True
- return None
-
- def scan_directive(self):
- # type: () -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- start_mark = self.reader.get_mark()
- srf()
- name = self.scan_directive_name(start_mark)
- value = None
- if name == "YAML":
- value = self.scan_yaml_directive_value(start_mark)
- end_mark = self.reader.get_mark()
- elif name == "TAG":
- value = self.scan_tag_directive_value(start_mark)
- end_mark = self.reader.get_mark()
- else:
- end_mark = self.reader.get_mark()
- while srp() not in _THE_END:
- srf()
- self.scan_directive_ignored_line(start_mark)
- return DirectiveToken(name, value, start_mark, end_mark)
-
- def scan_directive_name(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- length = 0
- srp = self.reader.peek
- ch = srp(length)
- while "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_:.":
- length += 1
- ch = srp(length)
- if not length:
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected alphabetic or numeric character, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- value = self.reader.prefix(length)
- self.reader.forward(length)
- ch = srp()
- if ch not in "\0 \r\n\x85\u2028\u2029":
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected alphabetic or numeric character, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- return value
-
- def scan_yaml_directive_value(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- while srp() == " ":
- srf()
- major = self.scan_yaml_directive_number(start_mark)
- if srp() != ".":
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected a digit or '.', but found %r" % utf8(srp()),
- self.reader.get_mark(),
- )
- srf()
- minor = self.scan_yaml_directive_number(start_mark)
- if srp() not in "\0 \r\n\x85\u2028\u2029":
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected a digit or ' ', but found %r" % utf8(srp()),
- self.reader.get_mark(),
- )
- self.yaml_version = (major, minor)
- return self.yaml_version
-
- def scan_yaml_directive_number(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- ch = srp()
- if not ("0" <= ch <= "9"):
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected a digit, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- length = 0
- while "0" <= srp(length) <= "9":
- length += 1
- value = int(self.reader.prefix(length))
- srf(length)
- return value
-
- def scan_tag_directive_value(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- while srp() == " ":
- srf()
- handle = self.scan_tag_directive_handle(start_mark)
- while srp() == " ":
- srf()
- prefix = self.scan_tag_directive_prefix(start_mark)
- return (handle, prefix)
-
- def scan_tag_directive_handle(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- value = self.scan_tag_handle("directive", start_mark)
- ch = self.reader.peek()
- if ch != " ":
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected ' ', but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- return value
-
- def scan_tag_directive_prefix(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- value = self.scan_tag_uri("directive", start_mark)
- ch = self.reader.peek()
- if ch not in "\0 \r\n\x85\u2028\u2029":
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected ' ', but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- return value
-
- def scan_directive_ignored_line(self, start_mark):
- # type: (Any) -> None
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- while srp() == " ":
- srf()
- if srp() == "#":
- while srp() not in _THE_END:
- srf()
- ch = srp()
- if ch not in _THE_END:
- raise ScannerError(
- "while scanning a directive",
- start_mark,
- "expected a comment or a line break, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- self.scan_line_break()
-
- def scan_anchor(self, TokenClass):
- # type: (Any) -> Any
- # The specification does not restrict characters for anchors and
- # aliases. This may lead to problems, for instance, the document:
- # [ *alias, value ]
- # can be interpteted in two ways, as
- # [ "value" ]
- # and
- # [ *alias , "value" ]
- # Therefore we restrict aliases to numbers and ASCII letters.
- srp = self.reader.peek
- start_mark = self.reader.get_mark()
- indicator = srp()
- if indicator == "*":
- name = "alias"
- else:
- name = "anchor"
- self.reader.forward()
- length = 0
- ch = srp(length)
- # while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \
- # or ch in u'-_':
- while check_anchorname_char(ch):
- length += 1
- ch = srp(length)
- if not length:
- raise ScannerError(
- "while scanning an %s" % (name,),
- start_mark,
- "expected alphabetic or numeric character, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- value = self.reader.prefix(length)
- self.reader.forward(length)
- # ch1 = ch
- # ch = srp() # no need to peek, ch is already set
- # assert ch1 == ch
- if ch not in "\0 \t\r\n\x85\u2028\u2029?:,[]{}%@`":
- raise ScannerError(
- "while scanning an %s" % (name,),
- start_mark,
- "expected alphabetic or numeric character, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- end_mark = self.reader.get_mark()
- return TokenClass(value, start_mark, end_mark)
-
- def scan_tag(self):
- # type: () -> Any
- # See the specification for details.
- srp = self.reader.peek
- start_mark = self.reader.get_mark()
- ch = srp(1)
- if ch == "<":
- handle = None
- self.reader.forward(2)
- suffix = self.scan_tag_uri("tag", start_mark)
- if srp() != ">":
- raise ScannerError(
- "while parsing a tag",
- start_mark,
- "expected '>', but found %r" % utf8(srp()),
- self.reader.get_mark(),
- )
- self.reader.forward()
- elif ch in _THE_END_SPACE_TAB:
- handle = None
- suffix = "!"
- self.reader.forward()
- else:
- length = 1
- use_handle = False
- while ch not in "\0 \r\n\x85\u2028\u2029":
- if ch == "!":
- use_handle = True
- break
- length += 1
- ch = srp(length)
- handle = "!"
- if use_handle:
- handle = self.scan_tag_handle("tag", start_mark)
- else:
- handle = "!"
- self.reader.forward()
- suffix = self.scan_tag_uri("tag", start_mark)
- ch = srp()
- if ch not in "\0 \r\n\x85\u2028\u2029":
- raise ScannerError(
- "while scanning a tag",
- start_mark,
- "expected ' ', but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- value = (handle, suffix)
- end_mark = self.reader.get_mark()
- return TagToken(value, start_mark, end_mark)
-
- def scan_block_scalar(self, style, rt=False):
- # type: (Any, Optional[bool]) -> Any
- # See the specification for details.
- srp = self.reader.peek
- if style == ">":
- folded = True
- else:
- folded = False
-
- chunks = [] # type: List[Any]
- start_mark = self.reader.get_mark()
-
- # Scan the header.
- self.reader.forward()
- chomping, increment = self.scan_block_scalar_indicators(start_mark)
- # block scalar comment e.g. : |+ # comment text
- block_scalar_comment = self.scan_block_scalar_ignored_line(start_mark)
-
- # Determine the indentation level and go to the first non-empty line.
- min_indent = self.indent + 1
- if increment is None:
- # no increment and top level, min_indent could be 0
- if min_indent < 1 and (
- style not in "|>"
- or (self.scanner_processing_version == (1, 1))
- and getattr(
- self.loader,
- "top_level_block_style_scalar_no_indent_error_1_1",
- False,
- )
- ):
- min_indent = 1
- breaks, max_indent, end_mark = self.scan_block_scalar_indentation()
- indent = max(min_indent, max_indent)
- else:
- if min_indent < 1:
- min_indent = 1
- indent = min_indent + increment - 1
- breaks, end_mark = self.scan_block_scalar_breaks(indent)
- line_break = ""
-
- # Scan the inner part of the block scalar.
- while self.reader.column == indent and srp() != "\0":
- chunks.extend(breaks)
- leading_non_space = srp() not in " \t"
- length = 0
- while srp(length) not in _THE_END:
- length += 1
- chunks.append(self.reader.prefix(length))
- self.reader.forward(length)
- line_break = self.scan_line_break()
- breaks, end_mark = self.scan_block_scalar_breaks(indent)
- if style in "|>" and min_indent == 0:
- # at the beginning of a line, if in block style see if
- # end of document/start_new_document
- if self.check_document_start() or self.check_document_end():
- break
- if self.reader.column == indent and srp() != "\0":
-
- # Unfortunately, folding rules are ambiguous.
- #
- # This is the folding according to the specification:
-
- if rt and folded and line_break == "\n":
- chunks.append("\a")
- if (
- folded
- and line_break == "\n"
- and leading_non_space
- and srp() not in " \t"
- ):
- if not breaks:
- chunks.append(" ")
- else:
- chunks.append(line_break)
-
- # This is Clark Evans's interpretation (also in the spec
- # examples):
- #
- # if folded and line_break == u'\n':
- # if not breaks:
- # if srp() not in ' \t':
- # chunks.append(u' ')
- # else:
- # chunks.append(line_break)
- # else:
- # chunks.append(line_break)
- else:
- break
-
- # Process trailing line breaks. The 'chomping' setting determines
- # whether they are included in the value.
- trailing = [] # type: List[Any]
- if chomping in [None, True]:
- chunks.append(line_break)
- if chomping is True:
- chunks.extend(breaks)
- elif chomping in [None, False]:
- trailing.extend(breaks)
-
- # We are done.
- token = ScalarToken("".join(chunks), False, start_mark, end_mark, style)
- if block_scalar_comment is not None:
- token.add_pre_comments([block_scalar_comment])
- if len(trailing) > 0:
- # nprint('trailing 1', trailing) # XXXXX
- # Eat whitespaces and comments until we reach the next token.
- comment = self.scan_to_next_token()
- while comment:
- trailing.append(" " * comment[1].column + comment[0])
- comment = self.scan_to_next_token()
-
- # Keep track of the trailing whitespace and following comments
- # as a comment token, if isn't all included in the actual value.
- comment_end_mark = self.reader.get_mark()
- comment = CommentToken("".join(trailing), end_mark, comment_end_mark)
- token.add_post_comment(comment)
- return token
-
- def scan_block_scalar_indicators(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- chomping = None
- increment = None
- ch = srp()
- if ch in "+-":
- if ch == "+":
- chomping = True
- else:
- chomping = False
- self.reader.forward()
- ch = srp()
- if ch in "0123456789":
- increment = int(ch)
- if increment == 0:
- raise ScannerError(
- "while scanning a block scalar",
- start_mark,
- "expected indentation indicator in the range 1-9, "
- "but found 0",
- self.reader.get_mark(),
- )
- self.reader.forward()
- elif ch in "0123456789":
- increment = int(ch)
- if increment == 0:
- raise ScannerError(
- "while scanning a block scalar",
- start_mark,
- "expected indentation indicator in the range 1-9, " "but found 0",
- self.reader.get_mark(),
- )
- self.reader.forward()
- ch = srp()
- if ch in "+-":
- if ch == "+":
- chomping = True
- else:
- chomping = False
- self.reader.forward()
- ch = srp()
- if ch not in "\0 \r\n\x85\u2028\u2029":
- raise ScannerError(
- "while scanning a block scalar",
- start_mark,
- "expected chomping or indentation indicators, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- return chomping, increment
-
- def scan_block_scalar_ignored_line(self, start_mark):
- # type: (Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- prefix = ""
- comment = None
- while srp() == " ":
- prefix += srp()
- srf()
- if srp() == "#":
- comment = prefix
- while srp() not in _THE_END:
- comment += srp()
- srf()
- ch = srp()
- if ch not in _THE_END:
- raise ScannerError(
- "while scanning a block scalar",
- start_mark,
- "expected a comment or a line break, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- self.scan_line_break()
- return comment
-
- def scan_block_scalar_indentation(self):
- # type: () -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- chunks = []
- max_indent = 0
- end_mark = self.reader.get_mark()
- while srp() in " \r\n\x85\u2028\u2029":
- if srp() != " ":
- chunks.append(self.scan_line_break())
- end_mark = self.reader.get_mark()
- else:
- srf()
- if self.reader.column > max_indent:
- max_indent = self.reader.column
- return chunks, max_indent, end_mark
-
- def scan_block_scalar_breaks(self, indent):
- # type: (int) -> Any
- # See the specification for details.
- chunks = []
- srp = self.reader.peek
- srf = self.reader.forward
- end_mark = self.reader.get_mark()
- while self.reader.column < indent and srp() == " ":
- srf()
- while srp() in "\r\n\x85\u2028\u2029":
- chunks.append(self.scan_line_break())
- end_mark = self.reader.get_mark()
- while self.reader.column < indent and srp() == " ":
- srf()
- return chunks, end_mark
-
- def scan_flow_scalar(self, style):
- # type: (Any) -> Any
- # See the specification for details.
- # Note that we loose indentation rules for quoted scalars. Quoted
- # scalars don't need to adhere indentation because " and ' clearly
- # mark the beginning and the end of them. Therefore we are less
- # restrictive then the specification requires. We only need to check
- # that document separators are not included in scalars.
- if style == '"':
- double = True
- else:
- double = False
- srp = self.reader.peek
- chunks = [] # type: List[Any]
- start_mark = self.reader.get_mark()
- quote = srp()
- self.reader.forward()
- chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
- while srp() != quote:
- chunks.extend(self.scan_flow_scalar_spaces(double, start_mark))
- chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
- self.reader.forward()
- end_mark = self.reader.get_mark()
- return ScalarToken("".join(chunks), False, start_mark, end_mark, style)
-
- ESCAPE_REPLACEMENTS = {
- "0": "\0",
- "a": "\x07",
- "b": "\x08",
- "t": "\x09",
- "\t": "\x09",
- "n": "\x0A",
- "v": "\x0B",
- "f": "\x0C",
- "r": "\x0D",
- "e": "\x1B",
- " ": "\x20",
- '"': '"',
- "/": "/", # as per http://www.json.org/
- "\\": "\\",
- "N": "\x85",
- "_": "\xA0",
- "L": "\u2028",
- "P": "\u2029",
- }
-
- ESCAPE_CODES = {"x": 2, "u": 4, "U": 8}
-
- def scan_flow_scalar_non_spaces(self, double, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- chunks = [] # type: List[Any]
- srp = self.reader.peek
- srf = self.reader.forward
- while True:
- length = 0
- while srp(length) not in " \n'\"\\\0\t\r\x85\u2028\u2029":
- length += 1
- if length != 0:
- chunks.append(self.reader.prefix(length))
- srf(length)
- ch = srp()
- if not double and ch == "'" and srp(1) == "'":
- chunks.append("'")
- srf(2)
- elif (double and ch == "'") or (not double and ch in '"\\'):
- chunks.append(ch)
- srf()
- elif double and ch == "\\":
- srf()
- ch = srp()
- if ch in self.ESCAPE_REPLACEMENTS:
- chunks.append(self.ESCAPE_REPLACEMENTS[ch])
- srf()
- elif ch in self.ESCAPE_CODES:
- length = self.ESCAPE_CODES[ch]
- srf()
- for k in range(length):
- if srp(k) not in "0123456789ABCDEFabcdef":
- raise ScannerError(
- "while scanning a double-quoted scalar",
- start_mark,
- "expected escape sequence of %d hexdecimal "
- "numbers, but found %r" % (length, utf8(srp(k))),
- self.reader.get_mark(),
- )
- code = int(self.reader.prefix(length), 16)
- chunks.append(unichr(code))
- srf(length)
- elif ch in "\n\r\x85\u2028\u2029":
- self.scan_line_break()
- chunks.extend(self.scan_flow_scalar_breaks(double, start_mark))
- else:
- raise ScannerError(
- "while scanning a double-quoted scalar",
- start_mark,
- "found unknown escape character %r" % utf8(ch),
- self.reader.get_mark(),
- )
- else:
- return chunks
-
- def scan_flow_scalar_spaces(self, double, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- chunks = []
- length = 0
- while srp(length) in " \t":
- length += 1
- whitespaces = self.reader.prefix(length)
- self.reader.forward(length)
- ch = srp()
- if ch == "\0":
- raise ScannerError(
- "while scanning a quoted scalar",
- start_mark,
- "found unexpected end of stream",
- self.reader.get_mark(),
- )
- elif ch in "\r\n\x85\u2028\u2029":
- line_break = self.scan_line_break()
- breaks = self.scan_flow_scalar_breaks(double, start_mark)
- if line_break != "\n":
- chunks.append(line_break)
- elif not breaks:
- chunks.append(" ")
- chunks.extend(breaks)
- else:
- chunks.append(whitespaces)
- return chunks
-
- def scan_flow_scalar_breaks(self, double, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- chunks = [] # type: List[Any]
- srp = self.reader.peek
- srf = self.reader.forward
- while True:
- # Instead of checking indentation, we check for document
- # separators.
- prefix = self.reader.prefix(3)
- if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB:
- raise ScannerError(
- "while scanning a quoted scalar",
- start_mark,
- "found unexpected document separator",
- self.reader.get_mark(),
- )
- while srp() in " \t":
- srf()
- if srp() in "\r\n\x85\u2028\u2029":
- chunks.append(self.scan_line_break())
- else:
- return chunks
-
- def scan_plain(self):
- # type: () -> Any
- # See the specification for details.
- # We add an additional restriction for the flow context:
- # plain scalars in the flow context cannot contain ',', ': ' and '?'.
- # We also keep track of the `allow_simple_key` flag here.
- # Indentation rules are loosed for the flow context.
- srp = self.reader.peek
- srf = self.reader.forward
- chunks = [] # type: List[Any]
- start_mark = self.reader.get_mark()
- end_mark = start_mark
- indent = self.indent + 1
- # We allow zero indentation for scalars, but then we need to check for
- # document separators at the beginning of the line.
- # if indent == 0:
- # indent = 1
- spaces = [] # type: List[Any]
- while True:
- length = 0
- if srp() == "#":
- break
- while True:
- ch = srp(length)
- if ch == ":" and srp(length + 1) not in _THE_END_SPACE_TAB:
- pass
- elif ch == "?" and self.scanner_processing_version != (1, 1):
- pass
- elif (
- ch in _THE_END_SPACE_TAB
- or (
- not self.flow_level
- and ch == ":"
- and srp(length + 1) in _THE_END_SPACE_TAB
- )
- or (self.flow_level and ch in ",:?[]{}")
- ):
- break
- length += 1
- # It's not clear what we should do with ':' in the flow context.
- if (
- self.flow_level
- and ch == ":"
- and srp(length + 1) not in "\0 \t\r\n\x85\u2028\u2029,[]{}"
- ):
- srf(length)
- raise ScannerError(
- "while scanning a plain scalar",
- start_mark,
- "found unexpected ':'",
- self.reader.get_mark(),
- "Please check "
- "http://pyyaml.org/wiki/YAMLColonInFlowContext "
- "for details.",
- )
- if length == 0:
- break
- self.allow_simple_key = False
- chunks.extend(spaces)
- chunks.append(self.reader.prefix(length))
- srf(length)
- end_mark = self.reader.get_mark()
- spaces = self.scan_plain_spaces(indent, start_mark)
- if (
- not spaces
- or srp() == "#"
- or (not self.flow_level and self.reader.column < indent)
- ):
- break
-
- token = ScalarToken("".join(chunks), True, start_mark, end_mark)
- if spaces and spaces[0] == "\n":
- # Create a comment token to preserve the trailing line breaks.
- comment = CommentToken("".join(spaces) + "\n", start_mark, end_mark)
- token.add_post_comment(comment)
- return token
-
- def scan_plain_spaces(self, indent, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- # The specification is really confusing about tabs in plain scalars.
- # We just forbid them completely. Do not use tabs in YAML!
- srp = self.reader.peek
- srf = self.reader.forward
- chunks = []
- length = 0
- while srp(length) in " ":
- length += 1
- whitespaces = self.reader.prefix(length)
- self.reader.forward(length)
- ch = srp()
- if ch in "\r\n\x85\u2028\u2029":
- line_break = self.scan_line_break()
- self.allow_simple_key = True
- prefix = self.reader.prefix(3)
- if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB:
- return
- breaks = []
- while srp() in " \r\n\x85\u2028\u2029":
- if srp() == " ":
- srf()
- else:
- breaks.append(self.scan_line_break())
- prefix = self.reader.prefix(3)
- if (prefix == "---" or prefix == "...") and srp(
- 3
- ) in _THE_END_SPACE_TAB:
- return
- if line_break != "\n":
- chunks.append(line_break)
- elif not breaks:
- chunks.append(" ")
- chunks.extend(breaks)
- elif whitespaces:
- chunks.append(whitespaces)
- return chunks
-
- def scan_tag_handle(self, name, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- # For some strange reasons, the specification does not allow '_' in
- # tag handles. I have allowed it anyway.
- srp = self.reader.peek
- ch = srp()
- if ch != "!":
- raise ScannerError(
- "while scanning a %s" % (name,),
- start_mark,
- "expected '!', but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- length = 1
- ch = srp(length)
- if ch != " ":
- while (
- "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_"
- ):
- length += 1
- ch = srp(length)
- if ch != "!":
- self.reader.forward(length)
- raise ScannerError(
- "while scanning a %s" % (name,),
- start_mark,
- "expected '!', but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- length += 1
- value = self.reader.prefix(length)
- self.reader.forward(length)
- return value
-
- def scan_tag_uri(self, name, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- # Note: we do not check if URI is well-formed.
- srp = self.reader.peek
- chunks = []
- length = 0
- ch = srp(length)
- while (
- "0" <= ch <= "9"
- or "A" <= ch <= "Z"
- or "a" <= ch <= "z"
- or ch in "-;/?:@&=+$,_.!~*'()[]%"
- or ((self.scanner_processing_version > (1, 1)) and ch == "#")
- ):
- if ch == "%":
- chunks.append(self.reader.prefix(length))
- self.reader.forward(length)
- length = 0
- chunks.append(self.scan_uri_escapes(name, start_mark))
- else:
- length += 1
- ch = srp(length)
- if length != 0:
- chunks.append(self.reader.prefix(length))
- self.reader.forward(length)
- length = 0
- if not chunks:
- raise ScannerError(
- "while parsing a %s" % (name,),
- start_mark,
- "expected URI, but found %r" % utf8(ch),
- self.reader.get_mark(),
- )
- return "".join(chunks)
-
- def scan_uri_escapes(self, name, start_mark):
- # type: (Any, Any) -> Any
- # See the specification for details.
- srp = self.reader.peek
- srf = self.reader.forward
- code_bytes = [] # type: List[Any]
- mark = self.reader.get_mark()
- while srp() == "%":
- srf()
- for k in range(2):
- if srp(k) not in "0123456789ABCDEFabcdef":
- raise ScannerError(
- "while scanning a %s" % (name,),
- start_mark,
- "expected URI escape sequence of 2 hexdecimal numbers,"
- " but found %r" % utf8(srp(k)),
- self.reader.get_mark(),
- )
- if PY3:
- code_bytes.append(int(self.reader.prefix(2), 16))
- else:
- code_bytes.append(chr(int(self.reader.prefix(2), 16)))
- srf(2)
- try:
- if PY3:
- value = bytes(code_bytes).decode("utf-8")
- else:
- value = unicode(b"".join(code_bytes), "utf-8")
- except UnicodeDecodeError as exc:
- raise ScannerError(
- "while scanning a %s" % (name,), start_mark, str(exc), mark
- )
- return value
-
- def scan_line_break(self):
- # type: () -> Any
- # Transforms:
- # '\r\n' : '\n'
- # '\r' : '\n'
- # '\n' : '\n'
- # '\x85' : '\n'
- # '\u2028' : '\u2028'
- # '\u2029 : '\u2029'
- # default : ''
- ch = self.reader.peek()
- if ch in "\r\n\x85":
- if self.reader.prefix(2) == "\r\n":
- self.reader.forward(2)
- else:
- self.reader.forward()
- return "\n"
- elif ch in "\u2028\u2029":
- self.reader.forward()
- return ch
- return ""
-
-
-class RoundTripScanner(Scanner):
- def check_token(self, *choices):
- # type: (Any) -> bool
- # Check if the next token is one of the given types.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- self._gather_comments()
- if bool(self.tokens):
- if not choices:
- return True
- for choice in choices:
- if isinstance(self.tokens[0], choice):
- return True
- return False
-
- def peek_token(self):
- # type: () -> Any
- # Return the next token, but do not delete if from the queue.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- self._gather_comments()
- if bool(self.tokens):
- return self.tokens[0]
- return None
-
- def _gather_comments(self):
- # type: () -> Any
- """combine multiple comment lines"""
- comments = [] # type: List[Any]
- if not self.tokens:
- return comments
- if isinstance(self.tokens[0], CommentToken):
- comment = self.tokens.pop(0)
- self.tokens_taken += 1
- comments.append(comment)
- while self.need_more_tokens():
- self.fetch_more_tokens()
- if not self.tokens:
- return comments
- if isinstance(self.tokens[0], CommentToken):
- self.tokens_taken += 1
- comment = self.tokens.pop(0)
- # nprint('dropping2', comment)
- comments.append(comment)
- if len(comments) >= 1:
- self.tokens[0].add_pre_comments(comments)
- # pull in post comment on e.g. ':'
- if not self.done and len(self.tokens) < 2:
- self.fetch_more_tokens()
-
- def get_token(self):
- # type: () -> Any
- # Return the next token.
- while self.need_more_tokens():
- self.fetch_more_tokens()
- self._gather_comments()
- if bool(self.tokens):
- # nprint('tk', self.tokens)
- # only add post comment to single line tokens:
- # scalar, value token. FlowXEndToken, otherwise
- # hidden streamtokens could get them (leave them and they will be
- # pre comments for the next map/seq
- if (
- len(self.tokens) > 1
- and isinstance(
- self.tokens[0],
- (
- ScalarToken,
- ValueToken,
- FlowSequenceEndToken,
- FlowMappingEndToken,
- ),
- )
- and isinstance(self.tokens[1], CommentToken)
- and self.tokens[0].end_mark.line == self.tokens[1].start_mark.line
- ):
- self.tokens_taken += 1
- c = self.tokens.pop(1)
- self.fetch_more_tokens()
- while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken):
- self.tokens_taken += 1
- c1 = self.tokens.pop(1)
- c.value = c.value + (" " * c1.start_mark.column) + c1.value
- self.fetch_more_tokens()
- self.tokens[0].add_post_comment(c)
- elif (
- len(self.tokens) > 1
- and isinstance(self.tokens[0], ScalarToken)
- and isinstance(self.tokens[1], CommentToken)
- and self.tokens[0].end_mark.line != self.tokens[1].start_mark.line
- ):
- self.tokens_taken += 1
- c = self.tokens.pop(1)
- c.value = (
- "\n" * (c.start_mark.line - self.tokens[0].end_mark.line)
- + (" " * c.start_mark.column)
- + c.value
- )
- self.tokens[0].add_post_comment(c)
- self.fetch_more_tokens()
- while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken):
- self.tokens_taken += 1
- c1 = self.tokens.pop(1)
- c.value = c.value + (" " * c1.start_mark.column) + c1.value
- self.fetch_more_tokens()
- self.tokens_taken += 1
- return self.tokens.pop(0)
- return None
-
- def fetch_comment(self, comment):
- # type: (Any) -> None
- value, start_mark, end_mark = comment
- while value and value[-1] == " ":
- # empty line within indented key context
- # no need to update end-mark, that is not used
- value = value[:-1]
- self.tokens.append(CommentToken(value, start_mark, end_mark))
-
- # scanner
-
- def scan_to_next_token(self):
- # type: () -> Any
- # We ignore spaces, line breaks and comments.
- # If we find a line break in the block context, we set the flag
- # `allow_simple_key` on.
- # The byte order mark is stripped if it's the first character in the
- # stream. We do not yet support BOM inside the stream as the
- # specification requires. Any such mark will be considered as a part
- # of the document.
- #
- # TODO: We need to make tab handling rules more sane. A good rule is
- # Tabs cannot precede tokens
- # BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END,
- # KEY(block), VALUE(block), BLOCK-ENTRY
- # So the checking code is
- # if :
- # self.allow_simple_keys = False
- # We also need to add the check for `allow_simple_keys == True` to
- # `unwind_indent` before issuing BLOCK-END.
- # Scanners for block, flow, and plain scalars need to be modified.
-
- srp = self.reader.peek
- srf = self.reader.forward
- if self.reader.index == 0 and srp() == "\uFEFF":
- srf()
- found = False
- while not found:
- while srp() == " ":
- srf()
- ch = srp()
- if ch == "#":
- start_mark = self.reader.get_mark()
- comment = ch
- srf()
- while ch not in _THE_END:
- ch = srp()
- if ch == "\0": # don't gobble the end-of-stream character
- # but add an explicit newline as "YAML processors should terminate
- # the stream with an explicit line break
- # https://yaml.org/spec/1.2/spec.html#id2780069
- comment += "\n"
- break
- comment += ch
- srf()
- # gather any blank lines following the comment too
- ch = self.scan_line_break()
- while len(ch) > 0:
- comment += ch
- ch = self.scan_line_break()
- end_mark = self.reader.get_mark()
- if not self.flow_level:
- self.allow_simple_key = True
- return comment, start_mark, end_mark
- if bool(self.scan_line_break()):
- start_mark = self.reader.get_mark()
- if not self.flow_level:
- self.allow_simple_key = True
- ch = srp()
- if ch == "\n": # empty toplevel lines
- start_mark = self.reader.get_mark()
- comment = ""
- while ch:
- ch = self.scan_line_break(empty_line=True)
- comment += ch
- if srp() == "#":
- # empty line followed by indented real comment
- comment = comment.rsplit("\n", 1)[0] + "\n"
- end_mark = self.reader.get_mark()
- return comment, start_mark, end_mark
- else:
- found = True
- return None
-
- def scan_line_break(self, empty_line=False):
- # type: (bool) -> Text
- # Transforms:
- # '\r\n' : '\n'
- # '\r' : '\n'
- # '\n' : '\n'
- # '\x85' : '\n'
- # '\u2028' : '\u2028'
- # '\u2029 : '\u2029'
- # default : ''
- ch = self.reader.peek() # type: Text
- if ch in "\r\n\x85":
- if self.reader.prefix(2) == "\r\n":
- self.reader.forward(2)
- else:
- self.reader.forward()
- return "\n"
- elif ch in "\u2028\u2029":
- self.reader.forward()
- return ch
- elif empty_line and ch in "\t ":
- self.reader.forward()
- return ch
- return ""
-
- def scan_block_scalar(self, style, rt=True):
- # type: (Any, Optional[bool]) -> Any
- return Scanner.scan_block_scalar(self, style, rt=rt)
-
-
-# try:
-# import psyco
-# psyco.bind(Scanner)
-# except ImportError:
-# pass
diff --git a/srsly/ruamel_yaml/serializer.py b/srsly/ruamel_yaml/serializer.py
deleted file mode 100755
index 7888cdc..0000000
--- a/srsly/ruamel_yaml/serializer.py
+++ /dev/null
@@ -1,250 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-
-from .error import YAMLError
-from .compat import nprint, DBG_NODE, dbg, string_types, nprintf # NOQA
-from .util import RegExp
-
-from .events import (
- StreamStartEvent,
- StreamEndEvent,
- MappingStartEvent,
- MappingEndEvent,
- SequenceStartEvent,
- SequenceEndEvent,
- AliasEvent,
- ScalarEvent,
- DocumentStartEvent,
- DocumentEndEvent,
-)
-from .nodes import MappingNode, ScalarNode, SequenceNode
-
-if False: # MYPY
- from typing import Any, Dict, Union, Text, Optional # NOQA
- from .compat import VersionType # NOQA
-
-__all__ = ["Serializer", "SerializerError"]
-
-
-class SerializerError(YAMLError):
- pass
-
-
-class Serializer(object):
-
- # 'id' and 3+ numbers, but not 000
- ANCHOR_TEMPLATE = u"id%03d"
- ANCHOR_RE = RegExp(u"id(?!000$)\\d{3,}")
-
- def __init__(
- self,
- encoding=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- tags=None,
- dumper=None,
- ):
- # type: (Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> None # NOQA
- self.dumper = dumper
- if self.dumper is not None:
- self.dumper._serializer = self
- self.use_encoding = encoding
- self.use_explicit_start = explicit_start
- self.use_explicit_end = explicit_end
- if isinstance(version, string_types):
- self.use_version = tuple(map(int, version.split(".")))
- else:
- self.use_version = version # type: ignore
- self.use_tags = tags
- self.serialized_nodes = {} # type: Dict[Any, Any]
- self.anchors = {} # type: Dict[Any, Any]
- self.last_anchor_id = 0
- self.closed = None # type: Optional[bool]
- self._templated_id = None
-
- @property
- def emitter(self):
- # type: () -> Any
- if hasattr(self.dumper, "typ"):
- return self.dumper.emitter
- return self.dumper._emitter
-
- @property
- def resolver(self):
- # type: () -> Any
- if hasattr(self.dumper, "typ"):
- self.dumper.resolver
- return self.dumper._resolver
-
- def open(self):
- # type: () -> None
- if self.closed is None:
- self.emitter.emit(StreamStartEvent(encoding=self.use_encoding))
- self.closed = False
- elif self.closed:
- raise SerializerError("serializer is closed")
- else:
- raise SerializerError("serializer is already opened")
-
- def close(self):
- # type: () -> None
- if self.closed is None:
- raise SerializerError("serializer is not opened")
- elif not self.closed:
- self.emitter.emit(StreamEndEvent())
- self.closed = True
-
- # def __del__(self):
- # self.close()
-
- def serialize(self, node):
- # type: (Any) -> None
- if dbg(DBG_NODE):
- nprint("Serializing nodes")
- node.dump()
- if self.closed is None:
- raise SerializerError("serializer is not opened")
- elif self.closed:
- raise SerializerError("serializer is closed")
- self.emitter.emit(
- DocumentStartEvent(
- explicit=self.use_explicit_start,
- version=self.use_version,
- tags=self.use_tags,
- )
- )
- self.anchor_node(node)
- self.serialize_node(node, None, None)
- self.emitter.emit(DocumentEndEvent(explicit=self.use_explicit_end))
- self.serialized_nodes = {}
- self.anchors = {}
- self.last_anchor_id = 0
-
- def anchor_node(self, node):
- # type: (Any) -> None
- if node in self.anchors:
- if self.anchors[node] is None:
- self.anchors[node] = self.generate_anchor(node)
- else:
- anchor = None
- try:
- if node.anchor.always_dump:
- anchor = node.anchor.value
- except: # NOQA
- pass
- self.anchors[node] = anchor
- if isinstance(node, SequenceNode):
- for item in node.value:
- self.anchor_node(item)
- elif isinstance(node, MappingNode):
- for key, value in node.value:
- self.anchor_node(key)
- self.anchor_node(value)
-
- def generate_anchor(self, node):
- # type: (Any) -> Any
- try:
- anchor = node.anchor.value
- except: # NOQA
- anchor = None
- if anchor is None:
- self.last_anchor_id += 1
- return self.ANCHOR_TEMPLATE % self.last_anchor_id
- return anchor
-
- def serialize_node(self, node, parent, index):
- # type: (Any, Any, Any) -> None
- alias = self.anchors[node]
- if node in self.serialized_nodes:
- self.emitter.emit(AliasEvent(alias))
- else:
- self.serialized_nodes[node] = True
- self.resolver.descend_resolver(parent, index)
- if isinstance(node, ScalarNode):
- # here check if the node.tag equals the one that would result from parsing
- # if not equal quoting is necessary for strings
- detected_tag = self.resolver.resolve(
- ScalarNode, node.value, (True, False)
- )
- default_tag = self.resolver.resolve(
- ScalarNode, node.value, (False, True)
- )
- implicit = (
- (node.tag == detected_tag),
- (node.tag == default_tag),
- node.tag.startswith("tag:yaml.org,2002:"),
- )
- self.emitter.emit(
- ScalarEvent(
- alias,
- node.tag,
- implicit,
- node.value,
- style=node.style,
- comment=node.comment,
- )
- )
- elif isinstance(node, SequenceNode):
- implicit = node.tag == self.resolver.resolve(
- SequenceNode, node.value, True
- )
- comment = node.comment
- end_comment = None
- seq_comment = None
- if node.flow_style is True:
- if comment: # eol comment on flow style sequence
- seq_comment = comment[0]
- # comment[0] = None
- if comment and len(comment) > 2:
- end_comment = comment[2]
- else:
- end_comment = None
- self.emitter.emit(
- SequenceStartEvent(
- alias,
- node.tag,
- implicit,
- flow_style=node.flow_style,
- comment=node.comment,
- )
- )
- index = 0
- for item in node.value:
- self.serialize_node(item, node, index)
- index += 1
- self.emitter.emit(SequenceEndEvent(comment=[seq_comment, end_comment]))
- elif isinstance(node, MappingNode):
- implicit = node.tag == self.resolver.resolve(
- MappingNode, node.value, True
- )
- comment = node.comment
- end_comment = None
- map_comment = None
- if node.flow_style is True:
- if comment: # eol comment on flow style sequence
- map_comment = comment[0]
- # comment[0] = None
- if comment and len(comment) > 2:
- end_comment = comment[2]
- self.emitter.emit(
- MappingStartEvent(
- alias,
- node.tag,
- implicit,
- flow_style=node.flow_style,
- comment=node.comment,
- nr_items=len(node.value),
- )
- )
- for key, value in node.value:
- self.serialize_node(key, node, None)
- self.serialize_node(value, node, key)
- self.emitter.emit(MappingEndEvent(comment=[map_comment, end_comment]))
- self.resolver.ascend_resolver()
-
-
-def templated_id(s):
- # type: (Text) -> Any
- return Serializer.ANCHOR_RE.match(s)
diff --git a/srsly/ruamel_yaml/timestamp.py b/srsly/ruamel_yaml/timestamp.py
deleted file mode 100755
index 374e4c0..0000000
--- a/srsly/ruamel_yaml/timestamp.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-import datetime
-import copy
-
-# ToDo: at least on PY3 you could probably attach the tzinfo correctly to the object
-# a more complete datetime might be used by safe loading as well
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List # NOQA
-
-
-class TimeStamp(datetime.datetime):
- def __init__(self, *args, **kw):
- # type: (Any, Any) -> None
- self._yaml = dict(t=False, tz=None, delta=0) # type: Dict[Any, Any]
-
- def __new__(cls, *args, **kw): # datetime is immutable
- # type: (Any, Any) -> Any
- return datetime.datetime.__new__(cls, *args, **kw) # type: ignore
-
- def __deepcopy__(self, memo):
- # type: (Any) -> Any
- ts = TimeStamp(self.year, self.month, self.day, self.hour, self.minute, self.second)
- ts._yaml = copy.deepcopy(self._yaml)
- return ts
diff --git a/srsly/ruamel_yaml/tokens.py b/srsly/ruamel_yaml/tokens.py
deleted file mode 100755
index 5f5a663..0000000
--- a/srsly/ruamel_yaml/tokens.py
+++ /dev/null
@@ -1,286 +0,0 @@
-# # header
-# coding: utf-8
-
-from __future__ import unicode_literals
-
-if False: # MYPY
- from typing import Text, Any, Dict, Optional, List # NOQA
- from .error import StreamMark # NOQA
-
-SHOWLINES = True
-
-
-class Token(object):
- __slots__ = 'start_mark', 'end_mark', '_comment'
-
- def __init__(self, start_mark, end_mark):
- # type: (StreamMark, StreamMark) -> None
- self.start_mark = start_mark
- self.end_mark = end_mark
-
- def __repr__(self):
- # type: () -> Any
- # attributes = [key for key in self.__slots__ if not key.endswith('_mark') and
- # hasattr('self', key)]
- attributes = [key for key in self.__slots__ if not key.endswith('_mark')]
- attributes.sort()
- arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes])
- if SHOWLINES:
- try:
- arguments += ', line: ' + str(self.start_mark.line)
- except: # NOQA
- pass
- try:
- arguments += ', comment: ' + str(self._comment)
- except: # NOQA
- pass
- return '{}({})'.format(self.__class__.__name__, arguments)
-
- def add_post_comment(self, comment):
- # type: (Any) -> None
- if not hasattr(self, '_comment'):
- self._comment = [None, None]
- self._comment[0] = comment
-
- def add_pre_comments(self, comments):
- # type: (Any) -> None
- if not hasattr(self, '_comment'):
- self._comment = [None, None]
- assert self._comment[1] is None
- self._comment[1] = comments
-
- def get_comment(self):
- # type: () -> Any
- return getattr(self, '_comment', None)
-
- @property
- def comment(self):
- # type: () -> Any
- return getattr(self, '_comment', None)
-
- def move_comment(self, target, empty=False):
- # type: (Any, bool) -> Any
- """move a comment from this token to target (normally next token)
- used to combine e.g. comments before a BlockEntryToken to the
- ScalarToken that follows it
- empty is a special for empty values -> comment after key
- """
- c = self.comment
- if c is None:
- return
- # don't push beyond last element
- if isinstance(target, (StreamEndToken, DocumentStartToken)):
- return
- delattr(self, '_comment')
- tc = target.comment
- if not tc: # target comment, just insert
- # special for empty value in key: value issue 25
- if empty:
- c = [c[0], c[1], None, None, c[0]]
- target._comment = c
- # nprint('mco2:', self, target, target.comment, empty)
- return self
- if c[0] and tc[0] or c[1] and tc[1]:
- raise NotImplementedError('overlap in comment %r %r' % (c, tc))
- if c[0]:
- tc[0] = c[0]
- if c[1]:
- tc[1] = c[1]
- return self
-
- def split_comment(self):
- # type: () -> Any
- """ split the post part of a comment, and return it
- as comment to be added. Delete second part if [None, None]
- abc: # this goes to sequence
- # this goes to first element
- - first element
- """
- comment = self.comment
- if comment is None or comment[0] is None:
- return None # nothing to do
- ret_val = [comment[0], None]
- if comment[1] is None:
- delattr(self, '_comment')
- return ret_val
-
-
-# class BOMToken(Token):
-# id = ''
-
-
-class DirectiveToken(Token):
- __slots__ = 'name', 'value'
- id = ''
-
- def __init__(self, name, value, start_mark, end_mark):
- # type: (Any, Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.name = name
- self.value = value
-
-
-class DocumentStartToken(Token):
- __slots__ = ()
- id = ''
-
-
-class DocumentEndToken(Token):
- __slots__ = ()
- id = ''
-
-
-class StreamStartToken(Token):
- __slots__ = ('encoding',)
- id = ''
-
- def __init__(self, start_mark=None, end_mark=None, encoding=None):
- # type: (Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.encoding = encoding
-
-
-class StreamEndToken(Token):
- __slots__ = ()
- id = ''
-
-
-class BlockSequenceStartToken(Token):
- __slots__ = ()
- id = ''
-
-
-class BlockMappingStartToken(Token):
- __slots__ = ()
- id = ''
-
-
-class BlockEndToken(Token):
- __slots__ = ()
- id = ''
-
-
-class FlowSequenceStartToken(Token):
- __slots__ = ()
- id = '['
-
-
-class FlowMappingStartToken(Token):
- __slots__ = ()
- id = '{'
-
-
-class FlowSequenceEndToken(Token):
- __slots__ = ()
- id = ']'
-
-
-class FlowMappingEndToken(Token):
- __slots__ = ()
- id = '}'
-
-
-class KeyToken(Token):
- __slots__ = ()
- id = '?'
-
- # def x__repr__(self):
- # return 'KeyToken({})'.format(
- # self.start_mark.buffer[self.start_mark.index:].split(None, 1)[0])
-
-
-class ValueToken(Token):
- __slots__ = ()
- id = ':'
-
-
-class BlockEntryToken(Token):
- __slots__ = ()
- id = '-'
-
-
-class FlowEntryToken(Token):
- __slots__ = ()
- id = ','
-
-
-class AliasToken(Token):
- __slots__ = ('value',)
- id = ''
-
- def __init__(self, value, start_mark, end_mark):
- # type: (Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.value = value
-
-
-class AnchorToken(Token):
- __slots__ = ('value',)
- id = ''
-
- def __init__(self, value, start_mark, end_mark):
- # type: (Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.value = value
-
-
-class TagToken(Token):
- __slots__ = ('value',)
- id = ''
-
- def __init__(self, value, start_mark, end_mark):
- # type: (Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.value = value
-
-
-class ScalarToken(Token):
- __slots__ = 'value', 'plain', 'style'
- id = ''
-
- def __init__(self, value, plain, start_mark, end_mark, style=None):
- # type: (Any, Any, Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.value = value
- self.plain = plain
- self.style = style
-
-
-class CommentToken(Token):
- __slots__ = 'value', 'pre_done'
- id = ''
-
- def __init__(self, value, start_mark, end_mark):
- # type: (Any, Any, Any) -> None
- Token.__init__(self, start_mark, end_mark)
- self.value = value
-
- def reset(self):
- # type: () -> None
- if hasattr(self, 'pre_done'):
- delattr(self, 'pre_done')
-
- def __repr__(self):
- # type: () -> Any
- v = '{!r}'.format(self.value)
- if SHOWLINES:
- try:
- v += ', line: ' + str(self.start_mark.line)
- v += ', col: ' + str(self.start_mark.column)
- except: # NOQA
- pass
- return 'CommentToken({})'.format(v)
-
- def __eq__(self, other):
- # type: (Any) -> bool
- if self.start_mark != other.start_mark:
- return False
- if self.end_mark != other.end_mark:
- return False
- if self.value != other.value:
- return False
- return True
-
- def __ne__(self, other):
- # type: (Any) -> bool
- return not self.__eq__(other)
diff --git a/srsly/ruamel_yaml/util.py b/srsly/ruamel_yaml/util.py
deleted file mode 100755
index 3eb7d76..0000000
--- a/srsly/ruamel_yaml/util.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# coding: utf-8
-
-"""
-some helper functions that might be generally useful
-"""
-
-from __future__ import absolute_import, print_function
-
-from functools import partial
-import re
-
-from .compat import text_type, binary_type
-
-if False: # MYPY
- from typing import Any, Dict, Optional, List, Text # NOQA
- from .compat import StreamTextType # NOQA
-
-
-class LazyEval(object):
- """
- Lightweight wrapper around lazily evaluated func(*args, **kwargs).
-
- func is only evaluated when any attribute of its return value is accessed.
- Every attribute access is passed through to the wrapped value.
- (This only excludes special cases like method-wrappers, e.g., __hash__.)
- The sole additional attribute is the lazy_self function which holds the
- return value (or, prior to evaluation, func and arguments), in its closure.
- """
-
- def __init__(self, func, *args, **kwargs):
- # type: (Any, Any, Any) -> None
- def lazy_self():
- # type: () -> Any
- return_value = func(*args, **kwargs)
- object.__setattr__(self, 'lazy_self', lambda: return_value)
- return return_value
-
- object.__setattr__(self, 'lazy_self', lazy_self)
-
- def __getattribute__(self, name):
- # type: (Any) -> Any
- lazy_self = object.__getattribute__(self, 'lazy_self')
- if name == 'lazy_self':
- return lazy_self
- return getattr(lazy_self(), name)
-
- def __setattr__(self, name, value):
- # type: (Any, Any) -> None
- setattr(self.lazy_self(), name, value)
-
-
-RegExp = partial(LazyEval, re.compile)
-
-
-# originally as comment
-# https://github.com/pre-commit/pre-commit/pull/211#issuecomment-186466605
-# if you use this in your code, I suggest adding a test in your test suite
-# that check this routines output against a known piece of your YAML
-# before upgrades to this code break your round-tripped YAML
-def load_yaml_guess_indent(stream, **kw):
- # type: (StreamTextType, Any) -> Any
- """guess the indent and block sequence indent of yaml stream/string
-
- returns round_trip_loaded stream, indent level, block sequence indent
- - block sequence indent is the number of spaces before a dash relative to previous indent
- - if there are no block sequences, indent is taken from nested mappings, block sequence
- indent is unset (None) in that case
- """
- from .main import round_trip_load
-
- # load a yaml file guess the indentation, if you use TABs ...
- def leading_spaces(l):
- # type: (Any) -> int
- idx = 0
- while idx < len(l) and l[idx] == ' ':
- idx += 1
- return idx
-
- if isinstance(stream, text_type):
- yaml_str = stream # type: Any
- elif isinstance(stream, binary_type):
- # most likely, but the Reader checks BOM for this
- yaml_str = stream.decode('utf-8')
- else:
- yaml_str = stream.read()
- map_indent = None
- indent = None # default if not found for some reason
- block_seq_indent = None
- prev_line_key_only = None
- key_indent = 0
- for line in yaml_str.splitlines():
- rline = line.rstrip()
- lline = rline.lstrip()
- if lline.startswith('- '):
- l_s = leading_spaces(line)
- block_seq_indent = l_s - key_indent
- idx = l_s + 1
- while line[idx] == ' ': # this will end as we rstripped
- idx += 1
- if line[idx] == '#': # comment after -
- continue
- indent = idx - key_indent
- break
- if map_indent is None and prev_line_key_only is not None and rline:
- idx = 0
- while line[idx] in ' -':
- idx += 1
- if idx > prev_line_key_only:
- map_indent = idx - prev_line_key_only
- if rline.endswith(':'):
- key_indent = leading_spaces(line)
- idx = 0
- while line[idx] == ' ': # this will end on ':'
- idx += 1
- prev_line_key_only = idx
- continue
- prev_line_key_only = None
- if indent is None and map_indent is not None:
- indent = map_indent
- return round_trip_load(yaml_str, **kw), indent, block_seq_indent
-
-
-def configobj_walker(cfg):
- # type: (Any) -> Any
- """
- walks over a ConfigObj (INI file with comments) generating
- corresponding YAML output (including comments
- """
- from configobj import ConfigObj # type: ignore
-
- assert isinstance(cfg, ConfigObj)
- for c in cfg.initial_comment:
- if c.strip():
- yield c
- for s in _walk_section(cfg):
- if s.strip():
- yield s
- for c in cfg.final_comment:
- if c.strip():
- yield c
-
-
-def _walk_section(s, level=0):
- # type: (Any, int) -> Any
- from configobj import Section
-
- assert isinstance(s, Section)
- indent = u' ' * level
- for name in s.scalars:
- for c in s.comments[name]:
- yield indent + c.strip()
- x = s[name]
- if u'\n' in x:
- i = indent + u' '
- x = u'|\n' + i + x.strip().replace(u'\n', u'\n' + i)
- elif ':' in x:
- x = u"'" + x.replace(u"'", u"''") + u"'"
- line = u'{0}{1}: {2}'.format(indent, name, x)
- c = s.inline_comments[name]
- if c:
- line += u' ' + c
- yield line
- for name in s.sections:
- for c in s.comments[name]:
- yield indent + c.strip()
- line = u'{0}{1}:'.format(indent, name)
- c = s.inline_comments[name]
- if c:
- line += u' ' + c
- yield line
- for val in _walk_section(s[name], level=level + 1):
- yield val
-
-
-# def config_obj_2_rt_yaml(cfg):
-# from .comments import CommentedMap, CommentedSeq
-# from configobj import ConfigObj
-# assert isinstance(cfg, ConfigObj)
-# #for c in cfg.initial_comment:
-# # if c.strip():
-# # pass
-# cm = CommentedMap()
-# for name in s.sections:
-# cm[name] = d = CommentedMap()
-#
-#
-# #for c in cfg.final_comment:
-# # if c.strip():
-# # yield c
-# return cm
diff --git a/srsly/tests/cloudpickle/__init__.py b/srsly/tests/cloudpickle/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/srsly/tests/cloudpickle/cloudpickle_file_test.py b/srsly/tests/cloudpickle/cloudpickle_file_test.py
deleted file mode 100644
index 218566f..0000000
--- a/srsly/tests/cloudpickle/cloudpickle_file_test.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import os
-import shutil
-import sys
-import tempfile
-import unittest
-
-import pytest
-
-import srsly.cloudpickle as cloudpickle
-from srsly.cloudpickle.compat import pickle
-
-
-class CloudPickleFileTests(unittest.TestCase):
- """In Cloudpickle, expected behaviour when pickling an opened file
- is to send its contents over the wire and seek to the same position."""
-
- def setUp(self):
- self.tmpdir = tempfile.mkdtemp()
- self.tmpfilepath = os.path.join(self.tmpdir, 'testfile')
- self.teststring = 'Hello world!'
-
- def tearDown(self):
- shutil.rmtree(self.tmpdir)
-
- def test_empty_file(self):
- # Empty file
- open(self.tmpfilepath, 'w').close()
- with open(self.tmpfilepath, 'r') as f:
- self.assertEqual('', pickle.loads(cloudpickle.dumps(f)).read())
- os.remove(self.tmpfilepath)
-
- def test_closed_file(self):
- # Write & close
- with open(self.tmpfilepath, 'w') as f:
- f.write(self.teststring)
- with pytest.raises(pickle.PicklingError) as excinfo:
- cloudpickle.dumps(f)
- assert "Cannot pickle closed files" in str(excinfo.value)
- os.remove(self.tmpfilepath)
-
- def test_r_mode(self):
- # Write & close
- with open(self.tmpfilepath, 'w') as f:
- f.write(self.teststring)
- # Open for reading
- with open(self.tmpfilepath, 'r') as f:
- new_f = pickle.loads(cloudpickle.dumps(f))
- self.assertEqual(self.teststring, new_f.read())
- os.remove(self.tmpfilepath)
-
- def test_w_mode(self):
- with open(self.tmpfilepath, 'w') as f:
- f.write(self.teststring)
- f.seek(0)
- self.assertRaises(pickle.PicklingError,
- lambda: cloudpickle.dumps(f))
- os.remove(self.tmpfilepath)
-
- def test_plus_mode(self):
- # Write, then seek to 0
- with open(self.tmpfilepath, 'w+') as f:
- f.write(self.teststring)
- f.seek(0)
- new_f = pickle.loads(cloudpickle.dumps(f))
- self.assertEqual(self.teststring, new_f.read())
- os.remove(self.tmpfilepath)
-
- def test_seek(self):
- # Write, then seek to arbitrary position
- with open(self.tmpfilepath, 'w+') as f:
- f.write(self.teststring)
- f.seek(4)
- unpickled = pickle.loads(cloudpickle.dumps(f))
- # unpickled StringIO is at position 4
- self.assertEqual(4, unpickled.tell())
- self.assertEqual(self.teststring[4:], unpickled.read())
- # but unpickled StringIO also contained the start
- unpickled.seek(0)
- self.assertEqual(self.teststring, unpickled.read())
- os.remove(self.tmpfilepath)
-
- @pytest.mark.skip(reason="Requires pytest -s to pass")
- def test_pickling_special_file_handles(self):
- # Warning: if you want to run your tests with nose, add -s option
- for out in sys.stdout, sys.stderr: # Regression test for SPARK-3415
- self.assertEqual(out, pickle.loads(cloudpickle.dumps(out)))
- self.assertRaises(pickle.PicklingError,
- lambda: cloudpickle.dumps(sys.stdin))
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/srsly/tests/cloudpickle/cloudpickle_test.py b/srsly/tests/cloudpickle/cloudpickle_test.py
deleted file mode 100644
index 25451df..0000000
--- a/srsly/tests/cloudpickle/cloudpickle_test.py
+++ /dev/null
@@ -1,2847 +0,0 @@
-import _collections_abc
-import abc
-import collections
-import base64
-import functools
-import io
-import itertools
-import logging
-import math
-import multiprocessing
-from operator import itemgetter, attrgetter
-import pickletools
-import platform
-import random
-import re
-import shutil
-import subprocess
-import sys
-import tempfile
-import textwrap
-import types
-import unittest
-import weakref
-import os
-import enum
-import typing
-from functools import wraps
-
-import pytest
-
-try:
- # try importing numpy and scipy. These are not hard dependencies and
- # tests should be skipped if these modules are not available
- import numpy as np
- import scipy.special as spp
-except (ImportError, RuntimeError):
- np = None
- spp = None
-
-try:
- # Ditto for Tornado
- import tornado
-except ImportError:
- tornado = None
-
-import srsly.cloudpickle as cloudpickle
-from srsly.cloudpickle.compat import pickle
-from srsly.cloudpickle import register_pickle_by_value
-from srsly.cloudpickle import unregister_pickle_by_value
-from srsly.cloudpickle import list_registry_pickle_by_value
-from srsly.cloudpickle.cloudpickle import _should_pickle_by_reference
-from srsly.cloudpickle.cloudpickle import _make_empty_cell, cell_set
-from srsly.cloudpickle.cloudpickle import _extract_class_dict, _whichmodule
-from srsly.cloudpickle.cloudpickle import _lookup_module_and_qualname
-
-from .testutils import subprocess_pickle_echo
-from .testutils import subprocess_pickle_string
-from .testutils import assert_run_python_script
-from .testutils import subprocess_worker
-
-
-_TEST_GLOBAL_VARIABLE = "default_value"
-_TEST_GLOBAL_VARIABLE2 = "another_value"
-
-
-class RaiserOnPickle:
-
- def __init__(self, exc):
- self.exc = exc
-
- def __reduce__(self):
- raise self.exc
-
-
-def pickle_depickle(obj, protocol=cloudpickle.DEFAULT_PROTOCOL):
- """Helper function to test whether object pickled with cloudpickle can be
- depickled with pickle
- """
- return pickle.loads(cloudpickle.dumps(obj, protocol=protocol))
-
-
-def _escape(raw_filepath):
- # Ugly hack to embed filepaths in code templates for windows
- return raw_filepath.replace("\\", r"\\\\")
-
-
-def _maybe_remove(list_, item):
- try:
- list_.remove(item)
- except ValueError:
- pass
- return list_
-
-
-def test_extract_class_dict():
- class A(int):
- """A docstring"""
- def method(self):
- return "a"
-
- class B:
- """B docstring"""
- B_CONSTANT = 42
-
- def method(self):
- return "b"
-
- class C(A, B):
- C_CONSTANT = 43
-
- def method_c(self):
- return "c"
-
- clsdict = _extract_class_dict(C)
- if sys.version_info >= (3, 13):
- expected_keys = ["C_CONSTANT", "__doc__", "__firstlineno__", "method_c"]
- else:
- expected_keys = ["C_CONSTANT", "__doc__", "method_c"]
- assert sorted(clsdict) == expected_keys
- assert clsdict["C_CONSTANT"] == 43
- assert clsdict["__doc__"] is None
- assert clsdict["method_c"](C()) == C().method_c()
-
-
-class CloudPickleTest(unittest.TestCase):
-
- protocol = cloudpickle.DEFAULT_PROTOCOL
-
- def setUp(self):
- self.tmpdir = tempfile.mkdtemp(prefix="tmp_cloudpickle_test_")
-
- def tearDown(self):
- shutil.rmtree(self.tmpdir)
-
- @pytest.mark.skipif(
- platform.python_implementation() != "CPython" or
- (sys.version_info >= (3, 8, 0) and sys.version_info < (3, 8, 2)),
- reason="Underlying bug fixed upstream starting Python 3.8.2")
- def test_reducer_override_reference_cycle(self):
- # Early versions of Python 3.8 introduced a reference cycle between a
- # Pickler and it's reducer_override method. Because a Pickler
- # object references every object it has pickled through its memo, this
- # cycle prevented the garbage-collection of those external pickled
- # objects. See #327 as well as https://bugs.python.org/issue39492
- # This bug was fixed in Python 3.8.2, but is still present using
- # cloudpickle and Python 3.8.0/1, hence the skipif directive.
- class MyClass:
- pass
-
- my_object = MyClass()
- wr = weakref.ref(my_object)
-
- cloudpickle.dumps(my_object)
- del my_object
- assert wr() is None, "'del'-ed my_object has not been collected"
-
- def test_itemgetter(self):
- d = range(10)
- getter = itemgetter(1)
-
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
-
- getter = itemgetter(0, 3)
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
-
- def test_attrgetter(self):
- class C:
- def __getattr__(self, item):
- return item
- d = C()
- getter = attrgetter("a")
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
- getter = attrgetter("a", "b")
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
-
- d.e = C()
- getter = attrgetter("e.a")
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
- getter = attrgetter("e.a", "e.b")
- getter2 = pickle_depickle(getter, protocol=self.protocol)
- self.assertEqual(getter(d), getter2(d))
-
- # Regression test for SPARK-3415
- @pytest.mark.skip(reason="Requires pytest -s to pass")
- def test_pickling_file_handles(self):
- out1 = sys.stderr
- out2 = pickle.loads(cloudpickle.dumps(out1, protocol=self.protocol))
- self.assertEqual(out1, out2)
-
- def test_func_globals(self):
- class Unpicklable:
- def __reduce__(self):
- raise Exception("not picklable")
-
- global exit
- exit = Unpicklable()
-
- self.assertRaises(Exception, lambda: cloudpickle.dumps(
- exit, protocol=self.protocol))
-
- def foo():
- sys.exit(0)
-
- self.assertTrue("exit" in foo.__code__.co_names)
- cloudpickle.dumps(foo)
-
- def test_buffer(self):
- try:
- buffer_obj = buffer("Hello")
- buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
- self.assertEqual(buffer_clone, str(buffer_obj))
- buffer_obj = buffer("Hello", 2, 3)
- buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
- self.assertEqual(buffer_clone, str(buffer_obj))
- except NameError: # Python 3 does no longer support buffers
- pass
-
- def test_memoryview(self):
- buffer_obj = memoryview(b"Hello")
- self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
- buffer_obj.tobytes())
-
- def test_dict_keys(self):
- keys = {"a": 1, "b": 2}.keys()
- results = pickle_depickle(keys)
- self.assertEqual(results, keys)
- assert isinstance(results, _collections_abc.dict_keys)
-
- def test_dict_values(self):
- values = {"a": 1, "b": 2}.values()
- results = pickle_depickle(values)
- self.assertEqual(sorted(results), sorted(values))
- assert isinstance(results, _collections_abc.dict_values)
-
- def test_dict_items(self):
- items = {"a": 1, "b": 2}.items()
- results = pickle_depickle(items)
- self.assertEqual(results, items)
- assert isinstance(results, _collections_abc.dict_items)
-
- def test_odict_keys(self):
- keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys()
- results = pickle_depickle(keys)
- self.assertEqual(results, keys)
- assert type(keys) == type(results)
-
- def test_odict_values(self):
- values = collections.OrderedDict([("a", 1), ("b", 2)]).values()
- results = pickle_depickle(values)
- self.assertEqual(list(results), list(values))
- assert type(values) == type(results)
-
- def test_odict_items(self):
- items = collections.OrderedDict([("a", 1), ("b", 2)]).items()
- results = pickle_depickle(items)
- self.assertEqual(results, items)
- assert type(items) == type(results)
-
- def test_sliced_and_non_contiguous_memoryview(self):
- buffer_obj = memoryview(b"Hello!" * 3)[2:15:2]
- self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
- buffer_obj.tobytes())
-
- def test_large_memoryview(self):
- buffer_obj = memoryview(b"Hello!" * int(1e7))
- self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
- buffer_obj.tobytes())
-
- def test_lambda(self):
- self.assertEqual(
- pickle_depickle(lambda: 1, protocol=self.protocol)(), 1)
-
- def test_nested_lambdas(self):
- a, b = 1, 2
- f1 = lambda x: x + a
- f2 = lambda x: f1(x) // b
- self.assertEqual(pickle_depickle(f2, protocol=self.protocol)(1), 1)
-
- def test_recursive_closure(self):
- def f1():
- def g():
- return g
- return g
-
- def f2(base):
- def g(n):
- return base if n <= 1 else n * g(n - 1)
- return g
-
- g1 = pickle_depickle(f1(), protocol=self.protocol)
- self.assertEqual(g1(), g1)
-
- g2 = pickle_depickle(f2(2), protocol=self.protocol)
- self.assertEqual(g2(5), 240)
-
- def test_closure_none_is_preserved(self):
- def f():
- """a function with no closure cells
- """
-
- self.assertTrue(
- f.__closure__ is None,
- msg='f actually has closure cells!',
- )
-
- g = pickle_depickle(f, protocol=self.protocol)
-
- self.assertTrue(
- g.__closure__ is None,
- msg='g now has closure cells even though f does not',
- )
-
- def test_empty_cell_preserved(self):
- def f():
- if False: # pragma: no cover
- cell = None
-
- def g():
- cell # NameError, unbound free variable
-
- return g
-
- g1 = f()
- with pytest.raises(NameError):
- g1()
-
- g2 = pickle_depickle(g1, protocol=self.protocol)
- with pytest.raises(NameError):
- g2()
-
- def test_unhashable_closure(self):
- def f():
- s = {1, 2} # mutable set is unhashable
-
- def g():
- return len(s)
-
- return g
-
- g = pickle_depickle(f(), protocol=self.protocol)
- self.assertEqual(g(), 2)
-
- def test_dynamically_generated_class_that_uses_super(self):
-
- class Base:
- def method(self):
- return 1
-
- class Derived(Base):
- "Derived Docstring"
- def method(self):
- return super().method() + 1
-
- self.assertEqual(Derived().method(), 2)
-
- # Pickle and unpickle the class.
- UnpickledDerived = pickle_depickle(Derived, protocol=self.protocol)
- self.assertEqual(UnpickledDerived().method(), 2)
-
- # We have special logic for handling __doc__ because it's a readonly
- # attribute on PyPy.
- self.assertEqual(UnpickledDerived.__doc__, "Derived Docstring")
-
- # Pickle and unpickle an instance.
- orig_d = Derived()
- d = pickle_depickle(orig_d, protocol=self.protocol)
- self.assertEqual(d.method(), 2)
-
- def test_cycle_in_classdict_globals(self):
-
- class C:
-
- def it_works(self):
- return "woohoo!"
-
- C.C_again = C
- C.instance_of_C = C()
-
- depickled_C = pickle_depickle(C, protocol=self.protocol)
- depickled_instance = pickle_depickle(C())
-
- # Test instance of depickled class.
- self.assertEqual(depickled_C().it_works(), "woohoo!")
- self.assertEqual(depickled_C.C_again().it_works(), "woohoo!")
- self.assertEqual(depickled_C.instance_of_C.it_works(), "woohoo!")
- self.assertEqual(depickled_instance.it_works(), "woohoo!")
-
- def test_locally_defined_function_and_class(self):
- LOCAL_CONSTANT = 42
-
- def some_function(x, y):
- # Make sure the __builtins__ are not broken (see #211)
- sum(range(10))
- return (x + y) / LOCAL_CONSTANT
-
- # pickle the function definition
- self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(41, 1), 1)
- self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(81, 3), 2)
-
- hidden_constant = lambda: LOCAL_CONSTANT
-
- class SomeClass:
- """Overly complicated class with nested references to symbols"""
- def __init__(self, value):
- self.value = value
-
- def one(self):
- return LOCAL_CONSTANT / hidden_constant()
-
- def some_method(self, x):
- return self.one() + some_function(x, 1) + self.value
-
- # pickle the class definition
- clone_class = pickle_depickle(SomeClass, protocol=self.protocol)
- self.assertEqual(clone_class(1).one(), 1)
- self.assertEqual(clone_class(5).some_method(41), 7)
- clone_class = subprocess_pickle_echo(SomeClass, protocol=self.protocol)
- self.assertEqual(clone_class(5).some_method(41), 7)
-
- # pickle the class instances
- self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1)
- self.assertEqual(pickle_depickle(SomeClass(5)).some_method(41), 7)
- new_instance = subprocess_pickle_echo(SomeClass(5),
- protocol=self.protocol)
- self.assertEqual(new_instance.some_method(41), 7)
-
- # pickle the method instances
- self.assertEqual(pickle_depickle(SomeClass(1).one)(), 1)
- self.assertEqual(pickle_depickle(SomeClass(5).some_method)(41), 7)
- new_method = subprocess_pickle_echo(SomeClass(5).some_method,
- protocol=self.protocol)
- self.assertEqual(new_method(41), 7)
-
- def test_partial(self):
- partial_obj = functools.partial(min, 1)
- partial_clone = pickle_depickle(partial_obj, protocol=self.protocol)
- self.assertEqual(partial_clone(4), 1)
-
- @pytest.mark.skipif(platform.python_implementation() == 'PyPy',
- reason="Skip numpy and scipy tests on PyPy")
- def test_ufunc(self):
- # test a numpy ufunc (universal function), which is a C-based function
- # that is applied on a numpy array
-
- if np:
- # simple ufunc: np.add
- self.assertEqual(pickle_depickle(np.add, protocol=self.protocol),
- np.add)
- else: # skip if numpy is not available
- pass
-
- if spp:
- # custom ufunc: scipy.special.iv
- self.assertEqual(pickle_depickle(spp.iv, protocol=self.protocol),
- spp.iv)
- else: # skip if scipy is not available
- pass
-
- def test_loads_namespace(self):
- obj = 1, 2, 3, 4
- returned_obj = cloudpickle.loads(cloudpickle.dumps(
- obj, protocol=self.protocol))
- self.assertEqual(obj, returned_obj)
-
- def test_load_namespace(self):
- obj = 1, 2, 3, 4
- bio = io.BytesIO()
- cloudpickle.dump(obj, bio)
- bio.seek(0)
- returned_obj = cloudpickle.load(bio)
- self.assertEqual(obj, returned_obj)
-
- def test_generator(self):
-
- def some_generator(cnt):
- for i in range(cnt):
- yield i
-
- gen2 = pickle_depickle(some_generator, protocol=self.protocol)
-
- assert type(gen2(3)) == type(some_generator(3))
- assert list(gen2(3)) == list(range(3))
-
- def test_classmethod(self):
- class A:
- @staticmethod
- def test_sm():
- return "sm"
- @classmethod
- def test_cm(cls):
- return "cm"
-
- sm = A.__dict__["test_sm"]
- cm = A.__dict__["test_cm"]
-
- A.test_sm = pickle_depickle(sm, protocol=self.protocol)
- A.test_cm = pickle_depickle(cm, protocol=self.protocol)
-
- self.assertEqual(A.test_sm(), "sm")
- self.assertEqual(A.test_cm(), "cm")
-
- def test_bound_classmethod(self):
- class A:
- @classmethod
- def test_cm(cls):
- return "cm"
-
- A.test_cm = pickle_depickle(A.test_cm, protocol=self.protocol)
- self.assertEqual(A.test_cm(), "cm")
-
- def test_method_descriptors(self):
- f = pickle_depickle(str.upper)
- self.assertEqual(f('abc'), 'ABC')
-
- def test_instancemethods_without_self(self):
- class F:
- def f(self, x):
- return x + 1
-
- g = pickle_depickle(F.f, protocol=self.protocol)
- self.assertEqual(g.__name__, F.f.__name__)
- # self.assertEqual(g(F(), 1), 2) # still fails
-
- def test_module(self):
- pickle_clone = pickle_depickle(pickle, protocol=self.protocol)
- self.assertEqual(pickle, pickle_clone)
-
- def test_dynamic_module(self):
- mod = types.ModuleType('mod')
- code = '''
- x = 1
- def f(y):
- return x + y
-
- class Foo:
- def method(self, x):
- return f(x)
- '''
- exec(textwrap.dedent(code), mod.__dict__)
- mod2 = pickle_depickle(mod, protocol=self.protocol)
- self.assertEqual(mod.x, mod2.x)
- self.assertEqual(mod.f(5), mod2.f(5))
- self.assertEqual(mod.Foo().method(5), mod2.Foo().method(5))
-
- if platform.python_implementation() != 'PyPy':
- # XXX: this fails with excessive recursion on PyPy.
- mod3 = subprocess_pickle_echo(mod, protocol=self.protocol)
- self.assertEqual(mod.x, mod3.x)
- self.assertEqual(mod.f(5), mod3.f(5))
- self.assertEqual(mod.Foo().method(5), mod3.Foo().method(5))
-
- # Test dynamic modules when imported back are singletons
- mod1, mod2 = pickle_depickle([mod, mod])
- self.assertEqual(id(mod1), id(mod2))
-
- # Ensure proper pickling of mod's functions when module "looks" like a
- # file-backed module even though it is not:
- try:
- sys.modules['mod'] = mod
- depickled_f = pickle_depickle(mod.f, protocol=self.protocol)
- self.assertEqual(mod.f(5), depickled_f(5))
- finally:
- sys.modules.pop('mod', None)
-
- def test_module_locals_behavior(self):
- # Makes sure that a local function defined in another module is
- # correctly serialized. This notably checks that the globals are
- # accessible and that there is no issue with the builtins (see #211)
-
- pickled_func_path = os.path.join(self.tmpdir, 'local_func_g.pkl')
-
- child_process_script = '''
- from srsly.cloudpickle.compat import pickle
- import gc
- with open("{pickled_func_path}", 'rb') as f:
- func = pickle.load(f)
-
- assert func(range(10)) == 45
- '''
-
- child_process_script = child_process_script.format(
- pickled_func_path=_escape(pickled_func_path))
-
- try:
-
- from srsly.tests.cloudpickle.testutils import make_local_function
-
- g = make_local_function()
- with open(pickled_func_path, 'wb') as f:
- cloudpickle.dump(g, f, protocol=self.protocol)
-
- assert_run_python_script(textwrap.dedent(child_process_script))
-
- finally:
- os.unlink(pickled_func_path)
-
- def test_dynamic_module_with_unpicklable_builtin(self):
- # Reproducer of https://github.com/cloudpipe/cloudpickle/issues/316
- # Some modules such as scipy inject some unpicklable objects into the
- # __builtins__ module, which appears in every module's __dict__ under
- # the '__builtins__' key. In such cases, cloudpickle used to fail
- # when pickling dynamic modules.
- class UnpickleableObject:
- def __reduce__(self):
- raise ValueError('Unpicklable object')
-
- mod = types.ModuleType("mod")
-
- exec('f = lambda x: abs(x)', mod.__dict__)
- assert mod.f(-1) == 1
- assert '__builtins__' in mod.__dict__
-
- unpicklable_obj = UnpickleableObject()
- with pytest.raises(ValueError):
- cloudpickle.dumps(unpicklable_obj)
-
- # Emulate the behavior of scipy by injecting an unpickleable object
- # into mod's builtins.
- # The __builtins__ entry of mod's __dict__ can either be the
- # __builtins__ module, or the __builtins__ module's __dict__. #316
- # happens only in the latter case.
- if isinstance(mod.__dict__['__builtins__'], dict):
- mod.__dict__['__builtins__']['unpickleable_obj'] = unpicklable_obj
- elif isinstance(mod.__dict__['__builtins__'], types.ModuleType):
- mod.__dict__['__builtins__'].unpickleable_obj = unpicklable_obj
-
- depickled_mod = pickle_depickle(mod, protocol=self.protocol)
- assert '__builtins__' in depickled_mod.__dict__
-
- if isinstance(depickled_mod.__dict__['__builtins__'], dict):
- assert "abs" in depickled_mod.__builtins__
- elif isinstance(
- depickled_mod.__dict__['__builtins__'], types.ModuleType):
- assert hasattr(depickled_mod.__builtins__, "abs")
- assert depickled_mod.f(-1) == 1
-
- # Additional check testing that the issue #425 is fixed: without the
- # fix for #425, `mod.f` would not have access to `__builtins__`, and
- # thus calling `mod.f(-1)` (which relies on the `abs` builtin) would
- # fail.
- assert mod.f(-1) == 1
-
- def test_load_dynamic_module_in_grandchild_process(self):
- # Make sure that when loaded, a dynamic module preserves its dynamic
- # property. Otherwise, this will lead to an ImportError if pickled in
- # the child process and reloaded in another one.
-
- # We create a new dynamic module
- mod = types.ModuleType('mod')
- code = '''
- x = 1
- '''
- exec(textwrap.dedent(code), mod.__dict__)
-
- # This script will be ran in a separate child process. It will import
- # the pickled dynamic module, and then re-pickle it under a new name.
- # Finally, it will create a child process that will load the re-pickled
- # dynamic module.
- parent_process_module_file = os.path.join(
- self.tmpdir, 'dynamic_module_from_parent_process.pkl')
- child_process_module_file = os.path.join(
- self.tmpdir, 'dynamic_module_from_child_process.pkl')
- child_process_script = '''
- from srsly.cloudpickle.compat import pickle
- import textwrap
-
- import srsly.cloudpickle as cloudpickle
- from srsly.tests.cloudpickle.testutils import assert_run_python_script
-
-
- child_of_child_process_script = {child_of_child_process_script}
-
- with open('{parent_process_module_file}', 'rb') as f:
- mod = pickle.load(f)
-
- with open('{child_process_module_file}', 'wb') as f:
- cloudpickle.dump(mod, f, protocol={protocol})
-
- assert_run_python_script(textwrap.dedent(child_of_child_process_script))
- '''
-
- # The script ran by the process created by the child process
- child_of_child_process_script = """ '''
- from srsly.cloudpickle.compat import pickle
- with open('{child_process_module_file}','rb') as fid:
- mod = pickle.load(fid)
- ''' """
-
- # Filling the two scripts with the pickled modules filepaths and,
- # for the first child process, the script to be executed by its
- # own child process.
- child_of_child_process_script = child_of_child_process_script.format(
- child_process_module_file=child_process_module_file)
-
- child_process_script = child_process_script.format(
- parent_process_module_file=_escape(parent_process_module_file),
- child_process_module_file=_escape(child_process_module_file),
- child_of_child_process_script=_escape(child_of_child_process_script),
- protocol=self.protocol)
-
- try:
- with open(parent_process_module_file, 'wb') as fid:
- cloudpickle.dump(mod, fid, protocol=self.protocol)
-
- assert_run_python_script(textwrap.dedent(child_process_script))
-
- finally:
- # Remove temporary created files
- if os.path.exists(parent_process_module_file):
- os.unlink(parent_process_module_file)
- if os.path.exists(child_process_module_file):
- os.unlink(child_process_module_file)
-
- def test_correct_globals_import(self):
- def nested_function(x):
- return x + 1
-
- def unwanted_function(x):
- return math.exp(x)
-
- def my_small_function(x, y):
- return nested_function(x) + y
-
- b = cloudpickle.dumps(my_small_function, protocol=self.protocol)
-
- # Make sure that the pickle byte string only includes the definition
- # of my_small_function and its dependency nested_function while
- # extra functions and modules such as unwanted_function and the math
- # module are not included so as to keep the pickle payload as
- # lightweight as possible.
-
- assert b'my_small_function' in b
- assert b'nested_function' in b
-
- assert b'unwanted_function' not in b
- assert b'math' not in b
-
- def test_module_importability(self):
- pytest.importorskip("_cloudpickle_testpkg")
- from srsly.cloudpickle.compat import pickle
- import os.path
- import collections
- import collections.abc
-
- assert _should_pickle_by_reference(pickle)
- assert _should_pickle_by_reference(os.path) # fake (aliased) module
- assert _should_pickle_by_reference(collections) # package
- assert _should_pickle_by_reference(collections.abc) # module in package
-
- dynamic_module = types.ModuleType('dynamic_module')
- assert not _should_pickle_by_reference(dynamic_module)
-
- if platform.python_implementation() == 'PyPy':
- import _codecs
- assert _should_pickle_by_reference(_codecs)
-
- # #354: Check that modules created dynamically during the import of
- # their parent modules are considered importable by cloudpickle.
- # See the mod_with_dynamic_submodule documentation for more
- # details of this use case.
- import _cloudpickle_testpkg.mod.dynamic_submodule as m
- assert _should_pickle_by_reference(m)
- assert pickle_depickle(m, protocol=self.protocol) is m
-
- # Check for similar behavior for a module that cannot be imported by
- # attribute lookup.
- from _cloudpickle_testpkg.mod import dynamic_submodule_two as m2
- # Note: import _cloudpickle_testpkg.mod.dynamic_submodule_two as m2
- # works only for Python 3.7+
- assert _should_pickle_by_reference(m2)
- assert pickle_depickle(m2, protocol=self.protocol) is m2
-
- # Submodule_three is a dynamic module only importable via module lookup
- with pytest.raises(ImportError):
- import _cloudpickle_testpkg.mod.submodule_three # noqa
- from _cloudpickle_testpkg.mod import submodule_three as m3
- assert not _should_pickle_by_reference(m3)
-
- # This module cannot be pickled using attribute lookup (as it does not
- # have a `__module__` attribute like classes and functions.
- assert not hasattr(m3, '__module__')
- depickled_m3 = pickle_depickle(m3, protocol=self.protocol)
- assert depickled_m3 is not m3
- assert m3.f(1) == depickled_m3.f(1)
-
- # Do the same for an importable dynamic submodule inside a dynamic
- # module inside a file-backed module.
- import _cloudpickle_testpkg.mod.dynamic_submodule.dynamic_subsubmodule as sm # noqa
- assert _should_pickle_by_reference(sm)
- assert pickle_depickle(sm, protocol=self.protocol) is sm
-
- expected = "cannot check importability of object instances"
- with pytest.raises(TypeError, match=expected):
- _should_pickle_by_reference(object())
-
- def test_Ellipsis(self):
- self.assertEqual(Ellipsis,
- pickle_depickle(Ellipsis, protocol=self.protocol))
-
- def test_NotImplemented(self):
- ExcClone = pickle_depickle(NotImplemented, protocol=self.protocol)
- self.assertEqual(NotImplemented, ExcClone)
-
- def test_NoneType(self):
- res = pickle_depickle(type(None), protocol=self.protocol)
- self.assertEqual(type(None), res)
-
- def test_EllipsisType(self):
- res = pickle_depickle(type(Ellipsis), protocol=self.protocol)
- self.assertEqual(type(Ellipsis), res)
-
- def test_NotImplementedType(self):
- res = pickle_depickle(type(NotImplemented), protocol=self.protocol)
- self.assertEqual(type(NotImplemented), res)
-
- def test_builtin_function(self):
- # Note that builtin_function_or_method are special-cased by cloudpickle
- # only in python2.
-
- # builtin function from the __builtin__ module
- assert pickle_depickle(zip, protocol=self.protocol) is zip
-
- from os import mkdir
- # builtin function from a "regular" module
- assert pickle_depickle(mkdir, protocol=self.protocol) is mkdir
-
- def test_builtin_type_constructor(self):
- # This test makes sure that cloudpickling builtin-type
- # constructors works for all python versions/implementation.
-
- # pickle_depickle some builtin methods of the __builtin__ module
- for t in list, tuple, set, frozenset, dict, object:
- cloned_new = pickle_depickle(t.__new__, protocol=self.protocol)
- assert isinstance(cloned_new(t), t)
-
- # The next 4 tests cover all cases into which builtin python methods can
- # appear.
- # There are 4 kinds of method: 'classic' methods, classmethods,
- # staticmethods and slotmethods. They will appear under different types
- # depending on whether they are called from the __dict__ of their
- # class, their class itself, or an instance of their class. This makes
- # 12 total combinations.
- # This discussion and the following tests are relevant for the CPython
- # implementation only. In PyPy, there is no builtin method or builtin
- # function types/flavours. The only way into which a builtin method can be
- # identified is with it's builtin-code __code__ attribute.
-
- def test_builtin_classicmethod(self):
- obj = 1.5 # float object
-
- bound_classicmethod = obj.hex # builtin_function_or_method
- unbound_classicmethod = type(obj).hex # method_descriptor
- clsdict_classicmethod = type(obj).__dict__['hex'] # method_descriptor
-
- assert unbound_classicmethod is clsdict_classicmethod
-
- depickled_bound_meth = pickle_depickle(
- bound_classicmethod, protocol=self.protocol)
- depickled_unbound_meth = pickle_depickle(
- unbound_classicmethod, protocol=self.protocol)
- depickled_clsdict_meth = pickle_depickle(
- clsdict_classicmethod, protocol=self.protocol)
-
- # No identity on the bound methods they are bound to different float
- # instances
- assert depickled_bound_meth() == bound_classicmethod()
- assert depickled_unbound_meth is unbound_classicmethod
- assert depickled_clsdict_meth is clsdict_classicmethod
-
-
- @pytest.mark.skipif(
- (platform.machine() == "aarch64" and sys.version_info[:2] >= (3, 10))
- or platform.python_implementation() == "PyPy"
- or (sys.version_info[:2] == (3, 10) and sys.version_info >= (3, 10, 8))
- # Skipping tests on 3.11 due to https://github.com/cloudpipe/cloudpickle/pull/486.
- or sys.version_info[:2] >= (3, 11),
- reason="Fails on aarch64 + python 3.10+ in cibuildwheel, currently unable to replicate failure elsewhere; fails sometimes for pypy on conda-forge; fails for python 3.10.8+ and 3.11+")
- def test_builtin_classmethod(self):
- obj = 1.5 # float object
-
- bound_clsmethod = obj.fromhex # builtin_function_or_method
- unbound_clsmethod = type(obj).fromhex # builtin_function_or_method
- clsdict_clsmethod = type(
- obj).__dict__['fromhex'] # classmethod_descriptor
-
- depickled_bound_meth = pickle_depickle(
- bound_clsmethod, protocol=self.protocol)
- depickled_unbound_meth = pickle_depickle(
- unbound_clsmethod, protocol=self.protocol)
- depickled_clsdict_meth = pickle_depickle(
- clsdict_clsmethod, protocol=self.protocol)
-
- # float.fromhex takes a string as input.
- arg = "0x1"
-
- # Identity on both the bound and the unbound methods cannot be
- # tested: the bound methods are bound to different objects, and the
- # unbound methods are actually recreated at each call.
- assert depickled_bound_meth(arg) == bound_clsmethod(arg)
- assert depickled_unbound_meth(arg) == unbound_clsmethod(arg)
-
- if platform.python_implementation() == 'CPython':
- # Roundtripping a classmethod_descriptor results in a
- # builtin_function_or_method (CPython upstream issue).
- assert depickled_clsdict_meth(arg) == clsdict_clsmethod(float, arg)
- if platform.python_implementation() == 'PyPy':
- # builtin-classmethods are simple classmethod in PyPy (not
- # callable). We test equality of types and the functionality of the
- # __func__ attribute instead. We do not test the the identity of
- # the functions as __func__ attributes of classmethods are not
- # pickleable and must be reconstructed at depickling time.
- assert type(depickled_clsdict_meth) == type(clsdict_clsmethod)
- assert depickled_clsdict_meth.__func__(
- float, arg) == clsdict_clsmethod.__func__(float, arg)
-
- def test_builtin_slotmethod(self):
- obj = 1.5 # float object
-
- bound_slotmethod = obj.__repr__ # method-wrapper
- unbound_slotmethod = type(obj).__repr__ # wrapper_descriptor
- clsdict_slotmethod = type(obj).__dict__['__repr__'] # ditto
-
- depickled_bound_meth = pickle_depickle(
- bound_slotmethod, protocol=self.protocol)
- depickled_unbound_meth = pickle_depickle(
- unbound_slotmethod, protocol=self.protocol)
- depickled_clsdict_meth = pickle_depickle(
- clsdict_slotmethod, protocol=self.protocol)
-
- # No identity tests on the bound slotmethod are they are bound to
- # different float instances
- assert depickled_bound_meth() == bound_slotmethod()
- assert depickled_unbound_meth is unbound_slotmethod
- assert depickled_clsdict_meth is clsdict_slotmethod
-
- @pytest.mark.skipif(
- platform.python_implementation() == "PyPy",
- reason="No known staticmethod example in the pypy stdlib")
- def test_builtin_staticmethod(self):
- obj = "foo" # str object
-
- bound_staticmethod = obj.maketrans # builtin_function_or_method
- unbound_staticmethod = type(obj).maketrans # ditto
- clsdict_staticmethod = type(obj).__dict__['maketrans'] # staticmethod
-
- assert bound_staticmethod is unbound_staticmethod
-
- depickled_bound_meth = pickle_depickle(
- bound_staticmethod, protocol=self.protocol)
- depickled_unbound_meth = pickle_depickle(
- unbound_staticmethod, protocol=self.protocol)
- depickled_clsdict_meth = pickle_depickle(
- clsdict_staticmethod, protocol=self.protocol)
-
- assert depickled_bound_meth is bound_staticmethod
- assert depickled_unbound_meth is unbound_staticmethod
-
- # staticmethod objects are recreated at depickling time, but the
- # underlying __func__ object is pickled by attribute.
- assert depickled_clsdict_meth.__func__ is clsdict_staticmethod.__func__
- type(depickled_clsdict_meth) is type(clsdict_staticmethod)
-
- @pytest.mark.skipif(tornado is None,
- reason="test needs Tornado installed")
- def test_tornado_coroutine(self):
- # Pickling a locally defined coroutine function
- from tornado import gen, ioloop
-
- @gen.coroutine
- def f(x, y):
- yield gen.sleep(x)
- raise gen.Return(y + 1)
-
- @gen.coroutine
- def g(y):
- res = yield f(0.01, y)
- raise gen.Return(res + 1)
-
- data = cloudpickle.dumps([g, g], protocol=self.protocol)
- f = g = None
- g2, g3 = pickle.loads(data)
- self.assertTrue(g2 is g3)
- loop = ioloop.IOLoop.current()
- res = loop.run_sync(functools.partial(g2, 5))
- self.assertEqual(res, 7)
-
- @pytest.mark.skipif(
- (3, 11, 0, 'beta') <= sys.version_info < (3, 11, 0, 'beta', 4),
- reason="https://github.com/python/cpython/issues/92932"
- )
- def test_extended_arg(self):
- # Functions with more than 65535 global vars prefix some global
- # variable references with the EXTENDED_ARG opcode.
- nvars = 65537 + 258
- names = ['g%d' % i for i in range(1, nvars)]
- r = random.Random(42)
- d = {name: r.randrange(100) for name in names}
- # def f(x):
- # x = g1, g2, ...
- # return zlib.crc32(bytes(bytearray(x)))
- code = """
- import zlib
-
- def f():
- x = {tup}
- return zlib.crc32(bytes(bytearray(x)))
- """.format(tup=', '.join(names))
- exec(textwrap.dedent(code), d, d)
- f = d['f']
- res = f()
- data = cloudpickle.dumps([f, f], protocol=self.protocol)
- d = f = None
- f2, f3 = pickle.loads(data)
- self.assertTrue(f2 is f3)
- self.assertEqual(f2(), res)
-
- def test_submodule(self):
- # Function that refers (by attribute) to a sub-module of a package.
-
- # Choose any module NOT imported by __init__ of its parent package
- # examples in standard library include:
- # - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree
-
- global xml # imitate performing this import at top of file
- import xml.etree.ElementTree
- def example():
- x = xml.etree.ElementTree.Comment # potential AttributeError
-
- s = cloudpickle.dumps(example, protocol=self.protocol)
-
- # refresh the environment, i.e., unimport the dependency
- del xml
- for item in list(sys.modules):
- if item.split('.')[0] == 'xml':
- del sys.modules[item]
-
- # deserialise
- f = pickle.loads(s)
- f() # perform test for error
-
- def test_submodule_closure(self):
- # Same as test_submodule except the package is not a global
- def scope():
- import xml.etree.ElementTree
- def example():
- x = xml.etree.ElementTree.Comment # potential AttributeError
- return example
- example = scope()
-
- s = cloudpickle.dumps(example, protocol=self.protocol)
-
- # refresh the environment (unimport dependency)
- for item in list(sys.modules):
- if item.split('.')[0] == 'xml':
- del sys.modules[item]
-
- f = cloudpickle.loads(s)
- f() # test
-
- def test_multiprocess(self):
- # running a function pickled by another process (a la dask.distributed)
- def scope():
- def example():
- x = xml.etree.ElementTree.Comment
- return example
- global xml
- import xml.etree.ElementTree
- example = scope()
-
- s = cloudpickle.dumps(example, protocol=self.protocol)
-
- # choose "subprocess" rather than "multiprocessing" because the latter
- # library uses fork to preserve the parent environment.
- command = ("import base64; "
- "from srsly.cloudpickle.compat import pickle; "
- "pickle.loads(base64.b32decode('" +
- base64.b32encode(s).decode('ascii') +
- "'))()")
- assert not subprocess.call([sys.executable, '-c', command])
-
- def test_import(self):
- # like test_multiprocess except subpackage modules referenced directly
- # (unlike test_submodule)
- global etree
- def scope():
- import xml.etree as foobar
- def example():
- x = etree.Comment
- x = foobar.ElementTree
- return example
- example = scope()
- import xml.etree.ElementTree as etree
-
- s = cloudpickle.dumps(example, protocol=self.protocol)
-
- command = ("import base64; "
- "from srsly.cloudpickle.compat import pickle; "
- "pickle.loads(base64.b32decode('" +
- base64.b32encode(s).decode('ascii') +
- "'))()")
- assert not subprocess.call([sys.executable, '-c', command])
-
- def test_multiprocessing_lock_raises(self):
- lock = multiprocessing.Lock()
- with pytest.raises(RuntimeError, match="only be shared between processes through inheritance"):
- cloudpickle.dumps(lock)
-
- def test_cell_manipulation(self):
- cell = _make_empty_cell()
-
- with pytest.raises(ValueError):
- cell.cell_contents
-
- ob = object()
- cell_set(cell, ob)
- self.assertTrue(
- cell.cell_contents is ob,
- msg='cell contents not set correctly',
- )
-
- def check_logger(self, name):
- logger = logging.getLogger(name)
- pickled = pickle_depickle(logger, protocol=self.protocol)
- self.assertTrue(pickled is logger, (pickled, logger))
-
- dumped = cloudpickle.dumps(logger)
-
- code = """if 1:
- import base64, srsly.cloudpickle as cloudpickle, logging
-
- logging.basicConfig(level=logging.INFO)
- logger = cloudpickle.loads(base64.b32decode(b'{}'))
- logger.info('hello')
- """.format(base64.b32encode(dumped).decode('ascii'))
- proc = subprocess.Popen([sys.executable, "-W ignore", "-c", code],
- stdout=subprocess.PIPE,
- stderr=subprocess.STDOUT)
- out, _ = proc.communicate()
- self.assertEqual(proc.wait(), 0)
- self.assertEqual(out.strip().decode(),
- f'INFO:{logger.name}:hello')
-
- def test_logger(self):
- # logging.RootLogger object
- self.check_logger(None)
- # logging.Logger object
- self.check_logger('cloudpickle.dummy_test_logger')
-
- def test_getset_descriptor(self):
- assert isinstance(float.real, types.GetSetDescriptorType)
- depickled_descriptor = pickle_depickle(float.real)
- self.assertIs(depickled_descriptor, float.real)
-
- def test_abc_cache_not_pickled(self):
- # cloudpickle issue #302: make sure that cloudpickle does not pickle
- # the caches populated during instance/subclass checks of abc.ABCMeta
- # instances.
- MyClass = abc.ABCMeta('MyClass', (), {})
-
- class MyUnrelatedClass:
- pass
-
- class MyRelatedClass:
- pass
-
- MyClass.register(MyRelatedClass)
-
- assert not issubclass(MyUnrelatedClass, MyClass)
- assert issubclass(MyRelatedClass, MyClass)
-
- s = cloudpickle.dumps(MyClass)
-
- assert b"MyUnrelatedClass" not in s
- assert b"MyRelatedClass" in s
-
- depickled_class = cloudpickle.loads(s)
- assert not issubclass(MyUnrelatedClass, depickled_class)
- assert issubclass(MyRelatedClass, depickled_class)
-
- def test_abc(self):
-
- class AbstractClass(abc.ABC):
- @abc.abstractmethod
- def some_method(self):
- """A method"""
-
- @classmethod
- @abc.abstractmethod
- def some_classmethod(cls):
- """A classmethod"""
-
- @staticmethod
- @abc.abstractmethod
- def some_staticmethod():
- """A staticmethod"""
-
- @property
- @abc.abstractmethod
- def some_property():
- """A property"""
-
- class ConcreteClass(AbstractClass):
- def some_method(self):
- return 'it works!'
-
- @classmethod
- def some_classmethod(cls):
- assert cls == ConcreteClass
- return 'it works!'
-
- @staticmethod
- def some_staticmethod():
- return 'it works!'
-
- @property
- def some_property(self):
- return 'it works!'
-
- # This abstract class is locally defined so we can safely register
- # tuple in it to verify the unpickled class also register tuple.
- AbstractClass.register(tuple)
-
- concrete_instance = ConcreteClass()
- depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol)
- depickled_class = pickle_depickle(ConcreteClass,
- protocol=self.protocol)
- depickled_instance = pickle_depickle(concrete_instance)
-
- assert issubclass(tuple, AbstractClass)
- assert issubclass(tuple, depickled_base)
-
- self.assertEqual(depickled_class().some_method(), 'it works!')
- self.assertEqual(depickled_instance.some_method(), 'it works!')
-
- self.assertEqual(depickled_class.some_classmethod(), 'it works!')
- self.assertEqual(depickled_instance.some_classmethod(), 'it works!')
-
- self.assertEqual(depickled_class().some_staticmethod(), 'it works!')
- self.assertEqual(depickled_instance.some_staticmethod(), 'it works!')
-
- self.assertEqual(depickled_class().some_property, 'it works!')
- self.assertEqual(depickled_instance.some_property, 'it works!')
- self.assertRaises(TypeError, depickled_base)
-
- class DepickledBaseSubclass(depickled_base):
- def some_method(self):
- return 'it works for realz!'
-
- @classmethod
- def some_classmethod(cls):
- assert cls == DepickledBaseSubclass
- return 'it works for realz!'
-
- @staticmethod
- def some_staticmethod():
- return 'it works for realz!'
-
- @property
- def some_property():
- return 'it works for realz!'
-
- self.assertEqual(DepickledBaseSubclass().some_method(),
- 'it works for realz!')
-
- class IncompleteBaseSubclass(depickled_base):
- def some_method(self):
- return 'this class lacks some concrete methods'
-
- self.assertRaises(TypeError, IncompleteBaseSubclass)
-
- def test_abstracts(self):
- # Same as `test_abc` but using deprecated `abc.abstract*` methods.
- # See https://github.com/cloudpipe/cloudpickle/issues/367
-
- class AbstractClass(abc.ABC):
- @abc.abstractmethod
- def some_method(self):
- """A method"""
-
- @abc.abstractclassmethod
- def some_classmethod(cls):
- """A classmethod"""
-
- @abc.abstractstaticmethod
- def some_staticmethod():
- """A staticmethod"""
-
- @abc.abstractproperty
- def some_property(self):
- """A property"""
-
- class ConcreteClass(AbstractClass):
- def some_method(self):
- return 'it works!'
-
- @classmethod
- def some_classmethod(cls):
- assert cls == ConcreteClass
- return 'it works!'
-
- @staticmethod
- def some_staticmethod():
- return 'it works!'
-
- @property
- def some_property(self):
- return 'it works!'
-
- # This abstract class is locally defined so we can safely register
- # tuple in it to verify the unpickled class also register tuple.
- AbstractClass.register(tuple)
-
- concrete_instance = ConcreteClass()
- depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol)
- depickled_class = pickle_depickle(ConcreteClass,
- protocol=self.protocol)
- depickled_instance = pickle_depickle(concrete_instance)
-
- assert issubclass(tuple, AbstractClass)
- assert issubclass(tuple, depickled_base)
-
- self.assertEqual(depickled_class().some_method(), 'it works!')
- self.assertEqual(depickled_instance.some_method(), 'it works!')
-
- self.assertEqual(depickled_class.some_classmethod(), 'it works!')
- self.assertEqual(depickled_instance.some_classmethod(), 'it works!')
-
- self.assertEqual(depickled_class().some_staticmethod(), 'it works!')
- self.assertEqual(depickled_instance.some_staticmethod(), 'it works!')
-
- self.assertEqual(depickled_class().some_property, 'it works!')
- self.assertEqual(depickled_instance.some_property, 'it works!')
- self.assertRaises(TypeError, depickled_base)
-
- class DepickledBaseSubclass(depickled_base):
- def some_method(self):
- return 'it works for realz!'
-
- @classmethod
- def some_classmethod(cls):
- assert cls == DepickledBaseSubclass
- return 'it works for realz!'
-
- @staticmethod
- def some_staticmethod():
- return 'it works for realz!'
-
- @property
- def some_property(self):
- return 'it works for realz!'
-
- self.assertEqual(DepickledBaseSubclass().some_method(),
- 'it works for realz!')
-
- class IncompleteBaseSubclass(depickled_base):
- def some_method(self):
- return 'this class lacks some concrete methods'
-
- self.assertRaises(TypeError, IncompleteBaseSubclass)
-
- def test_weakset_identity_preservation(self):
- # Test that weaksets don't lose all their inhabitants if they're
- # pickled in a larger data structure that includes other references to
- # their inhabitants.
-
- class SomeClass:
- def __init__(self, x):
- self.x = x
-
- obj1, obj2, obj3 = SomeClass(1), SomeClass(2), SomeClass(3)
-
- things = [weakref.WeakSet([obj1, obj2]), obj1, obj2, obj3]
- result = pickle_depickle(things, protocol=self.protocol)
-
- weakset, depickled1, depickled2, depickled3 = result
-
- self.assertEqual(depickled1.x, 1)
- self.assertEqual(depickled2.x, 2)
- self.assertEqual(depickled3.x, 3)
- self.assertEqual(len(weakset), 2)
-
- self.assertEqual(set(weakset), {depickled1, depickled2})
-
- def test_non_module_object_passing_whichmodule_test(self):
- # https://github.com/cloudpipe/cloudpickle/pull/326: cloudpickle should
- # not try to instrospect non-modules object when trying to discover the
- # module of a function/class. This happenened because codecov injects
- # tuples (and not modules) into sys.modules, but type-checks were not
- # carried out on the entries of sys.modules, causing cloupdickle to
- # then error in unexpected ways
- def func(x):
- return x ** 2
-
- # Trigger a loop during the execution of whichmodule(func) by
- # explicitly setting the function's module to None
- func.__module__ = None
-
- class NonModuleObject:
- def __ini__(self):
- self.some_attr = None
-
- def __getattr__(self, name):
- # We whitelist func so that a _whichmodule(func, None) call
- # returns the NonModuleObject instance if a type check on the
- # entries of sys.modules is not carried out, but manipulating
- # this instance thinking it really is a module later on in the
- # pickling process of func errors out
- if name == 'func':
- return func
- else:
- raise AttributeError
-
- non_module_object = NonModuleObject()
-
- assert func(2) == 4
- assert func is non_module_object.func
-
- # Any manipulation of non_module_object relying on attribute access
- # will raise an Exception
- with pytest.raises(AttributeError):
- _ = non_module_object.some_attr
-
- try:
- sys.modules['NonModuleObject'] = non_module_object
-
- func_module_name = _whichmodule(func, None)
- assert func_module_name != 'NonModuleObject'
- assert func_module_name is None
-
- depickled_func = pickle_depickle(func, protocol=self.protocol)
- assert depickled_func(2) == 4
-
- finally:
- sys.modules.pop('NonModuleObject')
-
- def test_unrelated_faulty_module(self):
- # Check that pickling a dynamically defined function or class does not
- # fail when introspecting the currently loaded modules in sys.modules
- # as long as those faulty modules are unrelated to the class or
- # function we are currently pickling.
- for base_class in (object, types.ModuleType):
- for module_name in ['_missing_module', None]:
- class FaultyModule(base_class):
- def __getattr__(self, name):
- # This throws an exception while looking up within
- # pickle.whichmodule or getattr(module, name, None)
- raise Exception()
-
- class Foo:
- __module__ = module_name
-
- def foo(self):
- return "it works!"
-
- def foo():
- return "it works!"
-
- foo.__module__ = module_name
-
- if base_class is types.ModuleType: # noqa
- faulty_module = FaultyModule('_faulty_module')
- else:
- faulty_module = FaultyModule()
- sys.modules["_faulty_module"] = faulty_module
-
- try:
- # Test whichmodule in save_global.
- self.assertEqual(pickle_depickle(Foo()).foo(), "it works!")
-
- # Test whichmodule in save_function.
- cloned = pickle_depickle(foo, protocol=self.protocol)
- self.assertEqual(cloned(), "it works!")
- finally:
- sys.modules.pop("_faulty_module", None)
-
- @pytest.mark.skip(reason="fails for pytest v7.2.0")
- def test_dynamic_pytest_module(self):
- # Test case for pull request https://github.com/cloudpipe/cloudpickle/pull/116
- import py
-
- def f():
- s = py.builtin.set([1])
- return s.pop()
-
- # some setup is required to allow pytest apimodules to be correctly
- # serializable.
- from srsly.cloudpickle import CloudPickler
- from srsly.cloudpickle import cloudpickle_fast as cp_fast
- CloudPickler.dispatch_table[type(py.builtin)] = cp_fast._module_reduce
-
- g = cloudpickle.loads(cloudpickle.dumps(f, protocol=self.protocol))
-
- result = g()
- self.assertEqual(1, result)
-
- def test_function_module_name(self):
- func = lambda x: x
- cloned = pickle_depickle(func, protocol=self.protocol)
- self.assertEqual(cloned.__module__, func.__module__)
-
- def test_function_qualname(self):
- def func(x):
- return x
- # Default __qualname__ attribute (Python 3 only)
- if hasattr(func, '__qualname__'):
- cloned = pickle_depickle(func, protocol=self.protocol)
- self.assertEqual(cloned.__qualname__, func.__qualname__)
-
- # Mutated __qualname__ attribute
- func.__qualname__ = ''
- cloned = pickle_depickle(func, protocol=self.protocol)
- self.assertEqual(cloned.__qualname__, func.__qualname__)
-
- def test_property(self):
- # Note that the @property decorator only has an effect on new-style
- # classes.
- class MyObject:
- _read_only_value = 1
- _read_write_value = 1
-
- @property
- def read_only_value(self):
- "A read-only attribute"
- return self._read_only_value
-
- @property
- def read_write_value(self):
- return self._read_write_value
-
- @read_write_value.setter
- def read_write_value(self, value):
- self._read_write_value = value
-
-
-
- my_object = MyObject()
-
- assert my_object.read_only_value == 1
- assert MyObject.read_only_value.__doc__ == "A read-only attribute"
-
- with pytest.raises(AttributeError):
- my_object.read_only_value = 2
- my_object.read_write_value = 2
-
- depickled_obj = pickle_depickle(my_object)
-
- assert depickled_obj.read_only_value == 1
- assert depickled_obj.read_write_value == 2
-
- # make sure the depickled read_only_value attribute is still read-only
- with pytest.raises(AttributeError):
- my_object.read_only_value = 2
-
- # make sure the depickled read_write_value attribute is writeable
- depickled_obj.read_write_value = 3
- assert depickled_obj.read_write_value == 3
- type(depickled_obj).read_only_value.__doc__ == "A read-only attribute"
-
-
- def test_namedtuple(self):
- MyTuple = collections.namedtuple('MyTuple', ['a', 'b', 'c'])
- t1 = MyTuple(1, 2, 3)
- t2 = MyTuple(3, 2, 1)
-
- depickled_t1, depickled_MyTuple, depickled_t2 = pickle_depickle(
- [t1, MyTuple, t2], protocol=self.protocol)
-
- assert isinstance(depickled_t1, MyTuple)
- assert depickled_t1 == t1
- assert depickled_MyTuple is MyTuple
- assert isinstance(depickled_t2, MyTuple)
- assert depickled_t2 == t2
-
- @pytest.mark.skipif(platform.python_implementation() == "PyPy",
- reason="fails sometimes for pypy on conda-forge")
- def test_interactively_defined_function(self):
- # Check that callables defined in the __main__ module of a Python
- # script (or jupyter kernel) can be pickled / unpickled / executed.
- code = """\
- from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo
-
- CONSTANT = 42
-
- class Foo(object):
-
- def method(self, x):
- return x
-
- foo = Foo()
-
- def f0(x):
- return x ** 2
-
- def f1():
- return Foo
-
- def f2(x):
- return Foo().method(x)
-
- def f3():
- return Foo().method(CONSTANT)
-
- def f4(x):
- return foo.method(x)
-
- def f5(x):
- # Recursive call to a dynamically defined function.
- if x <= 0:
- return f4(x)
- return f5(x - 1) + 1
-
- cloned = subprocess_pickle_echo(lambda x: x**2, protocol={protocol})
- assert cloned(3) == 9
-
- cloned = subprocess_pickle_echo(f0, protocol={protocol})
- assert cloned(3) == 9
-
- cloned = subprocess_pickle_echo(Foo, protocol={protocol})
- assert cloned().method(2) == Foo().method(2)
-
- cloned = subprocess_pickle_echo(Foo(), protocol={protocol})
- assert cloned.method(2) == Foo().method(2)
-
- cloned = subprocess_pickle_echo(f1, protocol={protocol})
- assert cloned()().method('a') == f1()().method('a')
-
- cloned = subprocess_pickle_echo(f2, protocol={protocol})
- assert cloned(2) == f2(2)
-
- cloned = subprocess_pickle_echo(f3, protocol={protocol})
- assert cloned() == f3()
-
- cloned = subprocess_pickle_echo(f4, protocol={protocol})
- assert cloned(2) == f4(2)
-
- cloned = subprocess_pickle_echo(f5, protocol={protocol})
- assert cloned(7) == f5(7) == 7
- """.format(protocol=self.protocol)
- assert_run_python_script(textwrap.dedent(code))
-
- def test_interactively_defined_global_variable(self):
- # Check that callables defined in the __main__ module of a Python
- # script (or jupyter kernel) correctly retrieve global variables.
- code_template = """\
- from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo
- from srsly.cloudpickle import dumps, loads
-
- def local_clone(obj, protocol=None):
- return loads(dumps(obj, protocol=protocol))
-
- VARIABLE = "default_value"
-
- def f0():
- global VARIABLE
- VARIABLE = "changed_by_f0"
-
- def f1():
- return VARIABLE
-
- assert f0.__globals__ is f1.__globals__
-
- # pickle f0 and f1 inside the same pickle_string
- cloned_f0, cloned_f1 = {clone_func}([f0, f1], protocol={protocol})
-
- # cloned_f0 and cloned_f1 now share a global namespace that is isolated
- # from any previously existing namespace
- assert cloned_f0.__globals__ is cloned_f1.__globals__
- assert cloned_f0.__globals__ is not f0.__globals__
-
- # pickle f1 another time, but in a new pickle string
- pickled_f1 = dumps(f1, protocol={protocol})
-
- # Change the value of the global variable in f0's new global namespace
- cloned_f0()
-
- # thanks to cloudpickle isolation, depickling and calling f0 and f1
- # should not affect the globals of already existing modules
- assert VARIABLE == "default_value", VARIABLE
-
- # Ensure that cloned_f1 and cloned_f0 share the same globals, as f1 and
- # f0 shared the same globals at pickling time, and cloned_f1 was
- # depickled from the same pickle string as cloned_f0
- shared_global_var = cloned_f1()
- assert shared_global_var == "changed_by_f0", shared_global_var
-
- # f1 is unpickled another time, but because it comes from another
- # pickle string than pickled_f1 and pickled_f0, it will not share the
- # same globals as the latter two.
- new_cloned_f1 = loads(pickled_f1)
- assert new_cloned_f1.__globals__ is not cloned_f1.__globals__
- assert new_cloned_f1.__globals__ is not f1.__globals__
-
- # get the value of new_cloned_f1's VARIABLE
- new_global_var = new_cloned_f1()
- assert new_global_var == "default_value", new_global_var
- """
- for clone_func in ['local_clone', 'subprocess_pickle_echo']:
- code = code_template.format(protocol=self.protocol,
- clone_func=clone_func)
- assert_run_python_script(textwrap.dedent(code))
-
- def test_closure_interacting_with_a_global_variable(self):
- global _TEST_GLOBAL_VARIABLE
- assert _TEST_GLOBAL_VARIABLE == "default_value"
- orig_value = _TEST_GLOBAL_VARIABLE
- try:
- def f0():
- global _TEST_GLOBAL_VARIABLE
- _TEST_GLOBAL_VARIABLE = "changed_by_f0"
-
- def f1():
- return _TEST_GLOBAL_VARIABLE
-
- # pickle f0 and f1 inside the same pickle_string
- cloned_f0, cloned_f1 = pickle_depickle([f0, f1],
- protocol=self.protocol)
-
- # cloned_f0 and cloned_f1 now share a global namespace that is
- # isolated from any previously existing namespace
- assert cloned_f0.__globals__ is cloned_f1.__globals__
- assert cloned_f0.__globals__ is not f0.__globals__
-
- # pickle f1 another time, but in a new pickle string
- pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol)
-
- # Change the global variable's value in f0's new global namespace
- cloned_f0()
-
- # depickling f0 and f1 should not affect the globals of already
- # existing modules
- assert _TEST_GLOBAL_VARIABLE == "default_value"
-
- # Ensure that cloned_f1 and cloned_f0 share the same globals, as f1
- # and f0 shared the same globals at pickling time, and cloned_f1
- # was depickled from the same pickle string as cloned_f0
- shared_global_var = cloned_f1()
- assert shared_global_var == "changed_by_f0", shared_global_var
-
- # f1 is unpickled another time, but because it comes from another
- # pickle string than pickled_f1 and pickled_f0, it will not share
- # the same globals as the latter two.
- new_cloned_f1 = pickle.loads(pickled_f1)
- assert new_cloned_f1.__globals__ is not cloned_f1.__globals__
- assert new_cloned_f1.__globals__ is not f1.__globals__
-
- # get the value of new_cloned_f1's VARIABLE
- new_global_var = new_cloned_f1()
- assert new_global_var == "default_value", new_global_var
- finally:
- _TEST_GLOBAL_VARIABLE = orig_value
-
- def test_interactive_remote_function_calls(self):
- code = """if __name__ == "__main__":
- from srsly.tests.cloudpickle.testutils import subprocess_worker
-
- def interactive_function(x):
- return x + 1
-
- with subprocess_worker(protocol={protocol}) as w:
-
- assert w.run(interactive_function, 41) == 42
-
- # Define a new function that will call an updated version of
- # the previously called function:
-
- def wrapper_func(x):
- return interactive_function(x)
-
- def interactive_function(x):
- return x - 1
-
- # The change in the definition of interactive_function in the main
- # module of the main process should be reflected transparently
- # in the worker process: the worker process does not recall the
- # previous definition of `interactive_function`:
-
- assert w.run(wrapper_func, 41) == 40
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- def test_interactive_remote_function_calls_no_side_effect(self):
- code = """if __name__ == "__main__":
- from srsly.tests.cloudpickle.testutils import subprocess_worker
- import sys
-
- with subprocess_worker(protocol={protocol}) as w:
-
- GLOBAL_VARIABLE = 0
-
- class CustomClass(object):
-
- def mutate_globals(self):
- global GLOBAL_VARIABLE
- GLOBAL_VARIABLE += 1
- return GLOBAL_VARIABLE
-
- custom_object = CustomClass()
- assert w.run(custom_object.mutate_globals) == 1
-
- # The caller global variable is unchanged in the main process.
-
- assert GLOBAL_VARIABLE == 0
-
- # Calling the same function again starts again from zero. The
- # worker process is stateless: it has no memory of the past call:
-
- assert w.run(custom_object.mutate_globals) == 1
-
- # The symbols defined in the main process __main__ module are
- # not set in the worker process main module to leave the worker
- # as stateless as possible:
-
- def is_in_main(name):
- return hasattr(sys.modules["__main__"], name)
-
- assert is_in_main("CustomClass")
- assert not w.run(is_in_main, "CustomClass")
-
- assert is_in_main("GLOBAL_VARIABLE")
- assert not w.run(is_in_main, "GLOBAL_VARIABLE")
-
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- def test_interactive_dynamic_type_and_remote_instances(self):
- code = """if __name__ == "__main__":
- from srsly.tests.cloudpickle.testutils import subprocess_worker
-
- with subprocess_worker(protocol={protocol}) as w:
-
- class CustomCounter:
- def __init__(self):
- self.count = 0
- def increment(self):
- self.count += 1
- return self
-
- counter = CustomCounter().increment()
- assert counter.count == 1
-
- returned_counter = w.run(counter.increment)
- assert returned_counter.count == 2, returned_counter.count
-
- # Check that the class definition of the returned instance was
- # matched back to the original class definition living in __main__.
-
- assert isinstance(returned_counter, CustomCounter)
-
- # Check that memoization does not break provenance tracking:
-
- def echo(*args):
- return args
-
- C1, C2, c1, c2 = w.run(echo, CustomCounter, CustomCounter,
- CustomCounter(), returned_counter)
- assert C1 is CustomCounter
- assert C2 is CustomCounter
- assert isinstance(c1, CustomCounter)
- assert isinstance(c2, CustomCounter)
-
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- def test_interactive_dynamic_type_and_stored_remote_instances(self):
- """Simulate objects stored on workers to check isinstance semantics
-
- Such instances stored in the memory of running worker processes are
- similar to dask-distributed futures for instance.
- """
- code = """if __name__ == "__main__":
- import srsly.cloudpickle as cloudpickle, uuid
- from srsly.tests.cloudpickle.testutils import subprocess_worker
-
- with subprocess_worker(protocol={protocol}) as w:
-
- class A:
- '''Original class definition'''
- pass
-
- def store(x):
- storage = getattr(cloudpickle, "_test_storage", None)
- if storage is None:
- storage = cloudpickle._test_storage = dict()
- obj_id = uuid.uuid4().hex
- storage[obj_id] = x
- return obj_id
-
- def lookup(obj_id):
- return cloudpickle._test_storage[obj_id]
-
- id1 = w.run(store, A())
-
- # The stored object on the worker is matched to a singleton class
- # definition thanks to provenance tracking:
- assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1)
-
- # Retrieving the object from the worker yields a local copy that
- # is matched back the local class definition this instance
- # originally stems from.
- assert isinstance(w.run(lookup, id1), A)
-
- # Changing the local class definition should be taken into account
- # in all subsequent calls. In particular the old instances on the
- # worker do not map back to the new class definition, neither on
- # the worker itself, nor locally on the main program when the old
- # instance is retrieved:
-
- class A:
- '''Updated class definition'''
- pass
-
- assert not w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1)
- retrieved1 = w.run(lookup, id1)
- assert not isinstance(retrieved1, A)
- assert retrieved1.__class__ is not A
- assert retrieved1.__class__.__doc__ == "Original class definition"
-
- # New instances on the other hand are proper instances of the new
- # class definition everywhere:
-
- a = A()
- id2 = w.run(store, a)
- assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id2)
- assert isinstance(w.run(lookup, id2), A)
-
- # Monkeypatch the class defintion in the main process to a new
- # class method:
- A.echo = lambda cls, x: x
-
- # Calling this method on an instance will automatically update
- # the remote class definition on the worker to propagate the monkey
- # patch dynamically.
- assert w.run(a.echo, 42) == 42
-
- # The stored instance can therefore also access the new class
- # method:
- assert w.run(lambda obj_id: lookup(obj_id).echo(43), id2) == 43
-
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- @pytest.mark.skip(reason="Seems to have issues outside of linux and CPython")
- def test_interactive_remote_function_calls_no_memory_leak(self):
- code = """if __name__ == "__main__":
- from srsly.tests.cloudpickle.testutils import subprocess_worker
- import struct
-
- with subprocess_worker(protocol={protocol}) as w:
-
- reference_size = w.memsize()
- assert reference_size > 0
-
-
- def make_big_closure(i):
- # Generate a byte string of size 1MB
- itemsize = len(struct.pack("l", 1))
- data = struct.pack("l", i) * (int(1e6) // itemsize)
- def process_data():
- return len(data)
- return process_data
-
- for i in range(100):
- func = make_big_closure(i)
- result = w.run(func)
- assert result == int(1e6), result
-
- import gc
- w.run(gc.collect)
-
- # By this time the worker process has processed 100MB worth of data
- # passed in the closures. The worker memory size should not have
- # grown by more than a few MB as closures are garbage collected at
- # the end of each remote function call.
- growth = w.memsize() - reference_size
-
- # For some reason, the memory growth after processing 100MB of
- # data is ~10MB on MacOS, and ~1MB on Linux, so the upper bound on
- # memory growth we use is only tight for MacOS. However,
- # - 10MB is still 10x lower than the expected memory growth in case
- # of a leak (which would be the total size of the processed data,
- # 100MB)
- # - the memory usage growth does not increase if using 10000
- # iterations instead of 100 as used now (100x more data)
- assert growth < 1.5e7, growth
-
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- def test_pickle_reraise(self):
- for exc_type in [Exception, ValueError, TypeError, RuntimeError]:
- obj = RaiserOnPickle(exc_type("foo"))
- with pytest.raises((exc_type, pickle.PicklingError)):
- cloudpickle.dumps(obj, protocol=self.protocol)
-
- def test_unhashable_function(self):
- d = {'a': 1}
- depickled_method = pickle_depickle(d.get, protocol=self.protocol)
- self.assertEqual(depickled_method('a'), 1)
- self.assertEqual(depickled_method('b'), None)
-
- @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Deprecation warning in python 3.12 about future deprecation in python 3.14")
- def test_itertools_count(self):
- counter = itertools.count(1, step=2)
-
- # advance the counter a bit
- next(counter)
- next(counter)
-
- new_counter = pickle_depickle(counter, protocol=self.protocol)
-
- self.assertTrue(counter is not new_counter)
-
- for _ in range(10):
- self.assertEqual(next(counter), next(new_counter))
-
- def test_wraps_preserves_function_name(self):
- from functools import wraps
-
- def f():
- pass
-
- @wraps(f)
- def g():
- f()
-
- f2 = pickle_depickle(g, protocol=self.protocol)
-
- self.assertEqual(f2.__name__, f.__name__)
-
- def test_wraps_preserves_function_doc(self):
- from functools import wraps
-
- def f():
- """42"""
- pass
-
- @wraps(f)
- def g():
- f()
-
- f2 = pickle_depickle(g, protocol=self.protocol)
-
- self.assertEqual(f2.__doc__, f.__doc__)
-
- def test_wraps_preserves_function_annotations(self):
- def f(x):
- pass
-
- f.__annotations__ = {'x': 1, 'return': float}
-
- @wraps(f)
- def g(x):
- f(x)
-
- f2 = pickle_depickle(g, protocol=self.protocol)
-
- self.assertEqual(f2.__annotations__, f.__annotations__)
-
- def test_type_hint(self):
- t = typing.Union[list, int]
- assert pickle_depickle(t) == t
-
- def test_instance_with_slots(self):
- for slots in [["registered_attribute"], "registered_attribute"]:
- class ClassWithSlots:
- __slots__ = slots
-
- def __init__(self):
- self.registered_attribute = 42
-
- initial_obj = ClassWithSlots()
- depickled_obj = pickle_depickle(
- initial_obj, protocol=self.protocol)
-
- for obj in [initial_obj, depickled_obj]:
- self.assertEqual(obj.registered_attribute, 42)
- with pytest.raises(AttributeError):
- obj.non_registered_attribute = 1
-
- class SubclassWithSlots(ClassWithSlots):
- def __init__(self):
- self.unregistered_attribute = 1
-
- obj = SubclassWithSlots()
- s = cloudpickle.dumps(obj, protocol=self.protocol)
- del SubclassWithSlots
- depickled_obj = cloudpickle.loads(s)
- assert depickled_obj.unregistered_attribute == 1
-
-
- @unittest.skipIf(not hasattr(types, "MappingProxyType"),
- "Old versions of Python do not have this type.")
- def test_mappingproxy(self):
- mp = types.MappingProxyType({"some_key": "some value"})
- assert mp == pickle_depickle(mp, protocol=self.protocol)
-
- def test_dataclass(self):
- dataclasses = pytest.importorskip("dataclasses")
-
- DataClass = dataclasses.make_dataclass('DataClass', [('x', int)])
- data = DataClass(x=42)
-
- pickle_depickle(DataClass, protocol=self.protocol)
- assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42
-
- def test_locally_defined_enum(self):
- class StringEnum(str, enum.Enum):
- """Enum when all members are also (and must be) strings"""
-
- class Color(StringEnum):
- """3-element color space"""
- RED = "1"
- GREEN = "2"
- BLUE = "3"
-
- def is_green(self):
- return self is Color.GREEN
-
- green1, green2, ClonedColor = pickle_depickle(
- [Color.GREEN, Color.GREEN, Color], protocol=self.protocol)
- assert green1 is green2
- assert green1 is ClonedColor.GREEN
- assert green1 is not ClonedColor.BLUE
- assert isinstance(green1, str)
- assert green1.is_green()
-
- # cloudpickle systematically tracks provenance of class definitions
- # and ensure reconciliation in case of round trips:
- assert green1 is Color.GREEN
- assert ClonedColor is Color
-
- green3 = pickle_depickle(Color.GREEN, protocol=self.protocol)
- assert green3 is Color.GREEN
-
- def test_locally_defined_intenum(self):
- # Try again with a IntEnum defined with the functional API
- DynamicColor = enum.IntEnum("Color", {"RED": 1, "GREEN": 2, "BLUE": 3})
-
- green1, green2, ClonedDynamicColor = pickle_depickle(
- [DynamicColor.GREEN, DynamicColor.GREEN, DynamicColor],
- protocol=self.protocol)
-
- assert green1 is green2
- assert green1 is ClonedDynamicColor.GREEN
- assert green1 is not ClonedDynamicColor.BLUE
- assert ClonedDynamicColor is DynamicColor
-
- def test_interactively_defined_enum(self):
- code = """if __name__ == "__main__":
- from enum import Enum
- from srsly.tests.cloudpickle.testutils import subprocess_worker
-
- with subprocess_worker(protocol={protocol}) as w:
-
- class Color(Enum):
- RED = 1
- GREEN = 2
-
- def check_positive(x):
- return Color.GREEN if x >= 0 else Color.RED
-
- result = w.run(check_positive, 1)
-
- # Check that the returned enum instance is reconciled with the
- # locally defined Color enum type definition:
-
- assert result is Color.GREEN
-
- # Check that changing the definition of the Enum class is taken
- # into account on the worker for subsequent calls:
-
- class Color(Enum):
- RED = 1
- BLUE = 2
-
- def check_positive(x):
- return Color.BLUE if x >= 0 else Color.RED
-
- result = w.run(check_positive, 1)
- assert result is Color.BLUE
- """.format(protocol=self.protocol)
- assert_run_python_script(code)
-
- def test_relative_import_inside_function(self):
- pytest.importorskip("_cloudpickle_testpkg")
- # Make sure relative imports inside round-tripped functions is not
- # broken. This was a bug in cloudpickle versions <= 0.5.3 and was
- # re-introduced in 0.8.0.
- from _cloudpickle_testpkg import relative_imports_factory
- f, g = relative_imports_factory()
- for func, source in zip([f, g], ["module", "package"]):
- # Make sure relative imports are initially working
- assert func() == f"hello from a {source}!"
-
- # Make sure relative imports still work after round-tripping
- cloned_func = pickle_depickle(func, protocol=self.protocol)
- assert cloned_func() == f"hello from a {source}!"
-
- def test_interactively_defined_func_with_keyword_only_argument(self):
- # fixes https://github.com/cloudpipe/cloudpickle/issues/263
- def f(a, *, b=1):
- return a + b
-
- depickled_f = pickle_depickle(f, protocol=self.protocol)
-
- for func in (f, depickled_f):
- assert func(2) == 3
- assert func.__kwdefaults__ == {'b': 1}
-
- @pytest.mark.skipif(not hasattr(types.CodeType, "co_posonlyargcount"),
- reason="Requires positional-only argument syntax")
- def test_interactively_defined_func_with_positional_only_argument(self):
- # Fixes https://github.com/cloudpipe/cloudpickle/issues/266
- # The source code of this test is bundled in a string and is ran from
- # the __main__ module of a subprocess in order to avoid a SyntaxError
- # in versions of python that do not support positional-only argument
- # syntax.
- code = """
- import pytest
- from srsly.cloudpickle import loads, dumps
-
- def f(a, /, b=1):
- return a + b
-
- depickled_f = loads(dumps(f, protocol={protocol}))
-
- for func in (f, depickled_f):
- assert func(2) == 3
- assert func.__code__.co_posonlyargcount == 1
- with pytest.raises(TypeError):
- func(a=2)
-
- """.format(protocol=self.protocol)
- assert_run_python_script(textwrap.dedent(code))
-
- def test___reduce___returns_string(self):
- # Non regression test for objects with a __reduce__ method returning a
- # string, meaning "save by attribute using save_global"
- pytest.importorskip("_cloudpickle_testpkg")
- from _cloudpickle_testpkg import some_singleton
- assert some_singleton.__reduce__() == "some_singleton"
- depickled_singleton = pickle_depickle(
- some_singleton, protocol=self.protocol)
- assert depickled_singleton is some_singleton
-
- def test_cloudpickle_extract_nested_globals(self):
- def function_factory():
- def inner_function():
- global _TEST_GLOBAL_VARIABLE
- return _TEST_GLOBAL_VARIABLE
- return inner_function
-
- globals_ = set(cloudpickle.cloudpickle._extract_code_globals(
- function_factory.__code__).keys())
- assert globals_ == {'_TEST_GLOBAL_VARIABLE'}
-
- depickled_factory = pickle_depickle(function_factory,
- protocol=self.protocol)
- inner_func = depickled_factory()
- assert inner_func() == _TEST_GLOBAL_VARIABLE
-
- def test_recursion_during_pickling(self):
- class A:
- def __getattribute__(self, name):
- return getattr(self, name)
-
- a = A()
- with pytest.raises(pickle.PicklingError, match='recursion'):
- cloudpickle.dumps(a)
-
- def test_out_of_band_buffers(self):
- if self.protocol < 5:
- pytest.skip("Need Pickle Protocol 5 or later")
- np = pytest.importorskip("numpy")
-
- class LocallyDefinedClass:
- data = np.zeros(10)
-
- data_instance = LocallyDefinedClass()
- buffers = []
- pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol,
- buffer_callback=buffers.append)
- assert len(buffers) == 1
- reconstructed = pickle.loads(pickle_bytes, buffers=buffers)
- np.testing.assert_allclose(reconstructed.data, data_instance.data)
-
- def test_pickle_dynamic_typevar(self):
- T = typing.TypeVar('T')
- depickled_T = pickle_depickle(T, protocol=self.protocol)
- attr_list = [
- "__name__", "__bound__", "__constraints__", "__covariant__",
- "__contravariant__"
- ]
- for attr in attr_list:
- assert getattr(T, attr) == getattr(depickled_T, attr)
-
- def test_pickle_dynamic_typevar_tracking(self):
- T = typing.TypeVar("T")
- T2 = subprocess_pickle_echo(T, protocol=self.protocol)
- assert T is T2
-
- def test_pickle_dynamic_typevar_memoization(self):
- T = typing.TypeVar('T')
- depickled_T1, depickled_T2 = pickle_depickle((T, T),
- protocol=self.protocol)
- assert depickled_T1 is depickled_T2
-
- def test_pickle_importable_typevar(self):
- pytest.importorskip("_cloudpickle_testpkg")
- from _cloudpickle_testpkg import T
- T1 = pickle_depickle(T, protocol=self.protocol)
- assert T1 is T
-
- # Standard Library TypeVar
- from typing import AnyStr
- assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)
-
- def test_generic_type(self):
- T = typing.TypeVar('T')
-
- class C(typing.Generic[T]):
- pass
-
- assert pickle_depickle(C, protocol=self.protocol) is C
-
- # Identity is not part of the typing contract: only test for
- # equality instead.
- assert pickle_depickle(C[int], protocol=self.protocol) == C[int]
-
- with subprocess_worker(protocol=self.protocol) as worker:
-
- def check_generic(generic, origin, type_value, use_args):
- assert generic.__origin__ is origin
-
- assert len(origin.__orig_bases__) == 1
- ob = origin.__orig_bases__[0]
- assert ob.__origin__ is typing.Generic
-
- if use_args:
- assert len(generic.__args__) == 1
- assert generic.__args__[0] is type_value
- else:
- assert len(generic.__parameters__) == 1
- assert generic.__parameters__[0] is type_value
- assert len(ob.__parameters__) == 1
-
- return "ok"
-
- # backward-compat for old Python 3.5 versions that sometimes relies
- # on __parameters__
- use_args = getattr(C[int], '__args__', ()) != ()
- assert check_generic(C[int], C, int, use_args) == "ok"
- assert worker.run(check_generic, C[int], C, int, use_args) == "ok"
-
- def test_generic_subclass(self):
- T = typing.TypeVar('T')
-
- class Base(typing.Generic[T]):
- pass
-
- class DerivedAny(Base):
- pass
-
- class LeafAny(DerivedAny):
- pass
-
- class DerivedInt(Base[int]):
- pass
-
- class LeafInt(DerivedInt):
- pass
-
- class DerivedT(Base[T]):
- pass
-
- class LeafT(DerivedT[T]):
- pass
-
- klasses = [
- Base, DerivedAny, LeafAny, DerivedInt, LeafInt, DerivedT, LeafT
- ]
- for klass in klasses:
- assert pickle_depickle(klass, protocol=self.protocol) is klass
-
- with subprocess_worker(protocol=self.protocol) as worker:
-
- def check_mro(klass, expected_mro):
- assert klass.mro() == expected_mro
- return "ok"
-
- for klass in klasses:
- mro = klass.mro()
- assert check_mro(klass, mro)
- assert worker.run(check_mro, klass, mro) == "ok"
-
- def test_locally_defined_class_with_type_hints(self):
- with subprocess_worker(protocol=self.protocol) as worker:
- for type_ in _all_types_to_test():
- class MyClass:
- def method(self, arg: type_) -> type_:
- return arg
- MyClass.__annotations__ = {'attribute': type_}
-
- def check_annotations(obj, expected_type, expected_type_str):
- assert obj.__annotations__["attribute"] == expected_type
- assert (
- obj.method.__annotations__["arg"] == expected_type
- )
- assert (
- obj.method.__annotations__["return"]
- == expected_type
- )
- return "ok"
-
- obj = MyClass()
- assert check_annotations(obj, type_, "type_") == "ok"
- assert (
- worker.run(check_annotations, obj, type_, "type_") == "ok"
- )
-
- def test_generic_extensions_literal(self):
- typing_extensions = pytest.importorskip('typing_extensions')
- for obj in [typing_extensions.Literal, typing_extensions.Literal['a']]:
- depickled_obj = pickle_depickle(obj, protocol=self.protocol)
- assert depickled_obj == obj
-
- def test_generic_extensions_final(self):
- typing_extensions = pytest.importorskip('typing_extensions')
- for obj in [typing_extensions.Final, typing_extensions.Final[int]]:
- depickled_obj = pickle_depickle(obj, protocol=self.protocol)
- assert depickled_obj == obj
-
- def test_class_annotations(self):
- class C:
- pass
- C.__annotations__ = {'a': int}
-
- C1 = pickle_depickle(C, protocol=self.protocol)
- assert C1.__annotations__ == C.__annotations__
-
- def test_function_annotations(self):
- def f(a: int) -> str:
- pass
-
- f1 = pickle_depickle(f, protocol=self.protocol)
- assert f1.__annotations__ == f.__annotations__
-
- def test_always_use_up_to_date_copyreg(self):
- # test that updates of copyreg.dispatch_table are taken in account by
- # cloudpickle
- import copyreg
- try:
- class MyClass:
- pass
-
- def reduce_myclass(x):
- return MyClass, (), {'custom_reduce': True}
-
- copyreg.dispatch_table[MyClass] = reduce_myclass
- my_obj = MyClass()
- depickled_myobj = pickle_depickle(my_obj, protocol=self.protocol)
- assert hasattr(depickled_myobj, 'custom_reduce')
- finally:
- copyreg.dispatch_table.pop(MyClass)
-
- def test_literal_misdetection(self):
- # see https://github.com/cloudpipe/cloudpickle/issues/403
- class MyClass:
- @property
- def __values__(self):
- return ()
-
- o = MyClass()
- pickle_depickle(o, protocol=self.protocol)
-
- def test_final_or_classvar_misdetection(self):
- # see https://github.com/cloudpipe/cloudpickle/issues/403
- class MyClass:
- @property
- def __type__(self):
- return int
-
- o = MyClass()
- pickle_depickle(o, protocol=self.protocol)
-
- @pytest.mark.skip(reason="Requires pytest -s to pass")
- def test_pickle_constructs_from_module_registered_for_pickling_by_value(self): # noqa
- _prev_sys_path = sys.path.copy()
- try:
- # We simulate an interactive session that:
- # - we start from the /path/to/cloudpickle/tests directory, where a
- # local .py file (mock_local_file) is located.
- # - uses constructs from mock_local_file in remote workers that do
- # not have access to this file. This situation is
- # the justification behind the
- # (un)register_pickle_by_value(module) api that cloudpickle
- # exposes.
- _mock_interactive_session_cwd = os.path.dirname(__file__)
-
- # First, remove sys.path entries that could point to
- # /path/to/cloudpickle/tests and be in inherited by the worker
- _maybe_remove(sys.path, '')
- _maybe_remove(sys.path, _mock_interactive_session_cwd)
-
- # Add the desired session working directory
- sys.path.insert(0, _mock_interactive_session_cwd)
-
- with subprocess_worker(protocol=self.protocol) as w:
- # Make the module unavailable in the remote worker
- w.run(
- lambda p: sys.path.remove(p), _mock_interactive_session_cwd
- )
- # Import the actual file after starting the module since the
- # worker is started using fork on Linux, which will inherits
- # the parent sys.modules. On Python>3.6, the worker can be
- # started using spawn using mp_context in ProcessPoolExectutor.
- # TODO Once Python 3.6 reaches end of life, rely on mp_context
- # instead.
- import mock_local_folder.mod as mod
- # The constructs whose pickling mechanism is changed using
- # register_pickle_by_value are functions, classes, TypeVar and
- # modules.
- from mock_local_folder.mod import (
- local_function, LocalT, LocalClass
- )
-
- # Make sure the module/constructs are unimportable in the
- # worker.
- with pytest.raises(ImportError):
- w.run(lambda: __import__("mock_local_folder.mod"))
- with pytest.raises(ImportError):
- w.run(
- lambda: __import__("mock_local_folder.subfolder.mod")
- )
-
- for o in [mod, local_function, LocalT, LocalClass]:
- with pytest.raises(ImportError):
- w.run(lambda: o)
-
- register_pickle_by_value(mod)
- # function
- assert w.run(lambda: local_function()) == local_function()
- # typevar
- assert w.run(lambda: LocalT.__name__) == LocalT.__name__
- # classes
- assert (
- w.run(lambda: LocalClass().method())
- == LocalClass().method()
- )
- # modules
- assert (
- w.run(lambda: mod.local_function()) == local_function()
- )
-
- # Constructs from modules inside subfolders should be pickled
- # by value if a namespace module pointing to some parent folder
- # was registered for pickling by value. A "mock_local_folder"
- # namespace module falls into that category, but a
- # "mock_local_folder.mod" one does not.
- from mock_local_folder.subfolder.submod import (
- LocalSubmodClass, LocalSubmodT, local_submod_function
- )
- # Shorter aliases to comply with line-length limits
- _t, _func, _class = (
- LocalSubmodT, local_submod_function, LocalSubmodClass
- )
- with pytest.raises(ImportError):
- w.run(
- lambda: __import__("mock_local_folder.subfolder.mod")
- )
- with pytest.raises(ImportError):
- w.run(lambda: local_submod_function)
-
- unregister_pickle_by_value(mod)
-
- with pytest.raises(ImportError):
- w.run(lambda: local_function)
-
- with pytest.raises(ImportError):
- w.run(lambda: __import__("mock_local_folder.mod"))
-
- # Test the namespace folder case
- import mock_local_folder
- register_pickle_by_value(mock_local_folder)
- assert w.run(lambda: local_function()) == local_function()
- assert w.run(lambda: _func()) == _func()
- unregister_pickle_by_value(mock_local_folder)
-
- with pytest.raises(ImportError):
- w.run(lambda: local_function)
- with pytest.raises(ImportError):
- w.run(lambda: local_submod_function)
-
- # Test the case of registering a single module inside a
- # subfolder.
- import mock_local_folder.subfolder.submod
- register_pickle_by_value(mock_local_folder.subfolder.submod)
- assert w.run(lambda: _func()) == _func()
- assert w.run(lambda: _t.__name__) == _t.__name__
- assert w.run(lambda: _class().method()) == _class().method()
-
- # Registering a module from a subfolder for pickling by value
- # should not make constructs from modules from the parent
- # folder pickleable
- with pytest.raises(ImportError):
- w.run(lambda: local_function)
- with pytest.raises(ImportError):
- w.run(lambda: __import__("mock_local_folder.mod"))
-
- unregister_pickle_by_value(
- mock_local_folder.subfolder.submod
- )
- with pytest.raises(ImportError):
- w.run(lambda: local_submod_function)
-
- # Test the subfolder namespace module case
- import mock_local_folder.subfolder
- register_pickle_by_value(mock_local_folder.subfolder)
- assert w.run(lambda: _func()) == _func()
- assert w.run(lambda: _t.__name__) == _t.__name__
- assert w.run(lambda: _class().method()) == _class().method()
-
- unregister_pickle_by_value(mock_local_folder.subfolder)
- finally:
- _fname = "mock_local_folder"
- sys.path = _prev_sys_path
- for m in [_fname, f"{_fname}.mod", f"{_fname}.subfolder",
- f"{_fname}.subfolder.submod"]:
- mod = sys.modules.pop(m, None)
- if mod and mod.__name__ in list_registry_pickle_by_value():
- unregister_pickle_by_value(mod)
-
- def test_pickle_constructs_from_installed_packages_registered_for_pickling_by_value( # noqa
- self
- ):
- pytest.importorskip("_cloudpickle_testpkg")
- for package_or_module in ["package", "module"]:
- if package_or_module == "package":
- import _cloudpickle_testpkg as m
- f = m.package_function_with_global
- _original_global = m.global_variable
- elif package_or_module == "module":
- import _cloudpickle_testpkg.mod as m
- f = m.module_function_with_global
- _original_global = m.global_variable
- try:
- with subprocess_worker(protocol=self.protocol) as w:
- assert w.run(lambda: f()) == _original_global
-
- # Test that f is pickled by value by modifying a global
- # variable that f uses, and making sure that this
- # modification shows up when calling the function remotely
- register_pickle_by_value(m)
- assert w.run(lambda: f()) == _original_global
- m.global_variable = "modified global"
- assert m.global_variable != _original_global
- assert w.run(lambda: f()) == "modified global"
- unregister_pickle_by_value(m)
- finally:
- m.global_variable = _original_global
- if m.__name__ in list_registry_pickle_by_value():
- unregister_pickle_by_value(m)
-
- def test_pickle_various_versions_of_the_same_function_with_different_pickling_method( # noqa
- self
- ):
- pytest.importorskip("_cloudpickle_testpkg")
- # Make sure that different versions of the same function (possibly
- # pickled in a different way - by value and/or by reference) can
- # peacefully co-exist (e.g. without globals interaction) in a remote
- # worker.
- import _cloudpickle_testpkg
- from _cloudpickle_testpkg import package_function_with_global as f
- _original_global = _cloudpickle_testpkg.global_variable
-
- def _create_registry():
- _main = __import__("sys").modules["__main__"]
- _main._cloudpickle_registry = {}
- # global _cloudpickle_registry
-
- def _add_to_registry(v, k):
- _main = __import__("sys").modules["__main__"]
- _main._cloudpickle_registry[k] = v
-
- def _call_from_registry(k):
- _main = __import__("sys").modules["__main__"]
- return _main._cloudpickle_registry[k]()
-
- try:
- with subprocess_worker(protocol=self.protocol) as w:
- w.run(_create_registry)
- w.run(_add_to_registry, f, "f_by_ref")
-
- register_pickle_by_value(_cloudpickle_testpkg)
- _cloudpickle_testpkg.global_variable = "modified global"
- w.run(_add_to_registry, f, "f_by_val")
- assert (
- w.run(_call_from_registry, "f_by_ref") == _original_global
- )
- assert (
- w.run(_call_from_registry, "f_by_val") == "modified global"
- )
-
- finally:
- _cloudpickle_testpkg.global_variable = _original_global
-
- if "_cloudpickle_testpkg" in list_registry_pickle_by_value():
- unregister_pickle_by_value(_cloudpickle_testpkg)
-
- @pytest.mark.skipif(
- sys.version_info < (3, 7),
- reason="Determinism can only be guaranteed for Python 3.7+"
- )
- def test_deterministic_pickle_bytes_for_function(self):
- # Ensure that functions with references to several global names are
- # pickled to fixed bytes that do not depend on the PYTHONHASHSEED of
- # the Python process.
- vals = set()
-
- def func_with_globals():
- return _TEST_GLOBAL_VARIABLE + _TEST_GLOBAL_VARIABLE2
-
- for i in range(5):
- vals.add(
- subprocess_pickle_string(func_with_globals,
- protocol=self.protocol,
- add_env={"PYTHONHASHSEED": str(i)}))
- if len(vals) > 1:
- # Print additional debug info on stdout with dis:
- for val in vals:
- pickletools.dis(val)
- pytest.fail(
- "Expected a single deterministic payload, got %d/5" % len(vals)
- )
-
-
-class Protocol2CloudPickleTest(CloudPickleTest):
-
- protocol = 2
-
-
-def test_lookup_module_and_qualname_dynamic_typevar():
- T = typing.TypeVar('T')
- module_and_name = _lookup_module_and_qualname(T, name=T.__name__)
- assert module_and_name is None
-
-
-def test_lookup_module_and_qualname_importable_typevar():
- pytest.importorskip("_cloudpickle_testpkg")
- import _cloudpickle_testpkg
- T = _cloudpickle_testpkg.T
- module_and_name = _lookup_module_and_qualname(T, name=T.__name__)
- assert module_and_name is not None
- module, name = module_and_name
- assert module is _cloudpickle_testpkg
- assert name == 'T'
-
-
-def test_lookup_module_and_qualname_stdlib_typevar():
- module_and_name = _lookup_module_and_qualname(typing.AnyStr,
- name=typing.AnyStr.__name__)
- assert module_and_name is not None
- module, name = module_and_name
- assert module is typing
- assert name == 'AnyStr'
-
-
-def test_register_pickle_by_value():
- pytest.importorskip("_cloudpickle_testpkg")
- import _cloudpickle_testpkg as pkg
- import _cloudpickle_testpkg.mod as mod
-
- assert list_registry_pickle_by_value() == set()
-
- register_pickle_by_value(pkg)
- assert list_registry_pickle_by_value() == {pkg.__name__}
-
- register_pickle_by_value(mod)
- assert list_registry_pickle_by_value() == {pkg.__name__, mod.__name__}
-
- unregister_pickle_by_value(mod)
- assert list_registry_pickle_by_value() == {pkg.__name__}
-
- msg = f"Input should be a module object, got {pkg.__name__} instead"
- with pytest.raises(ValueError, match=msg):
- unregister_pickle_by_value(pkg.__name__)
-
- unregister_pickle_by_value(pkg)
- assert list_registry_pickle_by_value() == set()
-
- msg = f"{pkg} is not registered for pickle by value"
- with pytest.raises(ValueError, match=re.escape(msg)):
- unregister_pickle_by_value(pkg)
-
- msg = f"Input should be a module object, got {pkg.__name__} instead"
- with pytest.raises(ValueError, match=msg):
- register_pickle_by_value(pkg.__name__)
-
- dynamic_mod = types.ModuleType('dynamic_mod')
- msg = (
- f"{dynamic_mod} was not imported correctly, have you used an "
- f"`import` statement to access it?"
- )
- with pytest.raises(ValueError, match=re.escape(msg)):
- register_pickle_by_value(dynamic_mod)
-
-
-def _all_types_to_test():
- T = typing.TypeVar('T')
-
- class C(typing.Generic[T]):
- pass
-
- types_to_test = [
- C, C[int],
- T, typing.Any, typing.Optional,
- typing.Generic, typing.Union,
- typing.Optional[int],
- typing.Generic[T],
- typing.Callable[[int], typing.Any],
- typing.Callable[..., typing.Any],
- typing.Callable[[], typing.Any],
- typing.Tuple[int, ...],
- typing.Tuple[int, C[int]],
- typing.List[int],
- typing.Dict[int, str],
- typing.ClassVar,
- typing.ClassVar[C[int]],
- typing.NoReturn,
- ]
- return types_to_test
-
-
-def test_module_level_pickler():
- # #366: cloudpickle should expose its pickle.Pickler subclass as
- # cloudpickle.Pickler
- assert hasattr(cloudpickle, "Pickler")
- assert cloudpickle.Pickler is cloudpickle.CloudPickler
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/srsly/tests/cloudpickle/mock_local_folder/mod.py b/srsly/tests/cloudpickle/mock_local_folder/mod.py
deleted file mode 100644
index 1a1c1da..0000000
--- a/srsly/tests/cloudpickle/mock_local_folder/mod.py
+++ /dev/null
@@ -1,20 +0,0 @@
-"""
-In the distributed computing setting, this file plays the role of a "local
-development" file, e.g. a file that is importable locally, but unimportable in
-remote workers. Constructs defined in this file and usually pickled by
-reference should instead flagged to cloudpickle for pickling by value: this is
-done using the register_pickle_by_value api exposed by cloudpickle.
-"""
-import typing
-
-
-def local_function():
- return "hello from a function importable locally!"
-
-
-class LocalClass:
- def method(self):
- return "hello from a class importable locally"
-
-
-LocalT = typing.TypeVar("LocalT")
diff --git a/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py b/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py
deleted file mode 100644
index deebc14..0000000
--- a/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import typing
-
-
-def local_submod_function():
- return "hello from a file located in a locally-importable subfolder!"
-
-
-class LocalSubmodClass:
- def method(self):
- return "hello from a class located in a locally-importable subfolder!"
-
-
-LocalSubmodT = typing.TypeVar("LocalSubmodT")
diff --git a/srsly/tests/cloudpickle/testutils.py b/srsly/tests/cloudpickle/testutils.py
deleted file mode 100644
index e0890b4..0000000
--- a/srsly/tests/cloudpickle/testutils.py
+++ /dev/null
@@ -1,217 +0,0 @@
-import sys
-import os
-import os.path as op
-import tempfile
-from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
-from srsly.cloudpickle.compat import pickle
-from contextlib import contextmanager
-from concurrent.futures import ProcessPoolExecutor
-
-import psutil
-from srsly.cloudpickle import dumps
-from subprocess import TimeoutExpired
-
-loads = pickle.loads
-TIMEOUT = 60
-TEST_GLOBALS = "a test value"
-
-
-def make_local_function():
- def g(x):
- # this function checks that the globals are correctly handled and that
- # the builtins are available
- assert TEST_GLOBALS == "a test value"
- return sum(range(10))
-
- return g
-
-
-def _make_cwd_env():
- """Helper to prepare environment for the child processes"""
- cloudpickle_repo_folder = op.normpath(
- op.join(op.dirname(__file__), '..'))
- env = os.environ.copy()
- pythonpath = "{src}{sep}tests{pathsep}{src}".format(
- src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep)
- env['PYTHONPATH'] = pythonpath
- return cloudpickle_repo_folder, env
-
-
-def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT,
- add_env=None):
- """Retrieve pickle string of an object generated by a child Python process
-
- Pickle the input data into a buffer, send it to a subprocess via
- stdin, expect the subprocess to unpickle, re-pickle that data back
- and send it back to the parent process via stdout for final unpickling.
-
- >>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
- b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'
-
- """
- # run then pickle_echo(protocol=protocol) in __main__:
-
- # Protect stderr from any warning, as we will assume an error will happen
- # if it is not empty. A concrete example is pytest using the imp module,
- # which is deprecated in python 3.8
- cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)]
- cwd, env = _make_cwd_env()
- if add_env:
- env.update(add_env)
- proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env,
- bufsize=4096)
- pickle_string = dumps(input_data, protocol=protocol)
- try:
- comm_kwargs = {}
- comm_kwargs['timeout'] = timeout
- out, err = proc.communicate(pickle_string, **comm_kwargs)
- if proc.returncode != 0 or len(err):
- message = "Subprocess returned %d: " % proc.returncode
- message += err.decode('utf-8')
- raise RuntimeError(message)
- return out
- except TimeoutExpired as e:
- proc.kill()
- out, err = proc.communicate()
- message = "\n".join([out.decode('utf-8'), err.decode('utf-8')])
- raise RuntimeError(message) from e
-
-
-def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT,
- add_env=None):
- """Echo function with a child Python process
- Pickle the input data into a buffer, send it to a subprocess via
- stdin, expect the subprocess to unpickle, re-pickle that data back
- and send it back to the parent process via stdout for final unpickling.
- >>> subprocess_pickle_echo([1, 'a', None])
- [1, 'a', None]
- """
- out = subprocess_pickle_string(input_data,
- protocol=protocol,
- timeout=timeout,
- add_env=add_env)
- return loads(out)
-
-
-def _read_all_bytes(stream_in, chunk_size=4096):
- all_data = b""
- while True:
- data = stream_in.read(chunk_size)
- all_data += data
- if len(data) < chunk_size:
- break
- return all_data
-
-
-def pickle_echo(stream_in=None, stream_out=None, protocol=None):
- """Read a pickle from stdin and pickle it back to stdout"""
- if stream_in is None:
- stream_in = sys.stdin
- if stream_out is None:
- stream_out = sys.stdout
-
- # Force the use of bytes streams under Python 3
- if hasattr(stream_in, 'buffer'):
- stream_in = stream_in.buffer
- if hasattr(stream_out, 'buffer'):
- stream_out = stream_out.buffer
-
- input_bytes = _read_all_bytes(stream_in)
- stream_in.close()
- obj = loads(input_bytes)
- repickled_bytes = dumps(obj, protocol=protocol)
- stream_out.write(repickled_bytes)
- stream_out.close()
-
-
-def call_func(payload, protocol):
- """Remote function call that uses cloudpickle to transport everthing"""
- func, args, kwargs = loads(payload)
- try:
- result = func(*args, **kwargs)
- except BaseException as e:
- result = e
- return dumps(result, protocol=protocol)
-
-
-class _Worker:
- def __init__(self, protocol=None):
- self.protocol = protocol
- self.pool = ProcessPoolExecutor(max_workers=1)
- self.pool.submit(id, 42).result() # start the worker process
-
- def run(self, func, *args, **kwargs):
- """Synchronous remote function call"""
-
- input_payload = dumps((func, args, kwargs), protocol=self.protocol)
- result_payload = self.pool.submit(
- call_func, input_payload, self.protocol).result()
- result = loads(result_payload)
-
- if isinstance(result, BaseException):
- raise result
- return result
-
- def memsize(self):
- workers_pids = [p.pid if hasattr(p, "pid") else p
- for p in list(self.pool._processes)]
- num_workers = len(workers_pids)
- if num_workers == 0:
- return 0
- elif num_workers > 1:
- raise RuntimeError("Unexpected number of workers: %d"
- % num_workers)
- return psutil.Process(workers_pids[0]).memory_info().rss
-
- def close(self):
- self.pool.shutdown(wait=True)
-
-
-@contextmanager
-def subprocess_worker(protocol=None):
- worker = _Worker(protocol=protocol)
- yield worker
- worker.close()
-
-
-def assert_run_python_script(source_code, timeout=TIMEOUT):
- """Utility to help check pickleability of objects defined in __main__
-
- The script provided in the source code should return 0 and not print
- anything on stderr or stdout.
- """
- fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
- os.close(fd)
- try:
- with open(source_file, 'wb') as f:
- f.write(source_code.encode('utf-8'))
- cmd = [sys.executable, '-W ignore', source_file]
- cwd, env = _make_cwd_env()
- kwargs = {
- 'cwd': cwd,
- 'stderr': STDOUT,
- 'env': env,
- }
- # If coverage is running, pass the config file to the subprocess
- coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
- if coverage_rc:
- kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc
- kwargs['timeout'] = timeout
- try:
- try:
- out = check_output(cmd, **kwargs)
- except CalledProcessError as e:
- raise RuntimeError("script errored with output:\n%s"
- % e.output.decode('utf-8')) from e
- if out != b"":
- raise AssertionError(out.decode('utf-8'))
- except TimeoutExpired as e:
- raise RuntimeError("script timeout, output so far:\n%s"
- % e.output.decode('utf-8')) from e
- finally:
- os.unlink(source_file)
-
-
-if __name__ == '__main__':
- protocol = int(sys.argv[sys.argv.index('--protocol') + 1])
- pickle_echo(protocol=protocol)
diff --git a/srsly/tests/msgpack/__init__.py b/srsly/tests/msgpack/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/srsly/tests/msgpack/test_buffer.py b/srsly/tests/msgpack/test_buffer.py
deleted file mode 100644
index db7b213..0000000
--- a/srsly/tests/msgpack/test_buffer.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from srsly.msgpack import packb, unpackb
-
-
-def test_unpack_buffer():
- from array import array
-
- buf = array("b")
- buf.frombytes(packb((b"foo", b"bar")))
- obj = unpackb(buf, use_list=1)
- assert [b"foo", b"bar"] == obj
-
-
-def test_unpack_bytearray():
- buf = bytearray(packb(("foo", "bar")))
- obj = unpackb(buf, use_list=1)
- assert [b"foo", b"bar"] == obj
- expected_type = bytes
- assert all(type(s) == expected_type for s in obj)
-
-
-def test_unpack_memoryview():
- buf = bytearray(packb(("foo", "bar")))
- view = memoryview(buf)
- obj = unpackb(view, use_list=1)
- assert [b"foo", b"bar"] == obj
- expected_type = bytes
- assert all(type(s) == expected_type for s in obj)
diff --git a/srsly/tests/msgpack/test_case.py b/srsly/tests/msgpack/test_case.py
deleted file mode 100644
index d07fe22..0000000
--- a/srsly/tests/msgpack/test_case.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from srsly.msgpack import packb, unpackb
-
-
-def check(length, obj):
- v = packb(obj)
- assert len(v) == length, "%r length should be %r but get %r" % (obj, length, len(v))
- assert unpackb(v, use_list=0) == obj
-
-
-def test_1():
- for o in [
- None,
- True,
- False,
- 0,
- 1,
- (1 << 6),
- (1 << 7) - 1,
- -1,
- -((1 << 5) - 1),
- -(1 << 5),
- ]:
- check(1, o)
-
-
-def test_2():
- for o in [1 << 7, (1 << 8) - 1, -((1 << 5) + 1), -(1 << 7)]:
- check(2, o)
-
-
-def test_3():
- for o in [1 << 8, (1 << 16) - 1, -((1 << 7) + 1), -(1 << 15)]:
- check(3, o)
-
-
-def test_5():
- for o in [1 << 16, (1 << 32) - 1, -((1 << 15) + 1), -(1 << 31)]:
- check(5, o)
-
-
-def test_9():
- for o in [
- 1 << 32,
- (1 << 64) - 1,
- -((1 << 31) + 1),
- -(1 << 63),
- 1.0,
- 0.1,
- -0.1,
- -1.0,
- ]:
- check(9, o)
-
-
-def check_raw(overhead, num):
- check(num + overhead, b" " * num)
-
-
-def test_fixraw():
- check_raw(1, 0)
- check_raw(1, (1 << 5) - 1)
-
-
-def test_raw16():
- check_raw(3, 1 << 5)
- check_raw(3, (1 << 16) - 1)
-
-
-def test_raw32():
- check_raw(5, 1 << 16)
-
-
-def check_array(overhead, num):
- check(num + overhead, (None,) * num)
-
-
-def test_fixarray():
- check_array(1, 0)
- check_array(1, (1 << 4) - 1)
-
-
-def test_array16():
- check_array(3, 1 << 4)
- check_array(3, (1 << 16) - 1)
-
-
-def test_array32():
- check_array(5, (1 << 16))
-
-
-def match(obj, buf):
- assert packb(obj) == buf
- assert unpackb(buf, use_list=0) == obj
-
-
-def test_match():
- cases = [
- (None, b"\xc0"),
- (False, b"\xc2"),
- (True, b"\xc3"),
- (0, b"\x00"),
- (127, b"\x7f"),
- (128, b"\xcc\x80"),
- (256, b"\xcd\x01\x00"),
- (-1, b"\xff"),
- (-33, b"\xd0\xdf"),
- (-129, b"\xd1\xff\x7f"),
- ({1: 1}, b"\x81\x01\x01"),
- (1.0, b"\xcb\x3f\xf0\x00\x00\x00\x00\x00\x00"),
- ((), b"\x90"),
- (
- tuple(range(15)),
- b"\x9f\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e",
- ),
- (
- tuple(range(16)),
- b"\xdc\x00\x10\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
- ),
- ({}, b"\x80"),
- (
- dict([(x, x) for x in range(15)]),
- b"\x8f\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e",
- ),
- (
- dict([(x, x) for x in range(16)]),
- b"\xde\x00\x10\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e\x0f\x0f",
- ),
- ]
-
- for v, p in cases:
- match(v, p)
-
-
-def test_unicode():
- assert unpackb(packb("foobar"), use_list=1) == b"foobar"
diff --git a/srsly/tests/msgpack/test_except.py b/srsly/tests/msgpack/test_except.py
deleted file mode 100644
index 2322e08..0000000
--- a/srsly/tests/msgpack/test_except.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from pytest import raises
-import datetime
-from srsly.msgpack import packb, unpackb, Unpacker, FormatError, StackError, OutOfData
-
-
-class DummyException(Exception):
- pass
-
-
-def test_raise_on_find_unsupported_value():
- with raises(TypeError):
- packb(datetime.datetime.now())
-
-
-def test_raise_from_object_hook():
- def hook(obj):
- raise DummyException
-
- raises(DummyException, unpackb, packb({}), object_hook=hook)
- raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_hook=hook)
- raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_pairs_hook=hook)
- raises(DummyException, unpackb, packb({"fizz": {"buzz": "spam"}}), object_hook=hook)
- raises(
- DummyException,
- unpackb,
- packb({"fizz": {"buzz": "spam"}}),
- object_pairs_hook=hook,
- )
-
-
-def test_invalidvalue():
- incomplete = b"\xd9\x97#DL_" # raw8 - length=0x97
- with raises(ValueError):
- unpackb(incomplete)
-
- with raises(OutOfData):
- unpacker = Unpacker()
- unpacker.feed(incomplete)
- unpacker.unpack()
-
- with raises(FormatError):
- unpackb(b"\xc1") # (undefined tag)
-
- with raises(FormatError):
- unpackb(b"\x91\xc1") # fixarray(len=1) [ (undefined tag) ]
-
- with raises(StackError):
- unpackb(b"\x91" * 3000) # nested fixarray(len=1)
-
-
-def test_strict_map_key():
- valid = {u"unicode": 1, b"bytes": 2}
- packed = packb(valid, use_bin_type=True)
- assert valid == unpackb(packed, raw=False, strict_map_key=True)
-
- invalid = {42: 1}
- packed = packb(invalid, use_bin_type=True)
- with raises(ValueError):
- unpackb(packed, raw=False, strict_map_key=True)
diff --git a/srsly/tests/msgpack/test_extension.py b/srsly/tests/msgpack/test_extension.py
deleted file mode 100644
index 5cbaafd..0000000
--- a/srsly/tests/msgpack/test_extension.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import array
-from srsly import msgpack
-from srsly.msgpack.ext import ExtType
-
-
-def test_pack_ext_type():
- def p(s):
- packer = msgpack.Packer()
- packer.pack_ext_type(0x42, s)
- return packer.bytes()
-
- assert p(b"A") == b"\xd4\x42A" # fixext 1
- assert p(b"AB") == b"\xd5\x42AB" # fixext 2
- assert p(b"ABCD") == b"\xd6\x42ABCD" # fixext 4
- assert p(b"ABCDEFGH") == b"\xd7\x42ABCDEFGH" # fixext 8
- assert p(b"A" * 16) == b"\xd8\x42" + b"A" * 16 # fixext 16
- assert p(b"ABC") == b"\xc7\x03\x42ABC" # ext 8
- assert p(b"A" * 0x0123) == b"\xc8\x01\x23\x42" + b"A" * 0x0123 # ext 16
- assert (
- p(b"A" * 0x00012345) == b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345
- ) # ext 32
-
-
-def test_unpack_ext_type():
- def check(b, expected):
- assert msgpack.unpackb(b) == expected
-
- check(b"\xd4\x42A", ExtType(0x42, b"A")) # fixext 1
- check(b"\xd5\x42AB", ExtType(0x42, b"AB")) # fixext 2
- check(b"\xd6\x42ABCD", ExtType(0x42, b"ABCD")) # fixext 4
- check(b"\xd7\x42ABCDEFGH", ExtType(0x42, b"ABCDEFGH")) # fixext 8
- check(b"\xd8\x42" + b"A" * 16, ExtType(0x42, b"A" * 16)) # fixext 16
- check(b"\xc7\x03\x42ABC", ExtType(0x42, b"ABC")) # ext 8
- check(b"\xc8\x01\x23\x42" + b"A" * 0x0123, ExtType(0x42, b"A" * 0x0123)) # ext 16
- check(
- b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345,
- ExtType(0x42, b"A" * 0x00012345),
- ) # ext 32
-
-
-def test_extension_type():
- def default(obj):
- print("default called", obj)
- if isinstance(obj, array.array):
- typecode = 123 # application specific typecode
- data = obj.tobytes()
- return ExtType(typecode, data)
- raise TypeError("Unknown type object %r" % (obj,))
-
- def ext_hook(code, data):
- print("ext_hook called", code, data)
- assert code == 123
- obj = array.array("d")
- obj.frombytes(data)
- return obj
-
- obj = [42, b"hello", array.array("d", [1.1, 2.2, 3.3])]
- s = msgpack.packb(obj, default=default)
- obj2 = msgpack.unpackb(s, ext_hook=ext_hook)
- assert obj == obj2
-
-
-def test_overriding_hooks():
- def default(obj):
- if isinstance(obj, int):
- return {"__type__": "long", "__data__": str(obj)}
- else:
- return obj
-
- obj = {"testval": int(1823746192837461928374619)}
- refobj = {"testval": default(obj["testval"])}
- refout = msgpack.packb(refobj)
- assert isinstance(refout, (str, bytes))
- testout = msgpack.packb(obj, default=default)
-
- assert refout == testout
diff --git a/srsly/tests/msgpack/test_format.py b/srsly/tests/msgpack/test_format.py
deleted file mode 100644
index c6f0871..0000000
--- a/srsly/tests/msgpack/test_format.py
+++ /dev/null
@@ -1,82 +0,0 @@
-from srsly.msgpack import unpackb
-
-
-def check(src, should, use_list=0):
- assert unpackb(src, use_list=use_list) == should
-
-
-def testSimpleValue():
- check(b"\x93\xc0\xc2\xc3", (None, False, True))
-
-
-def testFixnum():
- check(b"\x92\x93\x00\x40\x7f\x93\xe0\xf0\xff", ((0, 64, 127), (-32, -16, -1)))
-
-
-def testFixArray():
- check(b"\x92\x90\x91\x91\xc0", ((), ((None,),)))
-
-
-def testFixRaw():
- check(b"\x94\xa0\xa1a\xa2bc\xa3def", (b"", b"a", b"bc", b"def"))
-
-
-def testFixMap():
- check(
- b"\x82\xc2\x81\xc0\xc0\xc3\x81\xc0\x80", {False: {None: None}, True: {None: {}}}
- )
-
-
-def testUnsignedInt():
- check(
- b"\x99\xcc\x00\xcc\x80\xcc\xff\xcd\x00\x00\xcd\x80\x00"
- b"\xcd\xff\xff\xce\x00\x00\x00\x00\xce\x80\x00\x00\x00"
- b"\xce\xff\xff\xff\xff",
- (0, 128, 255, 0, 32768, 65535, 0, 2147483648, 4294967295),
- )
-
-
-def testSignedInt():
- check(
- b"\x99\xd0\x00\xd0\x80\xd0\xff\xd1\x00\x00\xd1\x80\x00"
- b"\xd1\xff\xff\xd2\x00\x00\x00\x00\xd2\x80\x00\x00\x00"
- b"\xd2\xff\xff\xff\xff",
- (0, -128, -1, 0, -32768, -1, 0, -2147483648, -1),
- )
-
-
-def testRaw():
- check(
- b"\x96\xda\x00\x00\xda\x00\x01a\xda\x00\x02ab\xdb\x00\x00"
- b"\x00\x00\xdb\x00\x00\x00\x01a\xdb\x00\x00\x00\x02ab",
- (b"", b"a", b"ab", b"", b"a", b"ab"),
- )
-
-
-def testArray():
- check(
- b"\x96\xdc\x00\x00\xdc\x00\x01\xc0\xdc\x00\x02\xc2\xc3\xdd\x00"
- b"\x00\x00\x00\xdd\x00\x00\x00\x01\xc0\xdd\x00\x00\x00\x02"
- b"\xc2\xc3",
- ((), (None,), (False, True), (), (None,), (False, True)),
- )
-
-
-def testMap():
- check(
- b"\x96"
- b"\xde\x00\x00"
- b"\xde\x00\x01\xc0\xc2"
- b"\xde\x00\x02\xc0\xc2\xc3\xc2"
- b"\xdf\x00\x00\x00\x00"
- b"\xdf\x00\x00\x00\x01\xc0\xc2"
- b"\xdf\x00\x00\x00\x02\xc0\xc2\xc3\xc2",
- (
- {},
- {None: False},
- {True: False, None: False},
- {},
- {None: False},
- {True: False, None: False},
- ),
- )
diff --git a/srsly/tests/msgpack/test_limits.py b/srsly/tests/msgpack/test_limits.py
deleted file mode 100644
index bafa4c1..0000000
--- a/srsly/tests/msgpack/test_limits.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import pytest
-from srsly.msgpack import packb, unpackb, Packer, Unpacker, ExtType
-from srsly.msgpack import PackOverflowError, PackValueError, UnpackValueError
-
-
-def test_integer():
- x = -(2 ** 63)
- assert unpackb(packb(x)) == x
- with pytest.raises(PackOverflowError):
- packb(x - 1)
-
- x = 2 ** 64 - 1
- assert unpackb(packb(x)) == x
- with pytest.raises(PackOverflowError):
- packb(x + 1)
-
-
-def test_array_header():
- packer = Packer()
- packer.pack_array_header(2 ** 32 - 1)
- with pytest.raises(PackValueError):
- packer.pack_array_header(2 ** 32)
-
-
-def test_map_header():
- packer = Packer()
- packer.pack_map_header(2 ** 32 - 1)
- with pytest.raises(PackValueError):
- packer.pack_array_header(2 ** 32)
-
-
-def test_max_str_len():
- d = "x" * 3
- packed = packb(d)
-
- unpacker = Unpacker(max_str_len=3, raw=False)
- unpacker.feed(packed)
- assert unpacker.unpack() == d
-
- unpacker = Unpacker(max_str_len=2, raw=False)
- with pytest.raises(UnpackValueError):
- unpacker.feed(packed)
- unpacker.unpack()
-
-
-def test_max_bin_len():
- d = b"x" * 3
- packed = packb(d, use_bin_type=True)
-
- unpacker = Unpacker(max_bin_len=3)
- unpacker.feed(packed)
- assert unpacker.unpack() == d
-
- unpacker = Unpacker(max_bin_len=2)
- with pytest.raises(UnpackValueError):
- unpacker.feed(packed)
- unpacker.unpack()
-
-
-def test_max_array_len():
- d = [1, 2, 3]
- packed = packb(d)
-
- unpacker = Unpacker(max_array_len=3)
- unpacker.feed(packed)
- assert unpacker.unpack() == d
-
- unpacker = Unpacker(max_array_len=2)
- with pytest.raises(UnpackValueError):
- unpacker.feed(packed)
- unpacker.unpack()
-
-
-def test_max_map_len():
- d = {1: 2, 3: 4, 5: 6}
- packed = packb(d)
-
- unpacker = Unpacker(max_map_len=3)
- unpacker.feed(packed)
- assert unpacker.unpack() == d
-
- unpacker = Unpacker(max_map_len=2)
- with pytest.raises(UnpackValueError):
- unpacker.feed(packed)
- unpacker.unpack()
-
-
-def test_max_ext_len():
- d = ExtType(42, b"abc")
- packed = packb(d)
-
- unpacker = Unpacker(max_ext_len=3)
- unpacker.feed(packed)
- assert unpacker.unpack() == d
-
- unpacker = Unpacker(max_ext_len=2)
- with pytest.raises(UnpackValueError):
- unpacker.feed(packed)
- unpacker.unpack()
-
-
-# PyPy fails following tests because of constant folding?
-# https://bugs.pypy.org/issue1721
-# @pytest.mark.skipif(True, reason="Requires very large memory.")
-# def test_binary():
-# x = b'x' * (2**32 - 1)
-# assert unpackb(packb(x)) == x
-# del x
-# x = b'x' * (2**32)
-# with pytest.raises(ValueError):
-# packb(x)
-#
-#
-# @pytest.mark.skipif(True, reason="Requires very large memory.")
-# def test_string():
-# x = 'x' * (2**32 - 1)
-# assert unpackb(packb(x)) == x
-# x += 'y'
-# with pytest.raises(ValueError):
-# packb(x)
-#
-#
-# @pytest.mark.skipif(True, reason="Requires very large memory.")
-# def test_array():
-# x = [0] * (2**32 - 1)
-# assert unpackb(packb(x)) == x
-# x.append(0)
-# with pytest.raises(ValueError):
-# packb(x)
diff --git a/srsly/tests/msgpack/test_memoryview.py b/srsly/tests/msgpack/test_memoryview.py
deleted file mode 100644
index b182c7f..0000000
--- a/srsly/tests/msgpack/test_memoryview.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from array import array
-from srsly.msgpack import packb, unpackb
-
-
-make_memoryview = memoryview
-
-
-def make_array(f, data):
- a = array(f)
- a.frombytes(data)
- return a
-
-
-def get_data(a):
- return a.tobytes()
-
-
-def _runtest(format, nbytes, expected_header, expected_prefix, use_bin_type):
- # create a new array
- original_array = array(format)
- original_array.fromlist([255] * (nbytes // original_array.itemsize))
- original_data = get_data(original_array)
- view = make_memoryview(original_array)
-
- # pack, unpack, and reconstruct array
- packed = packb(view, use_bin_type=use_bin_type)
- unpacked = unpackb(packed)
- reconstructed_array = make_array(format, unpacked)
-
- # check that we got the right amount of data
- assert len(original_data) == nbytes
- # check packed header
- assert packed[:1] == expected_header
- # check packed length prefix, if any
- assert packed[1 : 1 + len(expected_prefix)] == expected_prefix
- # check packed data
- assert packed[1 + len(expected_prefix) :] == original_data
- # check array unpacked correctly
- assert original_array == reconstructed_array
-
-
-def test_fixstr_from_byte():
- _runtest("B", 1, b"\xa1", b"", False)
- _runtest("B", 31, b"\xbf", b"", False)
-
-
-def test_fixstr_from_float():
- _runtest("f", 4, b"\xa4", b"", False)
- _runtest("f", 28, b"\xbc", b"", False)
-
-
-def test_str16_from_byte():
- _runtest("B", 2 ** 8, b"\xda", b"\x01\x00", False)
- _runtest("B", 2 ** 16 - 1, b"\xda", b"\xff\xff", False)
-
-
-def test_str16_from_float():
- _runtest("f", 2 ** 8, b"\xda", b"\x01\x00", False)
- _runtest("f", 2 ** 16 - 4, b"\xda", b"\xff\xfc", False)
-
-
-def test_str32_from_byte():
- _runtest("B", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False)
-
-
-def test_str32_from_float():
- _runtest("f", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False)
-
-
-def test_bin8_from_byte():
- _runtest("B", 1, b"\xc4", b"\x01", True)
- _runtest("B", 2 ** 8 - 1, b"\xc4", b"\xff", True)
-
-
-def test_bin8_from_float():
- _runtest("f", 4, b"\xc4", b"\x04", True)
- _runtest("f", 2 ** 8 - 4, b"\xc4", b"\xfc", True)
-
-
-def test_bin16_from_byte():
- _runtest("B", 2 ** 8, b"\xc5", b"\x01\x00", True)
- _runtest("B", 2 ** 16 - 1, b"\xc5", b"\xff\xff", True)
-
-
-def test_bin16_from_float():
- _runtest("f", 2 ** 8, b"\xc5", b"\x01\x00", True)
- _runtest("f", 2 ** 16 - 4, b"\xc5", b"\xff\xfc", True)
-
-
-def test_bin32_from_byte():
- _runtest("B", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True)
-
-
-def test_bin32_from_float():
- _runtest("f", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True)
diff --git a/srsly/tests/msgpack/test_newspec.py b/srsly/tests/msgpack/test_newspec.py
deleted file mode 100644
index 316280b..0000000
--- a/srsly/tests/msgpack/test_newspec.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from srsly.msgpack import packb, unpackb, ExtType
-
-
-def test_str8():
- header = b"\xd9"
- data = b"x" * 32
- b = packb(data.decode(), use_bin_type=True)
- assert len(b) == len(data) + 2
- assert b[0:2] == header + b"\x20"
- assert b[2:] == data
- assert unpackb(b) == data
-
- data = b"x" * 255
- b = packb(data.decode(), use_bin_type=True)
- assert len(b) == len(data) + 2
- assert b[0:2] == header + b"\xff"
- assert b[2:] == data
- assert unpackb(b) == data
-
-
-def test_bin8():
- header = b"\xc4"
- data = b""
- b = packb(data, use_bin_type=True)
- assert len(b) == len(data) + 2
- assert b[0:2] == header + b"\x00"
- assert b[2:] == data
- assert unpackb(b) == data
-
- data = b"x" * 255
- b = packb(data, use_bin_type=True)
- assert len(b) == len(data) + 2
- assert b[0:2] == header + b"\xff"
- assert b[2:] == data
- assert unpackb(b) == data
-
-
-def test_bin16():
- header = b"\xc5"
- data = b"x" * 256
- b = packb(data, use_bin_type=True)
- assert len(b) == len(data) + 3
- assert b[0:1] == header
- assert b[1:3] == b"\x01\x00"
- assert b[3:] == data
- assert unpackb(b) == data
-
- data = b"x" * 65535
- b = packb(data, use_bin_type=True)
- assert len(b) == len(data) + 3
- assert b[0:1] == header
- assert b[1:3] == b"\xff\xff"
- assert b[3:] == data
- assert unpackb(b) == data
-
-
-def test_bin32():
- header = b"\xc6"
- data = b"x" * 65536
- b = packb(data, use_bin_type=True)
- assert len(b) == len(data) + 5
- assert b[0:1] == header
- assert b[1:5] == b"\x00\x01\x00\x00"
- assert b[5:] == data
- assert unpackb(b) == data
-
-
-def test_ext():
- def check(ext, packed):
- assert packb(ext) == packed
- assert unpackb(packed) == ext
-
- check(ExtType(0x42, b"Z"), b"\xd4\x42Z") # fixext 1
- check(ExtType(0x42, b"ZZ"), b"\xd5\x42ZZ") # fixext 2
- check(ExtType(0x42, b"Z" * 4), b"\xd6\x42" + b"Z" * 4) # fixext 4
- check(ExtType(0x42, b"Z" * 8), b"\xd7\x42" + b"Z" * 8) # fixext 8
- check(ExtType(0x42, b"Z" * 16), b"\xd8\x42" + b"Z" * 16) # fixext 16
- # ext 8
- check(ExtType(0x42, b""), b"\xc7\x00\x42")
- check(ExtType(0x42, b"Z" * 255), b"\xc7\xff\x42" + b"Z" * 255)
- # ext 16
- check(ExtType(0x42, b"Z" * 256), b"\xc8\x01\x00\x42" + b"Z" * 256)
- check(ExtType(0x42, b"Z" * 0xFFFF), b"\xc8\xff\xff\x42" + b"Z" * 0xFFFF)
- # ext 32
- check(ExtType(0x42, b"Z" * 0x10000), b"\xc9\x00\x01\x00\x00\x42" + b"Z" * 0x10000)
- # needs large memory
- # check(ExtType(0x42, b'Z'*0xffffffff),
- # b'\xc9\xff\xff\xff\xff\x42' + b'Z'*0xffffffff)
diff --git a/srsly/tests/msgpack/test_pack.py b/srsly/tests/msgpack/test_pack.py
deleted file mode 100644
index 22f088d..0000000
--- a/srsly/tests/msgpack/test_pack.py
+++ /dev/null
@@ -1,201 +0,0 @@
-import struct
-import pytest
-from collections import OrderedDict
-from io import BytesIO
-from srsly.msgpack import packb, unpackb, Unpacker, Packer
-
-
-def check(data, use_list=False):
- re = unpackb(packb(data), use_list=use_list)
- assert re == data
-
-
-def testPack():
- test_data = [
- 0,
- 1,
- 127,
- 128,
- 255,
- 256,
- 65535,
- 65536,
- 4294967295,
- 4294967296,
- -1,
- -32,
- -33,
- -128,
- -129,
- -32768,
- -32769,
- -4294967296,
- -4294967297,
- 1.0,
- b"",
- b"a",
- b"a" * 31,
- b"a" * 32,
- None,
- True,
- False,
- (),
- ((),),
- ((), None),
- {None: 0},
- (1 << 23),
- ]
- for td in test_data:
- check(td)
-
-
-def testPackUnicode():
- test_data = ["", "abcd", ["defgh"], "Русский текст"]
- for td in test_data:
- re = unpackb(packb(td), use_list=1, raw=False)
- assert re == td
- packer = Packer()
- data = packer.pack(td)
- re = Unpacker(BytesIO(data), raw=False, use_list=1).unpack()
- assert re == td
-
-
-def testPackUTF32(): # deprecated
- re = unpackb(packb("", encoding="utf-32"), use_list=1, encoding="utf-32")
- assert re == ""
- re = unpackb(packb("abcd", encoding="utf-32"), use_list=1, encoding="utf-32")
- assert re == "abcd"
- re = unpackb(packb(["defgh"], encoding="utf-32"), use_list=1, encoding="utf-32")
- assert re == ["defgh"]
- try:
- packb("Русский текст", encoding="utf-32")
- except LookupError as e:
- pytest.xfail(str(e))
- # try:
- # test_data = ["", "abcd", ["defgh"], "Русский текст"]
- # for td in test_data:
- # except LookupError as e:
- # pytest.xfail(e)
-
-
-def testPackBytes():
- test_data = [b"", b"abcd", (b"defgh",)]
- for td in test_data:
- check(td)
-
-
-def testPackByteArrays():
- test_data = [bytearray(b""), bytearray(b"abcd"), (bytearray(b"defgh"),)]
- for td in test_data:
- check(td)
-
-
-def testIgnoreUnicodeErrors(): # deprecated
- re = unpackb(
- packb(b"abc\xeddef"), encoding="utf-8", unicode_errors="ignore", use_list=1
- )
- assert re == "abcdef"
-
-
-def testStrictUnicodeUnpack():
- with pytest.raises(UnicodeDecodeError):
- unpackb(packb(b"abc\xeddef"), raw=False, use_list=1)
-
-
-def testStrictUnicodePack(): # deprecated
- with pytest.raises(UnicodeEncodeError):
- packb("abc\xeddef", encoding="ascii", unicode_errors="strict")
-
-
-def testIgnoreErrorsPack(): # deprecated
- re = unpackb(
- packb("abcФФФdef", encoding="ascii", unicode_errors="ignore"),
- raw=False,
- use_list=1,
- )
- assert re == "abcdef"
-
-
-def testDecodeBinary():
- re = unpackb(packb(b"abc"), encoding=None, use_list=1)
- assert re == b"abc"
-
-
-def testPackFloat():
- assert packb(1.0, use_single_float=True) == b"\xca" + struct.pack(str(">f"), 1.0)
- assert packb(1.0, use_single_float=False) == b"\xcb" + struct.pack(str(">d"), 1.0)
-
-
-def testArraySize(sizes=[0, 5, 50, 1000]):
- bio = BytesIO()
- packer = Packer()
- for size in sizes:
- bio.write(packer.pack_array_header(size))
- for i in range(size):
- bio.write(packer.pack(i))
-
- bio.seek(0)
- unpacker = Unpacker(bio, use_list=1)
- for size in sizes:
- assert unpacker.unpack() == list(range(size))
-
-
-def test_manualreset(sizes=[0, 5, 50, 1000]):
- packer = Packer(autoreset=False)
- for size in sizes:
- packer.pack_array_header(size)
- for i in range(size):
- packer.pack(i)
-
- bio = BytesIO(packer.bytes())
- unpacker = Unpacker(bio, use_list=1)
- for size in sizes:
- assert unpacker.unpack() == list(range(size))
-
- packer.reset()
- assert packer.bytes() == b""
-
-
-def testMapSize(sizes=[0, 5, 50, 1000]):
- bio = BytesIO()
- packer = Packer()
- for size in sizes:
- bio.write(packer.pack_map_header(size))
- for i in range(size):
- bio.write(packer.pack(i)) # key
- bio.write(packer.pack(i * 2)) # value
-
- bio.seek(0)
- unpacker = Unpacker(bio)
- for size in sizes:
- assert unpacker.unpack() == dict((i, i * 2) for i in range(size))
-
-
-def test_odict():
- seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)]
- od = OrderedDict(seq)
- assert unpackb(packb(od), use_list=1) == dict(seq)
-
- def pair_hook(seq):
- return list(seq)
-
- assert unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1) == seq
-
-
-def test_pairlist():
- pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
- packer = Packer()
- packed = packer.pack_map_pairs(pairlist)
- unpacked = unpackb(packed, object_pairs_hook=list)
- assert pairlist == unpacked
-
-
-def test_get_buffer():
- packer = Packer(autoreset=0, use_bin_type=True)
- packer.pack([1, 2])
- strm = BytesIO()
- strm.write(packer.getbuffer())
- written = strm.getvalue()
-
- expected = packb([1, 2], use_bin_type=True)
- assert written == expected
diff --git a/srsly/tests/msgpack/test_read_size.py b/srsly/tests/msgpack/test_read_size.py
deleted file mode 100644
index 63e8b47..0000000
--- a/srsly/tests/msgpack/test_read_size.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from srsly.msgpack import packb, Unpacker, OutOfData
-
-
-UnexpectedTypeException = ValueError
-
-
-def test_read_array_header():
- unpacker = Unpacker()
- unpacker.feed(packb(["a", "b", "c"]))
- assert unpacker.read_array_header() == 3
- assert unpacker.unpack() == b"a"
- assert unpacker.unpack() == b"b"
- assert unpacker.unpack() == b"c"
- try:
- unpacker.unpack()
- assert 0, "should raise exception"
- except OutOfData:
- assert 1, "okay"
-
-
-def test_read_map_header():
- unpacker = Unpacker()
- unpacker.feed(packb({"a": "A"}))
- assert unpacker.read_map_header() == 1
- assert unpacker.unpack() == b"a"
- assert unpacker.unpack() == b"A"
- try:
- unpacker.unpack()
- assert 0, "should raise exception"
- except OutOfData:
- assert 1, "okay"
-
-
-def test_incorrect_type_array():
- unpacker = Unpacker()
- unpacker.feed(packb(1))
- try:
- unpacker.read_array_header()
- assert 0, "should raise exception"
- except UnexpectedTypeException:
- assert 1, "okay"
-
-
-def test_incorrect_type_map():
- unpacker = Unpacker()
- unpacker.feed(packb(1))
- try:
- unpacker.read_map_header()
- assert 0, "should raise exception"
- except UnexpectedTypeException:
- assert 1, "okay"
-
-
-def test_correct_type_nested_array():
- unpacker = Unpacker()
- unpacker.feed(packb({"a": ["b", "c", "d"]}))
- try:
- unpacker.read_array_header()
- assert 0, "should raise exception"
- except UnexpectedTypeException:
- assert 1, "okay"
-
-
-def test_incorrect_type_nested_map():
- unpacker = Unpacker()
- unpacker.feed(packb([{"a": "b"}]))
- try:
- unpacker.read_map_header()
- assert 0, "should raise exception"
- except UnexpectedTypeException:
- assert 1, "okay"
diff --git a/srsly/tests/msgpack/test_seq.py b/srsly/tests/msgpack/test_seq.py
deleted file mode 100644
index c50e1cf..0000000
--- a/srsly/tests/msgpack/test_seq.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import io
-from srsly import msgpack
-
-
-binarydata = bytes(bytearray(range(256)))
-
-
-def gen_binary_data(idx):
- return binarydata[: idx % 300]
-
-
-def test_exceeding_unpacker_read_size():
- dumpf = io.BytesIO()
-
- packer = msgpack.Packer()
-
- NUMBER_OF_STRINGS = 6
- read_size = 16
- # 5 ok for read_size=16, while 6 glibc detected *** python: double free or corruption (fasttop):
- # 20 ok for read_size=256, while 25 segfaults / glibc detected *** python: double free or corruption (!prev)
- # 40 ok for read_size=1024, while 50 introduces errors
- # 7000 ok for read_size=1024*1024, while 8000 leads to glibc detected *** python: double free or corruption (!prev):
-
- for idx in range(NUMBER_OF_STRINGS):
- data = gen_binary_data(idx)
- dumpf.write(packer.pack(data))
-
- f = io.BytesIO(dumpf.getvalue())
- dumpf.close()
-
- unpacker = msgpack.Unpacker(f, read_size=read_size, use_list=1)
-
- read_count = 0
- for idx, o in enumerate(unpacker):
- assert type(o) == bytes
- assert o == gen_binary_data(idx)
- read_count += 1
-
- assert read_count == NUMBER_OF_STRINGS
diff --git a/srsly/tests/msgpack/test_sequnpack.py b/srsly/tests/msgpack/test_sequnpack.py
deleted file mode 100644
index bfd7afa..0000000
--- a/srsly/tests/msgpack/test_sequnpack.py
+++ /dev/null
@@ -1,128 +0,0 @@
-import io
-import pytest
-from srsly.msgpack import Unpacker, BufferFull
-from srsly.msgpack import pack
-from srsly.msgpack.exceptions import OutOfData
-
-
-def test_partialdata():
- unpacker = Unpacker()
- unpacker.feed(b"\xa5")
- with pytest.raises(StopIteration):
- next(iter(unpacker))
- unpacker.feed(b"h")
- with pytest.raises(StopIteration):
- next(iter(unpacker))
- unpacker.feed(b"a")
- with pytest.raises(StopIteration):
- next(iter(unpacker))
- unpacker.feed(b"l")
- with pytest.raises(StopIteration):
- next(iter(unpacker))
- unpacker.feed(b"l")
- with pytest.raises(StopIteration):
- next(iter(unpacker))
- unpacker.feed(b"o")
- assert next(iter(unpacker)) == b"hallo"
-
-
-def test_foobar():
- unpacker = Unpacker(read_size=3, use_list=1)
- unpacker.feed(b"foobar")
- assert unpacker.unpack() == ord(b"f")
- assert unpacker.unpack() == ord(b"o")
- assert unpacker.unpack() == ord(b"o")
- assert unpacker.unpack() == ord(b"b")
- assert unpacker.unpack() == ord(b"a")
- assert unpacker.unpack() == ord(b"r")
- with pytest.raises(OutOfData):
- unpacker.unpack()
-
- unpacker.feed(b"foo")
- unpacker.feed(b"bar")
-
- k = 0
- for o, e in zip(unpacker, "foobarbaz"):
- assert o == ord(e)
- k += 1
- assert k == len(b"foobar")
-
-
-def test_foobar_skip():
- unpacker = Unpacker(read_size=3, use_list=1)
- unpacker.feed(b"foobar")
- assert unpacker.unpack() == ord(b"f")
- unpacker.skip()
- assert unpacker.unpack() == ord(b"o")
- unpacker.skip()
- assert unpacker.unpack() == ord(b"a")
- unpacker.skip()
- with pytest.raises(OutOfData):
- unpacker.unpack()
-
-
-def test_maxbuffersize():
- with pytest.raises(ValueError):
- Unpacker(read_size=5, max_buffer_size=3)
- unpacker = Unpacker(read_size=3, max_buffer_size=3, use_list=1)
- unpacker.feed(b"fo")
- with pytest.raises(BufferFull):
- unpacker.feed(b"ob")
- unpacker.feed(b"o")
- assert ord("f") == next(unpacker)
- unpacker.feed(b"b")
- assert ord("o") == next(unpacker)
- assert ord("o") == next(unpacker)
- assert ord("b") == next(unpacker)
-
-
-def test_readbytes():
- unpacker = Unpacker(read_size=3)
- unpacker.feed(b"foobar")
- assert unpacker.unpack() == ord(b"f")
- assert unpacker.read_bytes(3) == b"oob"
- assert unpacker.unpack() == ord(b"a")
- assert unpacker.unpack() == ord(b"r")
-
- # Test buffer refill
- unpacker = Unpacker(io.BytesIO(b"foobar"), read_size=3)
- assert unpacker.unpack() == ord(b"f")
- assert unpacker.read_bytes(3) == b"oob"
- assert unpacker.unpack() == ord(b"a")
- assert unpacker.unpack() == ord(b"r")
-
-
-def test_issue124():
- unpacker = Unpacker()
- unpacker.feed(b"\xa1?\xa1!")
- assert tuple(unpacker) == (b"?", b"!")
- assert tuple(unpacker) == ()
- unpacker.feed(b"\xa1?\xa1")
- assert tuple(unpacker) == (b"?",)
- assert tuple(unpacker) == ()
- unpacker.feed(b"!")
- assert tuple(unpacker) == (b"!",)
- assert tuple(unpacker) == ()
-
-
-def test_unpack_tell():
- stream = io.BytesIO()
- messages = [2 ** i - 1 for i in range(65)]
- messages += [-(2 ** i) for i in range(1, 64)]
- messages += [
- b"hello",
- b"hello" * 1000,
- list(range(20)),
- {i: bytes(i) * i for i in range(10)},
- {i: bytes(i) * i for i in range(32)},
- ]
- offsets = []
- for m in messages:
- pack(m, stream)
- offsets.append(stream.tell())
- stream.seek(0)
- unpacker = Unpacker(stream)
- for m, o in zip(messages, offsets):
- m2 = next(unpacker)
- assert m == m2
- assert o == unpacker.tell()
diff --git a/srsly/tests/msgpack/test_stricttype.py b/srsly/tests/msgpack/test_stricttype.py
deleted file mode 100644
index fd99185..0000000
--- a/srsly/tests/msgpack/test_stricttype.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from collections import namedtuple
-from srsly.msgpack import packb, unpackb, ExtType
-
-
-def test_namedtuple():
- T = namedtuple("T", "foo bar")
-
- def default(o):
- if isinstance(o, T):
- return dict(o._asdict())
- raise TypeError("Unsupported type %s" % (type(o),))
-
- packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default)
- unpacked = unpackb(packed, raw=False)
- assert unpacked == {"foo": 1, "bar": 42}
-
-
-def test_tuple():
- t = ("one", 2, b"three", (4,))
-
- def default(o):
- if isinstance(o, tuple):
- return {"__type__": "tuple", "value": list(o)}
- raise TypeError("Unsupported type %s" % (type(o),))
-
- def convert(o):
- if o.get("__type__") == "tuple":
- return tuple(o["value"])
- return o
-
- data = packb(t, strict_types=True, use_bin_type=True, default=default)
- expected = unpackb(data, raw=False, object_hook=convert)
-
- assert expected == t
-
-
-def test_tuple_ext():
- t = ("one", 2, b"three", (4,))
-
- MSGPACK_EXT_TYPE_TUPLE = 0
-
- def default(o):
- if isinstance(o, tuple):
- # Convert to list and pack
- payload = packb(
- list(o), strict_types=True, use_bin_type=True, default=default
- )
- return ExtType(MSGPACK_EXT_TYPE_TUPLE, payload)
- raise TypeError(repr(o))
-
- def convert(code, payload):
- if code == MSGPACK_EXT_TYPE_TUPLE:
- # Unpack and convert to tuple
- return tuple(unpackb(payload, raw=False, ext_hook=convert))
- raise ValueError("Unknown Ext code {}".format(code))
-
- data = packb(t, strict_types=True, use_bin_type=True, default=default)
- expected = unpackb(data, raw=False, ext_hook=convert)
-
- assert expected == t
diff --git a/srsly/tests/msgpack/test_subtype.py b/srsly/tests/msgpack/test_subtype.py
deleted file mode 100644
index 7dfe08c..0000000
--- a/srsly/tests/msgpack/test_subtype.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from collections import namedtuple
-from srsly.msgpack import packb
-
-
-class MyList(list):
- pass
-
-
-class MyDict(dict):
- pass
-
-
-class MyTuple(tuple):
- pass
-
-
-MyNamedTuple = namedtuple("MyNamedTuple", "x y")
-
-
-def test_types():
- assert packb(MyDict()) == packb(dict())
- assert packb(MyList()) == packb(list())
- assert packb(MyNamedTuple(1, 2)) == packb((1, 2))
diff --git a/srsly/tests/msgpack/test_unpack.py b/srsly/tests/msgpack/test_unpack.py
deleted file mode 100644
index 6045ea0..0000000
--- a/srsly/tests/msgpack/test_unpack.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from io import BytesIO
-import sys
-import pytest
-from srsly.msgpack import Unpacker, packb, OutOfData, ExtType
-
-
-def test_unpack_array_header_from_file():
- f = BytesIO(packb([1, 2, 3, 4]))
- unpacker = Unpacker(f)
- assert unpacker.read_array_header() == 4
- assert unpacker.unpack() == 1
- assert unpacker.unpack() == 2
- assert unpacker.unpack() == 3
- assert unpacker.unpack() == 4
- with pytest.raises(OutOfData):
- unpacker.unpack()
-
-
-@pytest.mark.skipif(
- "not hasattr(sys, 'getrefcount') == True",
- reason="sys.getrefcount() is needed to pass this test",
-)
-def test_unpacker_hook_refcnt():
- result = []
-
- def hook(x):
- result.append(x)
- return x
-
- basecnt = sys.getrefcount(hook)
-
- up = Unpacker(object_hook=hook, list_hook=hook)
-
- assert sys.getrefcount(hook) >= basecnt + 2
-
- up.feed(packb([{}]))
- up.feed(packb([{}]))
- assert up.unpack() == [{}]
- assert up.unpack() == [{}]
- assert result == [{}, [{}], {}, [{}]]
-
- del up
-
- assert sys.getrefcount(hook) == basecnt
-
-
-def test_unpacker_ext_hook():
- class MyUnpacker(Unpacker):
- def __init__(self):
- super(MyUnpacker, self).__init__(ext_hook=self._hook, raw=False)
-
- def _hook(self, code, data):
- if code == 1:
- return int(data)
- else:
- return ExtType(code, data)
-
- unpacker = MyUnpacker()
- unpacker.feed(packb({"a": 1}))
- assert unpacker.unpack() == {"a": 1}
- unpacker.feed(packb({"a": ExtType(1, b"123")}))
- assert unpacker.unpack() == {"a": 123}
- unpacker.feed(packb({"a": ExtType(2, b"321")}))
- assert unpacker.unpack() == {"a": ExtType(2, b"321")}
-
-
-if __name__ == "__main__":
- test_unpack_array_header_from_file()
- test_unpacker_hook_refcnt()
- test_unpacker_ext_hook()
diff --git a/srsly/tests/ruamel_yaml/__init__.py b/srsly/tests/ruamel_yaml/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/srsly/tests/ruamel_yaml/roundtrip.py b/srsly/tests/ruamel_yaml/roundtrip.py
deleted file mode 100755
index e3fd26f..0000000
--- a/srsly/tests/ruamel_yaml/roundtrip.py
+++ /dev/null
@@ -1,308 +0,0 @@
-"""
-helper routines for testing round trip of commented YAML data
-"""
-import sys
-import textwrap
-from pathlib import Path
-
-enforce = object()
-
-
-def dedent(data):
- try:
- position_of_first_newline = data.index("\n")
- for idx in range(position_of_first_newline):
- if not data[idx].isspace():
- raise ValueError
- except ValueError:
- pass
- else:
- data = data[position_of_first_newline + 1 :]
- return textwrap.dedent(data)
-
-
-def round_trip_load(inp, preserve_quotes=None, version=None):
- import srsly.ruamel_yaml # NOQA
-
- dinp = dedent(inp)
- return srsly.ruamel_yaml.load(
- dinp,
- Loader=srsly.ruamel_yaml.RoundTripLoader,
- preserve_quotes=preserve_quotes,
- version=version,
- )
-
-
-def round_trip_load_all(inp, preserve_quotes=None, version=None):
- import srsly.ruamel_yaml # NOQA
-
- dinp = dedent(inp)
- return srsly.ruamel_yaml.load_all(
- dinp,
- Loader=srsly.ruamel_yaml.RoundTripLoader,
- preserve_quotes=preserve_quotes,
- version=version,
- )
-
-
-def round_trip_dump(
- data,
- stream=None,
- indent=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
-):
- import srsly.ruamel_yaml # NOQA
-
- return srsly.ruamel_yaml.round_trip_dump(
- data,
- stream=stream,
- indent=indent,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- )
-
-
-def diff(inp, outp, file_name="stdin"):
- import difflib
-
- inl = inp.splitlines(True) # True for keepends
- outl = outp.splitlines(True)
- diff = difflib.unified_diff(inl, outl, file_name, "round trip YAML")
- # 2.6 difflib has trailing space on filename lines %-)
- strip_trailing_space = sys.version_info < (2, 7)
- for line in diff:
- if strip_trailing_space and line[:4] in ["--- ", "+++ "]:
- line = line.rstrip() + "\n"
- sys.stdout.write(line)
-
-
-def round_trip(
- inp,
- outp=None,
- extra=None,
- intermediate=None,
- indent=None,
- block_seq_indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- preserve_quotes=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- dump_data=None,
-):
- """
- inp: input string to parse
- outp: expected output (equals input if not specified)
- """
- if outp is None:
- outp = inp
- doutp = dedent(outp)
- if extra is not None:
- doutp += extra
- data = round_trip_load(inp, preserve_quotes=preserve_quotes)
- if dump_data:
- print("data", data)
- if intermediate is not None:
- if isinstance(intermediate, dict):
- for k, v in intermediate.items():
- if data[k] != v:
- print("{0!r} <> {1!r}".format(data[k], v))
- raise ValueError
- res = round_trip_dump(
- data,
- indent=indent,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- )
- if res != doutp:
- diff(doutp, res, "input string")
- print("\nroundtrip data:\n", res, sep="")
- assert res == doutp
- res = round_trip_dump(
- data,
- indent=indent,
- block_seq_indent=block_seq_indent,
- top_level_colon_align=top_level_colon_align,
- prefix_colon=prefix_colon,
- explicit_start=explicit_start,
- explicit_end=explicit_end,
- version=version,
- )
- print("roundtrip second round data:\n", res, sep="")
- assert res == doutp
- return data
-
-
-def na_round_trip(
- inp,
- outp=None,
- extra=None,
- intermediate=None,
- indent=None,
- top_level_colon_align=None,
- prefix_colon=None,
- preserve_quotes=None,
- explicit_start=None,
- explicit_end=None,
- version=None,
- dump_data=None,
-):
- """
- inp: input string to parse
- outp: expected output (equals input if not specified)
- """
- inp = dedent(inp)
- if outp is None:
- outp = inp
- if version is not None:
- version = version
- doutp = dedent(outp)
- if extra is not None:
- doutp += extra
- yaml = YAML()
- yaml.preserve_quotes = preserve_quotes
- yaml.scalar_after_indicator = False # newline after every directives end
- data = yaml.load(inp)
- if dump_data:
- print("data", data)
- if intermediate is not None:
- if isinstance(intermediate, dict):
- for k, v in intermediate.items():
- if data[k] != v:
- print("{0!r} <> {1!r}".format(data[k], v))
- raise ValueError
- yaml.indent = indent
- yaml.top_level_colon_align = top_level_colon_align
- yaml.prefix_colon = prefix_colon
- yaml.explicit_start = explicit_start
- yaml.explicit_end = explicit_end
- res = yaml.dump(data, compare=doutp)
- return res
-
-
-def YAML(**kw):
- import srsly.ruamel_yaml # NOQA
-
- class MyYAML(srsly.ruamel_yaml.YAML):
- """auto dedent string parameters on load"""
-
- def load(self, stream):
- if isinstance(stream, str):
- if stream and stream[0] == "\n":
- stream = stream[1:]
- stream = textwrap.dedent(stream)
- return srsly.ruamel_yaml.YAML.load(self, stream)
-
- def load_all(self, stream):
- if isinstance(stream, str):
- if stream and stream[0] == "\n":
- stream = stream[1:]
- stream = textwrap.dedent(stream)
- for d in srsly.ruamel_yaml.YAML.load_all(self, stream):
- yield d
-
- def dump(self, data, **kw):
- from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
-
- assert ("stream" in kw) ^ ("compare" in kw)
- if "stream" in kw:
- return srsly.ruamel_yaml.YAML.dump(data, **kw)
- lkw = kw.copy()
- expected = textwrap.dedent(lkw.pop("compare"))
- unordered_lines = lkw.pop("unordered_lines", False)
- if expected and expected[0] == "\n":
- expected = expected[1:]
- lkw["stream"] = st = StringIO()
- srsly.ruamel_yaml.YAML.dump(self, data, **lkw)
- res = st.getvalue()
- print(res)
- if unordered_lines:
- res = sorted(res.splitlines())
- expected = sorted(expected.splitlines())
- assert res == expected
-
- def round_trip(self, stream, **kw):
- from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
-
- assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str))
- lkw = kw.copy()
- if stream and stream[0] == "\n":
- stream = stream[1:]
- stream = textwrap.dedent(stream)
- data = srsly.ruamel_yaml.YAML.load(self, stream)
- outp = lkw.pop("outp", stream)
- lkw["stream"] = st = StringIO()
- srsly.ruamel_yaml.YAML.dump(self, data, **lkw)
- res = st.getvalue()
- if res != outp:
- diff(outp, res, "input string")
- assert res == outp
-
- def round_trip_all(self, stream, **kw):
- from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
-
- assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str))
- lkw = kw.copy()
- if stream and stream[0] == "\n":
- stream = stream[1:]
- stream = textwrap.dedent(stream)
- data = list(srsly.ruamel_yaml.YAML.load_all(self, stream))
- outp = lkw.pop("outp", stream)
- lkw["stream"] = st = StringIO()
- srsly.ruamel_yaml.YAML.dump_all(self, data, **lkw)
- res = st.getvalue()
- if res != outp:
- diff(outp, res, "input string")
- assert res == outp
-
- return MyYAML(**kw)
-
-
-def save_and_run(program, base_dir=None, output=None, file_name=None, optimized=False):
- """
- safe and run a python program, thereby circumventing any restrictions on module level
- imports
- """
- from subprocess import check_output, STDOUT, CalledProcessError
-
- if not hasattr(base_dir, "hash"):
- base_dir = Path(str(base_dir))
- if file_name is None:
- file_name = "safe_and_run_tmp.py"
- file_name = base_dir / file_name
- file_name.write_text(dedent(program))
-
- try:
- cmd = [sys.executable]
- if optimized:
- cmd.append("-O")
- cmd.append(str(file_name))
- print("running:", *cmd)
- res = check_output(cmd, stderr=STDOUT, universal_newlines=True)
- if output is not None:
- if "__pypy__" in sys.builtin_module_names:
- res = res.splitlines(True)
- res = [line for line in res if "no version info" not in line]
- res = "".join(res)
- print("result: ", res, end="")
- print("expected:", output, end="")
- assert res == output
- except CalledProcessError as exception:
- print("##### Running '{} {}' FAILED #####".format(sys.executable, file_name))
- print(exception.output)
- return exception.returncode
- return 0
diff --git a/srsly/tests/ruamel_yaml/test_a_dedent.py b/srsly/tests/ruamel_yaml/test_a_dedent.py
deleted file mode 100755
index 146dfd0..0000000
--- a/srsly/tests/ruamel_yaml/test_a_dedent.py
+++ /dev/null
@@ -1,55 +0,0 @@
-from .roundtrip import dedent
-
-
-class TestDedent:
- def test_start_newline(self):
- # fmt: off
- x = dedent("""
- 123
- 456
- """)
- # fmt: on
- assert x == "123\n 456\n"
-
- def test_start_space_newline(self):
- # special construct to prevent stripping of following whitespace
- # fmt: off
- x = dedent(" " """
- 123
- """)
- # fmt: on
- assert x == "123\n"
-
- def test_start_no_newline(self):
- # special construct to prevent stripping of following whitespac
- x = dedent(
- """\
- 123
- 456
- """
- )
- assert x == "123\n 456\n"
-
- def test_preserve_no_newline_at_end(self):
- x = dedent(
- """
- 123"""
- )
- assert x == "123"
-
- def test_preserve_no_newline_at_all(self):
- x = dedent(
- """\
- 123"""
- )
- assert x == "123"
-
- def test_multiple_dedent(self):
- x = dedent(
- dedent(
- """
- 123
- """
- )
- )
- assert x == "123\n"
diff --git a/srsly/tests/ruamel_yaml/test_add_xxx.py b/srsly/tests/ruamel_yaml/test_add_xxx.py
deleted file mode 100755
index 14ed103..0000000
--- a/srsly/tests/ruamel_yaml/test_add_xxx.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# coding: utf-8
-
-import re
-import pytest # NOQA
-
-from .roundtrip import dedent
-
-
-# from PyYAML docs
-class Dice(tuple):
- def __new__(cls, a, b):
- return tuple.__new__(cls, [a, b])
-
- def __repr__(self):
- return "Dice(%s,%s)" % self
-
-
-def dice_constructor(loader, node):
- value = loader.construct_scalar(node)
- a, b = map(int, value.split("d"))
- return Dice(a, b)
-
-
-def dice_representer(dumper, data):
- return dumper.represent_scalar(u"!dice", u"{}d{}".format(*data))
-
-
-def test_dice_constructor():
- import srsly.ruamel_yaml # NOQA
-
- srsly.ruamel_yaml.add_constructor(u"!dice", dice_constructor)
- with pytest.raises(ValueError):
- data = srsly.ruamel_yaml.load(
- "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader
- )
- assert str(data) == "{'initial hit points': Dice(8,4)}"
-
-
-def test_dice_constructor_with_loader():
- import srsly.ruamel_yaml # NOQA
-
- with pytest.raises(ValueError):
- srsly.ruamel_yaml.add_constructor(
- u"!dice", dice_constructor, Loader=srsly.ruamel_yaml.Loader
- )
- data = srsly.ruamel_yaml.load(
- "initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader
- )
- assert str(data) == "{'initial hit points': Dice(8,4)}"
-
-
-def test_dice_representer():
- import srsly.ruamel_yaml # NOQA
-
- srsly.ruamel_yaml.add_representer(Dice, dice_representer)
- # srsly.ruamel_yaml 0.15.8+ no longer forces quotes tagged scalars
- assert (
- srsly.ruamel_yaml.dump(dict(gold=Dice(10, 6)), default_flow_style=False)
- == "gold: !dice 10d6\n"
- )
-
-
-def test_dice_implicit_resolver():
- import srsly.ruamel_yaml # NOQA
-
- pattern = re.compile(r"^\d+d\d+$")
- with pytest.raises(ValueError):
- srsly.ruamel_yaml.add_implicit_resolver(u"!dice", pattern)
- assert (
- srsly.ruamel_yaml.dump(dict(treasure=Dice(10, 20)), default_flow_style=False)
- == "treasure: 10d20\n"
- )
- assert srsly.ruamel_yaml.load(
- "damage: 5d10", Loader=srsly.ruamel_yaml.Loader
- ) == dict(damage=Dice(5, 10))
-
-
-class Obj1(dict):
- def __init__(self, suffix):
- self._suffix = suffix
- self._node = None
-
- def add_node(self, n):
- self._node = n
-
- def __repr__(self):
- return "Obj1(%s->%s)" % (self._suffix, self.items())
-
- def dump(self):
- return repr(self._node)
-
-
-class YAMLObj1(object):
- yaml_tag = u"!obj:"
-
- @classmethod
- def from_yaml(cls, loader, suffix, node):
- import srsly.ruamel_yaml # NOQA
-
- obj1 = Obj1(suffix)
- if isinstance(node, srsly.ruamel_yaml.MappingNode):
- obj1.add_node(loader.construct_mapping(node))
- else:
- raise NotImplementedError
- return obj1
-
- @classmethod
- def to_yaml(cls, dumper, data):
- return dumper.represent_scalar(cls.yaml_tag + data._suffix, data.dump())
-
-
-def test_yaml_obj():
- import srsly.ruamel_yaml # NOQA
-
- srsly.ruamel_yaml.add_representer(Obj1, YAMLObj1.to_yaml)
- srsly.ruamel_yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml)
- with pytest.raises(ValueError):
- x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader)
- print(x)
- assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n"""
-
-
-def test_yaml_obj_with_loader_and_dumper():
- import srsly.ruamel_yaml # NOQA
-
- srsly.ruamel_yaml.add_representer(
- Obj1, YAMLObj1.to_yaml, Dumper=srsly.ruamel_yaml.Dumper
- )
- srsly.ruamel_yaml.add_multi_constructor(
- YAMLObj1.yaml_tag, YAMLObj1.from_yaml, Loader=srsly.ruamel_yaml.Loader
- )
- with pytest.raises(ValueError):
- x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader)
- # x = srsly.ruamel_yaml.load('!obj:x.2\na: 1')
- print(x)
- assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n"""
-
-
-# ToDo use nullege to search add_multi_representer and add_path_resolver
-# and add some test code
-
-# Issue 127 reported by Tommy Wang
-
-
-def test_issue_127():
- import srsly.ruamel_yaml # NOQA
-
- class Ref(srsly.ruamel_yaml.YAMLObject):
- yaml_constructor = srsly.ruamel_yaml.RoundTripConstructor
- yaml_representer = srsly.ruamel_yaml.RoundTripRepresenter
- yaml_tag = u"!Ref"
-
- def __init__(self, logical_id):
- self.logical_id = logical_id
-
- @classmethod
- def from_yaml(cls, loader, node):
- return cls(loader.construct_scalar(node))
-
- @classmethod
- def to_yaml(cls, dumper, data):
- if isinstance(data.logical_id, srsly.ruamel_yaml.scalarstring.ScalarString):
- style = data.logical_id.style # srsly.ruamel_yaml>0.15.8
- else:
- style = None
- return dumper.represent_scalar(cls.yaml_tag, data.logical_id, style=style)
-
- document = dedent(
- """\
- AList:
- - !Ref One
- - !Ref 'Two'
- - !Ref
- Two and a half
- BList: [!Ref Three, !Ref "Four"]
- CList:
- - Five Six
- - 'Seven Eight'
- """
- )
- data = srsly.ruamel_yaml.round_trip_load(document, preserve_quotes=True)
- assert srsly.ruamel_yaml.round_trip_dump(
- data, indent=4, block_seq_indent=2
- ) == document.replace("\n Two and", " Two and")
diff --git a/srsly/tests/ruamel_yaml/test_anchor.py b/srsly/tests/ruamel_yaml/test_anchor.py
deleted file mode 100755
index ad0859f..0000000
--- a/srsly/tests/ruamel_yaml/test_anchor.py
+++ /dev/null
@@ -1,575 +0,0 @@
-"""
-testing of anchors and the aliases referring to them
-"""
-
-import pytest
-from textwrap import dedent
-import platform
-import srsly
-
-from .roundtrip import (
- round_trip,
- dedent,
- round_trip_load,
- round_trip_dump,
- YAML,
-) # NOQA
-
-
-def load(s):
- return round_trip_load(dedent(s))
-
-
-def compare(d, s):
- assert round_trip_dump(d) == dedent(s)
-
-
-class TestAnchorsAliases:
- def test_anchor_id_renumber(self):
- from srsly.ruamel_yaml.serializer import Serializer
-
- assert Serializer.ANCHOR_TEMPLATE == "id%03d"
- data = load(
- """
- a: &id002
- b: 1
- c: 2
- d: *id002
- """
- )
- compare(
- data,
- """
- a: &id001
- b: 1
- c: 2
- d: *id001
- """,
- )
-
- def test_template_matcher(self):
- """test if id matches the anchor template"""
- from srsly.ruamel_yaml.serializer import templated_id
-
- assert templated_id(u"id001")
- assert templated_id(u"id999")
- assert templated_id(u"id1000")
- assert templated_id(u"id0001")
- assert templated_id(u"id0000")
- assert not templated_id(u"id02")
- assert not templated_id(u"id000")
- assert not templated_id(u"x000")
-
- # def test_re_matcher(self):
- # import re
- # assert re.compile(u'id(?!000)\\d{3,}').match('id001')
- # assert not re.compile(u'id(?!000\\d*)\\d{3,}').match('id000')
- # assert re.compile(u'id(?!000$)\\d{3,}').match('id0001')
-
- def test_anchor_assigned(self):
- from srsly.ruamel_yaml.comments import CommentedMap
-
- data = load(
- """
- a: &id002
- b: 1
- c: 2
- d: *id002
- e: &etemplate
- b: 1
- c: 2
- f: *etemplate
- """
- )
- d = data["d"]
- assert isinstance(d, CommentedMap)
- assert d.yaml_anchor() is None # got dropped as it matches pattern
- e = data["e"]
- assert isinstance(e, CommentedMap)
- assert e.yaml_anchor().value == "etemplate"
- assert e.yaml_anchor().always_dump is False
-
- def test_anchor_id_retained(self):
- data = load(
- """
- a: &id002
- b: 1
- c: 2
- d: *id002
- e: &etemplate
- b: 1
- c: 2
- f: *etemplate
- """
- )
- compare(
- data,
- """
- a: &id001
- b: 1
- c: 2
- d: *id001
- e: &etemplate
- b: 1
- c: 2
- f: *etemplate
- """,
- )
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_alias_before_anchor(self):
- from srsly.ruamel_yaml.composer import ComposerError
-
- with pytest.raises(ComposerError):
- data = load(
- """
- d: *id002
- a: &id002
- b: 1
- c: 2
- """
- )
- data = data
-
- def test_anchor_on_sequence(self):
- # as reported by Bjorn Stabell
- # https://bitbucket.org/ruamel/yaml/issue/7/anchor-names-not-preserved
- from srsly.ruamel_yaml.comments import CommentedSeq
-
- data = load(
- """
- nut1: &alice
- - 1
- - 2
- nut2: &blake
- - some data
- - *alice
- nut3:
- - *blake
- - *alice
- """
- )
- r = data["nut1"]
- assert isinstance(r, CommentedSeq)
- assert r.yaml_anchor() is not None
- assert r.yaml_anchor().value == "alice"
-
- merge_yaml = dedent(
- """
- - &CENTER {x: 1, y: 2}
- - &LEFT {x: 0, y: 2}
- - &BIG {r: 10}
- - &SMALL {r: 1}
- # All the following maps are equal:
- # Explicit keys
- - x: 1
- y: 2
- r: 10
- label: center/small
- # Merge one map
- - <<: *CENTER
- r: 10
- label: center/medium
- # Merge multiple maps
- - <<: [*CENTER, *BIG]
- label: center/big
- # Override
- - <<: [*BIG, *LEFT, *SMALL]
- x: 1
- label: center/huge
- """
- )
-
- def test_merge_00(self):
- data = load(self.merge_yaml)
- d = data[4]
- ok = True
- for k in d:
- for o in [5, 6, 7]:
- x = d.get(k)
- y = data[o].get(k)
- if not isinstance(x, int):
- x = x.split("/")[0]
- y = y.split("/")[0]
- if x != y:
- ok = False
- print("key", k, d.get(k), data[o].get(k))
- assert ok
-
- def test_merge_accessible(self):
- from srsly.ruamel_yaml.comments import CommentedMap, merge_attrib
-
- data = load(
- """
- k: &level_2 { a: 1, b2 }
- l: &level_1 { a: 10, c: 3 }
- m:
- <<: *level_1
- c: 30
- d: 40
- """
- )
- d = data["m"]
- assert isinstance(d, CommentedMap)
- assert hasattr(d, merge_attrib)
-
- def test_merge_01(self):
- data = load(self.merge_yaml)
- compare(data, self.merge_yaml)
-
- def test_merge_nested(self):
- yaml = """
- a:
- <<: &content
- 1: plugh
- 2: plover
- 0: xyzzy
- b:
- <<: *content
- """
- data = round_trip(yaml) # NOQA
-
- def test_merge_nested_with_sequence(self):
- yaml = """
- a:
- <<: &content
- <<: &y2
- 1: plugh
- 2: plover
- 0: xyzzy
- b:
- <<: [*content, *y2]
- """
- data = round_trip(yaml) # NOQA
-
- def test_add_anchor(self):
- from srsly.ruamel_yaml.comments import CommentedMap
-
- data = CommentedMap()
- data_a = CommentedMap()
- data["a"] = data_a
- data_a["c"] = 3
- data["b"] = 2
- data.yaml_set_anchor("klm", always_dump=True)
- data["a"].yaml_set_anchor("xyz", always_dump=True)
- compare(
- data,
- """
- &klm
- a: &xyz
- c: 3
- b: 2
- """,
- )
-
- # this is an error in PyYAML
- def test_reused_anchor(self):
- from srsly.ruamel_yaml.error import ReusedAnchorWarning
-
- yaml = """
- - &a
- x: 1
- - <<: *a
- - &a
- x: 2
- - <<: *a
- """
- with pytest.warns(ReusedAnchorWarning):
- data = round_trip(yaml) # NOQA
-
- def test_issue_130(self):
- # issue 130 reported by Devid Fee
- ys = dedent(
- """\
- components:
- server: &server_component
- type: spark.server:ServerComponent
- host: 0.0.0.0
- port: 8000
- shell: &shell_component
- type: spark.shell:ShellComponent
-
- services:
- server: &server_service
- <<: *server_component
- shell: &shell_service
- <<: *shell_component
- components:
- server: {<<: *server_service}
- """
- )
- data = srsly.ruamel_yaml.safe_load(ys)
- assert data["services"]["shell"]["components"]["server"]["port"] == 8000
-
- def test_issue_130a(self):
- # issue 130 reported by Devid Fee
- ys = dedent(
- """\
- components:
- server: &server_component
- type: spark.server:ServerComponent
- host: 0.0.0.0
- port: 8000
- shell: &shell_component
- type: spark.shell:ShellComponent
-
- services:
- server: &server_service
- <<: *server_component
- port: 4000
- shell: &shell_service
- <<: *shell_component
- components:
- server: {<<: *server_service}
- """
- )
- data = srsly.ruamel_yaml.safe_load(ys)
- assert data["services"]["shell"]["components"]["server"]["port"] == 4000
-
-
-class TestMergeKeysValues:
-
- yaml_str = dedent(
- """\
- - &mx
- a: x1
- b: x2
- c: x3
- - &my
- a: y1
- b: y2 # masked by the one in &mx
- d: y4
- -
- a: 1
- <<: [*mx, *my]
- m: 6
- """
- )
-
- # in the following d always has "expanded" the merges
-
- def test_merge_for(self):
- from srsly.ruamel_yaml import safe_load
-
- d = safe_load(self.yaml_str)
- data = round_trip_load(self.yaml_str)
- count = 0
- for x in data[2]:
- count += 1
- print(count, x)
- assert count == len(d[2])
-
- def test_merge_keys(self):
- from srsly.ruamel_yaml import safe_load
-
- d = safe_load(self.yaml_str)
- data = round_trip_load(self.yaml_str)
- count = 0
- for x in data[2].keys():
- count += 1
- print(count, x)
- assert count == len(d[2])
-
- def test_merge_values(self):
- from srsly.ruamel_yaml import safe_load
-
- d = safe_load(self.yaml_str)
- data = round_trip_load(self.yaml_str)
- count = 0
- for x in data[2].values():
- count += 1
- print(count, x)
- assert count == len(d[2])
-
- def test_merge_items(self):
- from srsly.ruamel_yaml import safe_load
-
- d = safe_load(self.yaml_str)
- data = round_trip_load(self.yaml_str)
- count = 0
- for x in data[2].items():
- count += 1
- print(count, x)
- assert count == len(d[2])
-
- def test_len_items_delete(self):
- from srsly.ruamel_yaml import safe_load
- from srsly.ruamel_yaml.compat import PY3
-
- d = safe_load(self.yaml_str)
- data = round_trip_load(self.yaml_str)
- x = data[2].items()
- print("d2 items", d[2].items(), len(d[2].items()), x, len(x))
- ref = len(d[2].items())
- print("ref", ref)
- assert len(x) == ref
- del data[2]["m"]
- if PY3:
- ref -= 1
- assert len(x) == ref
- del data[2]["d"]
- if PY3:
- ref -= 1
- assert len(x) == ref
- del data[2]["a"]
- if PY3:
- ref -= 1
- assert len(x) == ref
-
- def test_issue_196_cast_of_dict(self, capsys):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- mapping = yaml.load(
- """\
- anchored: &anchor
- a : 1
-
- mapping:
- <<: *anchor
- b: 2
- """
- )["mapping"]
-
- for k in mapping:
- print("k", k)
- for k in mapping.copy():
- print("kc", k)
-
- print("v", list(mapping.keys()))
- print("v", list(mapping.values()))
- print("v", list(mapping.items()))
- print(len(mapping))
- print("-----")
-
- # print({**mapping})
- # print(type({**mapping}))
- # assert 'a' in {**mapping}
- assert "a" in mapping
- x = {}
- for k in mapping:
- x[k] = mapping[k]
- assert "a" in x
- assert "a" in mapping.keys()
- assert mapping["a"] == 1
- assert mapping.__getitem__("a") == 1
- assert "a" in dict(mapping)
- assert "a" in dict(mapping.items())
-
- def test_values_of_merged(self):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- data = yaml.load(dedent(self.yaml_str))
- assert list(data[2].values()) == [1, 6, "x2", "x3", "y4"]
-
- def test_issue_213_copy_of_merge(self):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- d = yaml.load(
- """\
- foo: &foo
- a: a
- foo2:
- <<: *foo
- b: b
- """
- )["foo2"]
- assert d["a"] == "a"
- d2 = d.copy()
- assert d2["a"] == "a"
- print("d", d)
- del d["a"]
- assert "a" not in d
- assert "a" in d2
-
-
-class TestDuplicateKeyThroughAnchor:
- def test_duplicate_key_00(self):
- from srsly.ruamel_yaml import version_info
- from srsly.ruamel_yaml import safe_load, round_trip_load
- from srsly.ruamel_yaml.constructor import (
- DuplicateKeyFutureWarning,
- DuplicateKeyError,
- )
-
- s = dedent(
- """\
- &anchor foo:
- foo: bar
- *anchor : duplicate key
- baz: bat
- *anchor : duplicate key
- """
- )
- if version_info < (0, 15, 1):
- pass
- elif version_info < (0, 16, 0):
- with pytest.warns(DuplicateKeyFutureWarning):
- safe_load(s)
- with pytest.warns(DuplicateKeyFutureWarning):
- round_trip_load(s)
- else:
- with pytest.raises(DuplicateKeyError):
- safe_load(s)
- with pytest.raises(DuplicateKeyError):
- round_trip_load(s)
-
- def test_duplicate_key_01(self):
- # so issue https://stackoverflow.com/a/52852106/1307905
- from srsly.ruamel_yaml import version_info
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- s = dedent(
- """\
- - &name-name
- a: 1
- - &help-name
- b: 2
- - <<: *name-name
- <<: *help-name
- """
- )
- if version_info < (0, 15, 1):
- pass
- else:
- with pytest.raises(DuplicateKeyError):
- yaml = YAML(typ="safe")
- yaml.load(s)
- with pytest.raises(DuplicateKeyError):
- yaml = YAML()
- yaml.load(s)
-
-
-class TestFullCharSetAnchors:
- def test_master_of_orion(self):
- # https://bitbucket.org/ruamel/yaml/issues/72/not-allowed-in-anchor-names
- # submitted by Shalon Wood
- yaml_str = """
- - collection: &Backend.Civilizations.RacialPerk
- items:
- - key: perk_population_growth_modifier
- - *Backend.Civilizations.RacialPerk
- """
- data = load(yaml_str) # NOQA
-
- def test_roundtrip_00(self):
- yaml_str = """
- - &dotted.words.here
- a: 1
- b: 2
- - *dotted.words.here
- """
- data = round_trip(yaml_str) # NOQA
-
- def test_roundtrip_01(self):
- yaml_str = """
- - &dotted.words.here[a, b]
- - *dotted.words.here
- """
- data = load(yaml_str) # NOQA
- compare(data, yaml_str.replace("[", " [")) # an extra space is inserted
diff --git a/srsly/tests/ruamel_yaml/test_api_change.py b/srsly/tests/ruamel_yaml/test_api_change.py
deleted file mode 100755
index d2e4e1e..0000000
--- a/srsly/tests/ruamel_yaml/test_api_change.py
+++ /dev/null
@@ -1,247 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-"""
-testing of anchors and the aliases referring to them
-"""
-
-import sys
-import textwrap
-import pytest
-from pathlib import Path
-
-pytestmark = pytest.mark.filterwarnings(
- "ignore::pytest.PytestUnraisableExceptionWarning"
-)
-
-
-class TestNewAPI:
- def test_duplicate_keys_00(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- yaml = YAML()
- with pytest.raises(DuplicateKeyError):
- yaml.load("{a: 1, a: 2}")
-
- def test_duplicate_keys_01(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- yaml = YAML(typ="safe", pure=True)
- with pytest.raises(DuplicateKeyError):
- yaml.load("{a: 1, a: 2}")
-
- def test_duplicate_keys_02(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- yaml = YAML(typ="safe")
- with pytest.raises(DuplicateKeyError):
- yaml.load("{a: 1, a: 2}")
-
- def test_issue_135(self):
- # reported by Andrzej Ostrowski
- from srsly.ruamel_yaml import YAML
-
- data = {"a": 1, "b": 2}
- yaml = YAML(typ="safe")
- # originally on 2.7: with pytest.raises(TypeError):
- yaml.dump(data, sys.stdout)
-
- def test_issue_135_temporary_workaround(self):
- # never raised error
- from srsly.ruamel_yaml import YAML
-
- data = {"a": 1, "b": 2}
- yaml = YAML(typ="safe", pure=True)
- yaml.dump(data, sys.stdout)
-
-
-class TestWrite:
- def test_dump_path(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- fn = Path(str(tmpdir)) / "test.yaml"
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- yaml.dump(data, fn)
- assert fn.read_text() == "a: 1\nb: 2\n"
-
- def test_dump_file(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- fn = Path(str(tmpdir)) / "test.yaml"
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- with open(str(fn), "w") as fp:
- yaml.dump(data, fp)
- assert fn.read_text() == "a: 1\nb: 2\n"
-
- def test_dump_missing_stream(self):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- with pytest.raises(TypeError):
- yaml.dump(data)
-
- def test_dump_too_many_args(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- fn = Path(str(tmpdir)) / "test.yaml"
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- with pytest.raises(TypeError):
- yaml.dump(data, fn, True)
-
- def test_transform(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- def tr(s):
- return s.replace(" ", " ")
-
- fn = Path(str(tmpdir)) / "test.yaml"
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- yaml.dump(data, fn, transform=tr)
- assert fn.read_text() == "a: 1\nb: 2\n"
-
- def test_print(self, capsys):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- data = yaml.map()
- data["a"] = 1
- data["b"] = 2
- yaml.dump(data, sys.stdout)
- out, err = capsys.readouterr()
- assert out == "a: 1\nb: 2\n"
-
-
-class TestRead:
- def test_multi_load(self):
- # make sure reader, scanner, parser get reset
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- yaml.load("a: 1")
- yaml.load("a: 1") # did not work in 0.15.4
-
- def test_parse(self):
- # ensure `parse` method is functional and can parse "unsafe" yaml
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.constructor import ConstructorError
-
- yaml = YAML(typ="safe")
- s = "- !User0 {age: 18, name: Anthon}"
- # should fail to load
- with pytest.raises(ConstructorError):
- yaml.load(s)
- # should parse fine
- yaml = YAML(typ="safe")
- for _ in yaml.parse(s):
- pass
-
-
-class TestLoadAll:
- def test_multi_document_load(self, tmpdir):
- """this went wrong on 3.7 because of StopIteration, PR 37 and Issue 211"""
- from srsly.ruamel_yaml import YAML
-
- fn = Path(str(tmpdir)) / "test.yaml"
- fn.write_text(
- textwrap.dedent(
- u"""\
- ---
- - a
- ---
- - b
- ...
- """
- )
- )
- yaml = YAML()
- assert list(yaml.load_all(fn)) == [["a"], ["b"]]
-
-
-class TestDuplSet:
- def test_dupl_set_00(self):
- # round-trip-loader should except
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- yaml = YAML()
- with pytest.raises(DuplicateKeyError):
- yaml.load(
- textwrap.dedent(
- """\
- !!set
- ? a
- ? b
- ? c
- ? a
- """
- )
- )
-
-
-class TestDumpLoadUnicode:
- # test triggered by SamH on stackoverflow (https://stackoverflow.com/q/45281596/1307905)
- # and answer by randomir (https://stackoverflow.com/a/45281922/1307905)
- def test_write_unicode(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- text_dict = {"text": u"HELLO_WORLD©"}
- file_name = str(tmpdir) + "/tstFile.yaml"
- yaml.dump(text_dict, open(file_name, "w", encoding="utf8", newline="\n"))
- assert open(file_name, "rb").read().decode("utf-8") == u"text: HELLO_WORLD©\n"
-
- def test_read_unicode(self, tmpdir):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- file_name = str(tmpdir) + "/tstFile.yaml"
- with open(file_name, "wb") as fp:
- fp.write(u"text: HELLO_WORLD©\n".encode("utf-8"))
- with open(file_name, "r", encoding="utf8") as fp:
- text_dict = yaml.load(fp)
- assert text_dict["text"] == u"HELLO_WORLD©"
-
-
-class TestFlowStyle:
- def test_flow_style(self, capsys):
- # https://stackoverflow.com/questions/45791712/
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- yaml.default_flow_style = None
- data = yaml.map()
- data["b"] = 1
- data["a"] = [[1, 2], [3, 4]]
- yaml.dump(data, sys.stdout)
- out, err = capsys.readouterr()
- assert out == "b: 1\na:\n- [1, 2]\n- [3, 4]\n"
-
-
-class TestOldAPI:
- @pytest.mark.skipif(sys.version_info >= (3, 0), reason="ok on Py3")
- def test_duplicate_keys_02(self):
- # Issue 165 unicode keys in error/warning
- from srsly.ruamel_yaml import safe_load
- from srsly.ruamel_yaml.constructor import DuplicateKeyError
-
- with pytest.raises(DuplicateKeyError):
- safe_load("type: Doméstica\ntype: International")
diff --git a/srsly/tests/ruamel_yaml/test_appliance.py b/srsly/tests/ruamel_yaml/test_appliance.py
deleted file mode 100755
index bba0210..0000000
--- a/srsly/tests/ruamel_yaml/test_appliance.py
+++ /dev/null
@@ -1,220 +0,0 @@
-from __future__ import print_function
-
-import sys
-import os
-import types
-import traceback
-import pprint
-import argparse
-from srsly.ruamel_yaml.compat import PY3
-
-# DATA = 'tests/data'
-# determine the position of data dynamically relative to program
-# this allows running test while the current path is not the top of the
-# repository, e.g. from the tests/data directory: python ../test_yaml.py
-DATA = __file__.rsplit(os.sep, 2)[0] + "/data"
-
-
-def find_test_functions(collections):
- if not isinstance(collections, list):
- collections = [collections]
- functions = []
- for collection in collections:
- if not isinstance(collection, dict):
- collection = vars(collection)
- for key in sorted(collection):
- value = collection[key]
- if isinstance(value, types.FunctionType) and hasattr(value, "unittest"):
- functions.append(value)
- return functions
-
-
-def find_test_filenames(directory):
- filenames = {}
- for filename in os.listdir(directory):
- if os.path.isfile(os.path.join(directory, filename)):
- base, ext = os.path.splitext(filename)
- if base.endswith("-py2" if PY3 else "-py3"):
- continue
- filenames.setdefault(base, []).append(ext)
- filenames = sorted(filenames.items())
- return filenames
-
-
-def parse_arguments(args):
- """"""
- parser = argparse.ArgumentParser(
- usage=""" run the yaml tests. By default
- all functions on all appropriate test_files are run. Functions have
- unittest attributes that determine the required extensions to filenames
- that need to be available in order to run that test. E.g.\n\n
- python test_yaml.py test_constructor_types\n
- python test_yaml.py --verbose test_tokens spec-02-05\n\n
- The presence of an extension in the .skip attribute of a function
- disables the test for that function."""
- )
- # ToDo: make into int and test > 0 in functions
- parser.add_argument(
- "--verbose",
- "-v",
- action="store_true",
- default="YAML_TEST_VERBOSE" in os.environ,
- help="set verbosity output",
- )
- parser.add_argument(
- "--list-functions",
- action="store_true",
- help="""list all functions with required file extensions for test files
- """,
- )
- parser.add_argument("function", nargs="?", help="""restrict function to run""")
- parser.add_argument(
- "filenames",
- nargs="*",
- help="""basename of filename set, extensions (.code, .data) have to
- be a superset of those in the unittest attribute of the selected
- function""",
- )
- args = parser.parse_args(args)
- # print('args', args)
- verbose = args.verbose
- include_functions = [args.function] if args.function else []
- include_filenames = args.filenames
- # if args is None:
- # args = sys.argv[1:]
- # verbose = False
- # if '-v' in args:
- # verbose = True
- # args.remove('-v')
- # if '--verbose' in args:
- # verbose = True
- # args.remove('--verbose') # never worked without this
- # if 'YAML_TEST_VERBOSE' in os.environ:
- # verbose = True
- # include_functions = []
- # if args:
- # include_functions.append(args.pop(0))
- if "YAML_TEST_FUNCTIONS" in os.environ:
- include_functions.extend(os.environ["YAML_TEST_FUNCTIONS"].split())
- # include_filenames = []
- # include_filenames.extend(args)
- if "YAML_TEST_FILENAMES" in os.environ:
- include_filenames.extend(os.environ["YAML_TEST_FILENAMES"].split())
- return include_functions, include_filenames, verbose, args
-
-
-def execute(function, filenames, verbose):
- if PY3:
- name = function.__name__
- else:
- if hasattr(function, "unittest_name"):
- name = function.unittest_name
- else:
- name = function.func_name
- if verbose:
- sys.stdout.write("=" * 75 + "\n")
- sys.stdout.write("%s(%s)...\n" % (name, ", ".join(filenames)))
- try:
- function(verbose=verbose, *filenames)
- except Exception as exc:
- info = sys.exc_info()
- if isinstance(exc, AssertionError):
- kind = "FAILURE"
- else:
- kind = "ERROR"
- if verbose:
- traceback.print_exc(limit=1, file=sys.stdout)
- else:
- sys.stdout.write(kind[0])
- sys.stdout.flush()
- else:
- kind = "SUCCESS"
- info = None
- if not verbose:
- sys.stdout.write(".")
- sys.stdout.flush()
- return (name, filenames, kind, info)
-
-
-def display(results, verbose):
- if results and not verbose:
- sys.stdout.write("\n")
- total = len(results)
- failures = 0
- errors = 0
- for name, filenames, kind, info in results:
- if kind == "SUCCESS":
- continue
- if kind == "FAILURE":
- failures += 1
- if kind == "ERROR":
- errors += 1
- sys.stdout.write("=" * 75 + "\n")
- sys.stdout.write("%s(%s): %s\n" % (name, ", ".join(filenames), kind))
- if kind == "ERROR":
- traceback.print_exception(file=sys.stdout, *info)
- else:
- sys.stdout.write("Traceback (most recent call last):\n")
- traceback.print_tb(info[2], file=sys.stdout)
- sys.stdout.write("%s: see below\n" % info[0].__name__)
- sys.stdout.write("~" * 75 + "\n")
- for arg in info[1].args:
- pprint.pprint(arg, stream=sys.stdout)
- for filename in filenames:
- sys.stdout.write("-" * 75 + "\n")
- sys.stdout.write("%s:\n" % filename)
- if PY3:
- with open(filename, "r", errors="replace") as fp:
- data = fp.read()
- else:
- with open(filename, "rb") as fp:
- data = fp.read()
- sys.stdout.write(data)
- if data and data[-1] != "\n":
- sys.stdout.write("\n")
- sys.stdout.write("=" * 75 + "\n")
- sys.stdout.write("TESTS: %s\n" % total)
- ret_val = 0
- if failures:
- sys.stdout.write("FAILURES: %s\n" % failures)
- ret_val = 1
- if errors:
- sys.stdout.write("ERRORS: %s\n" % errors)
- ret_val = 2
- return ret_val
-
-
-def run(collections, args=None):
- test_functions = find_test_functions(collections)
- test_filenames = find_test_filenames(DATA)
- include_functions, include_filenames, verbose, a = parse_arguments(args)
- if a.list_functions:
- print("test functions:")
- for f in test_functions:
- print(" {:30s} {}".format(f.__name__, f.unittest))
- return
- results = []
- for function in test_functions:
- if include_functions and function.__name__ not in include_functions:
- continue
- if function.unittest:
- for base, exts in test_filenames:
- if include_filenames and base not in include_filenames:
- continue
- filenames = []
- for ext in function.unittest:
- if ext not in exts:
- break
- filenames.append(os.path.join(DATA, base + ext))
- else:
- skip_exts = getattr(function, "skip", [])
- for skip_ext in skip_exts:
- if skip_ext in exts:
- break
- else:
- result = execute(function, filenames, verbose)
- results.append(result)
- else:
- result = execute(function, [], verbose)
- results.append(result)
- return display(results, verbose=verbose)
diff --git a/srsly/tests/ruamel_yaml/test_class_register.py b/srsly/tests/ruamel_yaml/test_class_register.py
deleted file mode 100755
index 190f6a2..0000000
--- a/srsly/tests/ruamel_yaml/test_class_register.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# coding: utf-8
-
-"""
-testing of YAML.register_class and @yaml_object
-"""
-
-from .roundtrip import YAML
-
-
-class User0(object):
- def __init__(self, name, age):
- self.name = name
- self.age = age
-
-
-class User1(object):
- yaml_tag = u"!user"
-
- def __init__(self, name, age):
- self.name = name
- self.age = age
-
- @classmethod
- def to_yaml(cls, representer, node):
- return representer.represent_scalar(
- cls.yaml_tag, u"{.name}-{.age}".format(node, node)
- )
-
- @classmethod
- def from_yaml(cls, constructor, node):
- return cls(*node.value.split("-"))
-
-
-class TestRegisterClass(object):
- def test_register_0_rt(self):
- yaml = YAML()
- yaml.register_class(User0)
- ys = """
- - !User0
- name: Anthon
- age: 18
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys, unordered_lines=True)
-
- def test_register_0_safe(self):
- # default_flow_style = None
- yaml = YAML(typ="safe")
- yaml.register_class(User0)
- ys = """
- - !User0 {age: 18, name: Anthon}
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_register_0_unsafe(self):
- # default_flow_style = None
- yaml = YAML(typ="unsafe")
- yaml.register_class(User0)
- ys = """
- - !User0 {age: 18, name: Anthon}
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_register_1_rt(self):
- yaml = YAML()
- yaml.register_class(User1)
- ys = """
- - !user Anthon-18
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_register_1_safe(self):
- yaml = YAML(typ="safe")
- yaml.register_class(User1)
- ys = """
- [!user Anthon-18]
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_register_1_unsafe(self):
- yaml = YAML(typ="unsafe")
- yaml.register_class(User1)
- ys = """
- [!user Anthon-18]
- """
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
-
-class TestDecorator(object):
- def test_decorator_implicit(self):
- from srsly.ruamel_yaml import yaml_object
-
- yml = YAML()
-
- @yaml_object(yml)
- class User2(object):
- def __init__(self, name, age):
- self.name = name
- self.age = age
-
- ys = """
- - !User2
- name: Anthon
- age: 18
- """
- d = yml.load(ys)
- yml.dump(d, compare=ys, unordered_lines=True)
-
- def test_decorator_explicit(self):
- from srsly.ruamel_yaml import yaml_object
-
- yml = YAML()
-
- @yaml_object(yml)
- class User3(object):
- yaml_tag = u"!USER"
-
- def __init__(self, name, age):
- self.name = name
- self.age = age
-
- @classmethod
- def to_yaml(cls, representer, node):
- return representer.represent_scalar(
- cls.yaml_tag, u"{.name}-{.age}".format(node, node)
- )
-
- @classmethod
- def from_yaml(cls, constructor, node):
- return cls(*node.value.split("-"))
-
- ys = """
- - !USER Anthon-18
- """
- d = yml.load(ys)
- yml.dump(d, compare=ys)
diff --git a/srsly/tests/ruamel_yaml/test_collections.py b/srsly/tests/ruamel_yaml/test_collections.py
deleted file mode 100755
index 632aae2..0000000
--- a/srsly/tests/ruamel_yaml/test_collections.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# coding: utf-8
-
-"""
-collections.OrderedDict is a new class not supported by PyYAML (issue 83 by Frazer McLean)
-
-This is now so integrated in Python that it can be mapped to !!omap
-
-"""
-
-import pytest # NOQA
-
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-class TestOrderedDict:
- def test_ordereddict(self):
- from collections import OrderedDict
- import srsly.ruamel_yaml # NOQA
-
- assert srsly.ruamel_yaml.dump(OrderedDict()) == "!!omap []\n"
diff --git a/srsly/tests/ruamel_yaml/test_comment_manipulation.py b/srsly/tests/ruamel_yaml/test_comment_manipulation.py
deleted file mode 100755
index 95ce934..0000000
--- a/srsly/tests/ruamel_yaml/test_comment_manipulation.py
+++ /dev/null
@@ -1,639 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-def load(s):
- return round_trip_load(dedent(s))
-
-
-def compare(data, s, **kw):
- assert round_trip_dump(data, **kw) == dedent(s)
-
-
-def compare_eol(data, s):
- assert "EOL" in s
- ds = dedent(s).replace("EOL", "").replace("\n", "|\n")
- assert round_trip_dump(data).replace("\n", "|\n") == ds
-
-
-class TestCommentsManipulation:
-
- # list
- def test_seq_set_comment_on_existing_explicit_column(self):
- data = load(
- """
- - a # comment 1
- - b
- - c
- """
- )
- data.yaml_add_eol_comment("comment 2", key=1, column=6)
- exp = """
- - a # comment 1
- - b # comment 2
- - c
- """
- compare(data, exp)
-
- def test_seq_overwrite_comment_on_existing_explicit_column(self):
- data = load(
- """
- - a # comment 1
- - b
- - c
- """
- )
- data.yaml_add_eol_comment("comment 2", key=0, column=6)
- exp = """
- - a # comment 2
- - b
- - c
- """
- compare(data, exp)
-
- def test_seq_first_comment_explicit_column(self):
- data = load(
- """
- - a
- - b
- - c
- """
- )
- data.yaml_add_eol_comment("comment 1", key=1, column=6)
- exp = """
- - a
- - b # comment 1
- - c
- """
- compare(data, exp)
-
- def test_seq_set_comment_on_existing_column_prev(self):
- data = load(
- """
- - a # comment 1
- - b
- - c
- - d # comment 3
- """
- )
- data.yaml_add_eol_comment("comment 2", key=1)
- exp = """
- - a # comment 1
- - b # comment 2
- - c
- - d # comment 3
- """
- compare(data, exp)
-
- def test_seq_set_comment_on_existing_column_next(self):
- data = load(
- """
- - a # comment 1
- - b
- - c
- - d # comment 3
- """
- )
- print(data._yaml_comment)
- # print(type(data._yaml_comment._items[0][0].start_mark))
- # srsly.ruamel_yaml.error.Mark
- # print(type(data._yaml_comment._items[0][0].start_mark))
- data.yaml_add_eol_comment("comment 2", key=2)
- exp = """
- - a # comment 1
- - b
- - c # comment 2
- - d # comment 3
- """
- compare(data, exp)
-
- def test_seq_set_comment_on_existing_column_further_away(self):
- """
- no comment line before or after, take the latest before
- the new position
- """
- data = load(
- """
- - a # comment 1
- - b
- - c
- - d
- - e
- - f # comment 3
- """
- )
- print(data._yaml_comment)
- # print(type(data._yaml_comment._items[0][0].start_mark))
- # srsly.ruamel_yaml.error.Mark
- # print(type(data._yaml_comment._items[0][0].start_mark))
- data.yaml_add_eol_comment("comment 2", key=3)
- exp = """
- - a # comment 1
- - b
- - c
- - d # comment 2
- - e
- - f # comment 3
- """
- compare(data, exp)
-
- def test_seq_set_comment_on_existing_explicit_column_with_hash(self):
- data = load(
- """
- - a # comment 1
- - b
- - c
- """
- )
- data.yaml_add_eol_comment("# comment 2", key=1, column=6)
- exp = """
- - a # comment 1
- - b # comment 2
- - c
- """
- compare(data, exp)
-
- # dict
-
- def test_dict_set_comment_on_existing_explicit_column(self):
- data = load(
- """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4
- e: 5
- """
- )
- data.yaml_add_eol_comment("comment 2", key="c", column=7)
- exp = """
- a: 1 # comment 1
- b: 2
- c: 3 # comment 2
- d: 4
- e: 5
- """
- compare(data, exp)
-
- def test_dict_overwrite_comment_on_existing_explicit_column(self):
- data = load(
- """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4
- e: 5
- """
- )
- data.yaml_add_eol_comment("comment 2", key="a", column=7)
- exp = """
- a: 1 # comment 2
- b: 2
- c: 3
- d: 4
- e: 5
- """
- compare(data, exp)
-
- def test_map_set_comment_on_existing_column_prev(self):
- data = load(
- """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4
- e: 5 # comment 3
- """
- )
- data.yaml_add_eol_comment("comment 2", key="b")
- exp = """
- a: 1 # comment 1
- b: 2 # comment 2
- c: 3
- d: 4
- e: 5 # comment 3
- """
- compare(data, exp)
-
- def test_map_set_comment_on_existing_column_next(self):
- data = load(
- """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4
- e: 5 # comment 3
- """
- )
- data.yaml_add_eol_comment("comment 2", key="d")
- exp = """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4 # comment 2
- e: 5 # comment 3
- """
- compare(data, exp)
-
- def test_map_set_comment_on_existing_column_further_away(self):
- """
- no comment line before or after, take the latest before
- the new position
- """
- data = load(
- """
- a: 1 # comment 1
- b: 2
- c: 3
- d: 4
- e: 5 # comment 3
- """
- )
- data.yaml_add_eol_comment("comment 2", key="c")
- print(round_trip_dump(data))
- exp = """
- a: 1 # comment 1
- b: 2
- c: 3 # comment 2
- d: 4
- e: 5 # comment 3
- """
- compare(data, exp)
-
- def test_before_top_map_rt(self):
- data = load(
- """
- a: 1
- b: 2
- """
- )
- data.yaml_set_start_comment("Hello\nWorld\n")
- exp = """
- # Hello
- # World
- a: 1
- b: 2
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_top_map_replace(self):
- data = load(
- """
- # abc
- # def
- a: 1 # 1
- b: 2
- """
- )
- data.yaml_set_start_comment("Hello\nWorld\n")
- exp = """
- # Hello
- # World
- a: 1 # 1
- b: 2
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_top_map_from_scratch(self):
- from srsly.ruamel_yaml.comments import CommentedMap
-
- data = CommentedMap()
- data["a"] = 1
- data["b"] = 2
- data.yaml_set_start_comment("Hello\nWorld\n")
- # print(data.ca)
- # print(data.ca._items)
- exp = """
- # Hello
- # World
- a: 1
- b: 2
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_top_seq_rt(self):
- data = load(
- """
- - a
- - b
- """
- )
- data.yaml_set_start_comment("Hello\nWorld\n")
- print(round_trip_dump(data))
- exp = """
- # Hello
- # World
- - a
- - b
- """
- compare(data, exp)
-
- def test_before_top_seq_rt_replace(self):
- s = """
- # this
- # that
- - a
- - b
- """
- data = load(s.format(comment="#"))
- data.yaml_set_start_comment("Hello\nWorld\n")
- print(round_trip_dump(data))
- exp = """
- # Hello
- # World
- - a
- - b
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_top_seq_from_scratch(self):
- from srsly.ruamel_yaml.comments import CommentedSeq
-
- data = CommentedSeq()
- data.append("a")
- data.append("b")
- data.yaml_set_start_comment("Hello\nWorld\n")
- print(round_trip_dump(data))
- exp = """
- # Hello
- # World
- - a
- - b
- """
- compare(data, exp.format(comment="#"))
-
- # nested variants
- def test_before_nested_map_rt(self):
- data = load(
- """
- a: 1
- b:
- c: 2
- d: 3
- """
- )
- data["b"].yaml_set_start_comment("Hello\nWorld\n")
- exp = """
- a: 1
- b:
- # Hello
- # World
- c: 2
- d: 3
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_nested_map_rt_indent(self):
- data = load(
- """
- a: 1
- b:
- c: 2
- d: 3
- """
- )
- data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
- exp = """
- a: 1
- b:
- # Hello
- # World
- c: 2
- d: 3
- """
- compare(data, exp.format(comment="#"))
- print(data["b"].ca)
-
- def test_before_nested_map_from_scratch(self):
- from srsly.ruamel_yaml.comments import CommentedMap
-
- data = CommentedMap()
- datab = CommentedMap()
- data["a"] = 1
- data["b"] = datab
- datab["c"] = 2
- datab["d"] = 3
- data["b"].yaml_set_start_comment("Hello\nWorld\n")
- exp = """
- a: 1
- b:
- # Hello
- # World
- c: 2
- d: 3
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_nested_seq_from_scratch(self):
- from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq
-
- data = CommentedMap()
- datab = CommentedSeq()
- data["a"] = 1
- data["b"] = datab
- datab.append("c")
- datab.append("d")
- data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
- exp = """
- a: 1
- b:
- # Hello
- # World
- - c
- - d
- """
- compare(data, exp.format(comment="#"))
-
- def test_before_nested_seq_from_scratch_block_seq_indent(self):
- from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq
-
- data = CommentedMap()
- datab = CommentedSeq()
- data["a"] = 1
- data["b"] = datab
- datab.append("c")
- datab.append("d")
- data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
- exp = """
- a: 1
- b:
- # Hello
- # World
- - c
- - d
- """
- compare(data, exp.format(comment="#"), indent=4, block_seq_indent=2)
-
- def test_map_set_comment_before_and_after_non_first_key_00(self):
- # http://stackoverflow.com/a/40705671/1307905
- data = load(
- """
- xyz:
- a: 1 # comment 1
- b: 2
-
- test1:
- test2:
- test3: 3
- """
- )
- data.yaml_set_comment_before_after_key(
- "test1", "before test1 (top level)", after="before test2"
- )
- data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
- exp = """
- xyz:
- a: 1 # comment 1
- b: 2
-
- # before test1 (top level)
- test1:
- # before test2
- test2:
- # after test2
- test3: 3
- """
- compare(data, exp)
-
- def Xtest_map_set_comment_before_and_after_non_first_key_01(self):
- data = load(
- """
- xyz:
- a: 1 # comment 1
- b: 2
-
- test1:
- test2:
- test3: 3
- """
- )
- data.yaml_set_comment_before_after_key(
- "test1", "before test1 (top level)", after="before test2\n\n"
- )
- data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
- # EOL is needed here as dedenting gets rid of spaces (as well as does Emacs
- exp = """
- xyz:
- a: 1 # comment 1
- b: 2
-
- # before test1 (top level)
- test1:
- # before test2
- EOL
- test2:
- # after test2
- test3: 3
- """
- compare_eol(data, exp)
-
- # EOL is no longer necessary
- # fixed together with issue # 216
- def test_map_set_comment_before_and_after_non_first_key_01(self):
- data = load(
- """
- xyz:
- a: 1 # comment 1
- b: 2
-
- test1:
- test2:
- test3: 3
- """
- )
- data.yaml_set_comment_before_after_key(
- "test1", "before test1 (top level)", after="before test2\n\n"
- )
- data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
- exp = """
- xyz:
- a: 1 # comment 1
- b: 2
-
- # before test1 (top level)
- test1:
- # before test2
-
- test2:
- # after test2
- test3: 3
- """
- compare(data, exp)
-
- def Xtest_map_set_comment_before_and_after_non_first_key_02(self):
- data = load(
- """
- xyz:
- a: 1 # comment 1
- b: 2
-
- test1:
- test2:
- test3: 3
- """
- )
- data.yaml_set_comment_before_after_key(
- "test1",
- "xyz\n\nbefore test1 (top level)",
- after="\nbefore test2",
- after_indent=4,
- )
- data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
- # EOL is needed here as dedenting gets rid of spaces (as well as does Emacs
- exp = """
- xyz:
- a: 1 # comment 1
- b: 2
-
- # xyz
-
- # before test1 (top level)
- test1:
- EOL
- # before test2
- test2:
- # after test2
- test3: 3
- """
- compare_eol(data, exp)
-
- def test_map_set_comment_before_and_after_non_first_key_02(self):
- data = load(
- """
- xyz:
- a: 1 # comment 1
- b: 2
-
- test1:
- test2:
- test3: 3
- """
- )
- data.yaml_set_comment_before_after_key(
- "test1",
- "xyz\n\nbefore test1 (top level)",
- after="\nbefore test2",
- after_indent=4,
- )
- data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
- exp = """
- xyz:
- a: 1 # comment 1
- b: 2
-
- # xyz
-
- # before test1 (top level)
- test1:
-
- # before test2
- test2:
- # after test2
- test3: 3
- """
- compare(data, exp)
diff --git a/srsly/tests/ruamel_yaml/test_comments.py b/srsly/tests/ruamel_yaml/test_comments.py
deleted file mode 100755
index b65198d..0000000
--- a/srsly/tests/ruamel_yaml/test_comments.py
+++ /dev/null
@@ -1,968 +0,0 @@
-# coding: utf-8
-
-"""
-comment testing is all about roundtrips
-these can be done in the "old" way by creating a file.data and file.roundtrip
-but there is little flexibility in doing that
-
-but some things are not easily tested, eog. how a
-roundtrip changes
-
-"""
-
-import pytest
-import sys
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump
-
-
-class TestComments:
- def test_no_end_of_file_eol(self):
- """not excluding comments caused some problems if at the end of
- the file without a newline. First error, then included \0 """
- x = """\
- - europe: 10 # abc"""
- round_trip(x, extra="\n")
- with pytest.raises(AssertionError):
- round_trip(x, extra="a\n")
-
- def test_no_comments(self):
- round_trip(
- """
- - europe: 10
- - usa:
- - ohio: 2
- - california: 9
- """
- )
-
- def test_round_trip_ordering(self):
- round_trip(
- """
- a: 1
- b: 2
- c: 3
- b1: 2
- b2: 2
- d: 4
- e: 5
- f: 6
- """
- )
-
- def test_complex(self):
- round_trip(
- """
- - europe: 10 # top
- - usa:
- - ohio: 2
- - california: 9 # o
- """
- )
-
- def test_dropped(self):
- s = """\
- # comment
- scalar
- ...
- """
- round_trip(s, "scalar\n...\n")
-
- def test_main_mapping_begin_end(self):
- round_trip(
- """
- # C start a
- # C start b
- abc: 1
- ghi: 2
- klm: 3
- # C end a
- # C end b
- """
- )
-
- def test_reindent(self):
- x = """\
- a:
- b: # comment 1
- c: 1 # comment 2
- """
- d = round_trip_load(x)
- y = round_trip_dump(d, indent=4)
- assert y == dedent(
- """\
- a:
- b: # comment 1
- c: 1 # comment 2
- """
- )
-
- def test_main_mapping_begin_end_items_post(self):
- round_trip(
- """
- # C start a
- # C start b
- abc: 1 # abc comment
- ghi: 2
- klm: 3 # klm comment
- # C end a
- # C end b
- """
- )
-
- def test_main_sequence_begin_end(self):
- round_trip(
- """
- # C start a
- # C start b
- - abc
- - ghi
- - klm
- # C end a
- # C end b
- """
- )
-
- def test_main_sequence_begin_end_items_post(self):
- round_trip(
- """
- # C start a
- # C start b
- - abc # abc comment
- - ghi
- - klm # klm comment
- # C end a
- # C end b
- """
- )
-
- def test_main_mapping_begin_end_complex(self):
- round_trip(
- """
- # C start a
- # C start b
- abc: 1
- ghi: 2
- klm:
- 3a: alpha
- 3b: beta # it is all greek to me
- # C end a
- # C end b
- """
- )
-
- def test_09(self): # 2.9 from the examples in the spec
- s = """\
- hr: # 1998 hr ranking
- - Mark McGwire
- - Sammy Sosa
- rbi:
- # 1998 rbi ranking
- - Sammy Sosa
- - Ken Griffey
- """
- round_trip(s, indent=4, block_seq_indent=2)
-
- def test_09a(self):
- round_trip(
- """
- hr: # 1998 hr ranking
- - Mark McGwire
- - Sammy Sosa
- rbi:
- # 1998 rbi ranking
- - Sammy Sosa
- - Ken Griffey
- """
- )
-
- def test_simple_map_middle_comment(self):
- round_trip(
- """
- abc: 1
- # C 3a
- # C 3b
- ghi: 2
- """
- )
-
- def test_map_in_map_0(self):
- round_trip(
- """
- map1: # comment 1
- # comment 2
- map2:
- key1: val1
- """
- )
-
- def test_map_in_map_1(self):
- # comment is moved from value to key
- round_trip(
- """
- map1:
- # comment 1
- map2:
- key1: val1
- """
- )
-
- def test_application_arguments(self):
- # application configur
- round_trip(
- """
- args:
- username: anthon
- passwd: secret
- fullname: Anthon van der Neut
- tmux:
- session-name: test
- loop:
- wait: 10
- """
- )
-
- def test_substitute(self):
- x = """
- args:
- username: anthon # name
- passwd: secret # password
- fullname: Anthon van der Neut
- tmux:
- session-name: test
- loop:
- wait: 10
- """
- data = round_trip_load(x)
- data["args"]["passwd"] = "deleted password"
- # note the requirement to add spaces for alignment of comment
- x = x.replace(": secret ", ": deleted password")
- assert round_trip_dump(data) == dedent(x)
-
- def test_set_comment(self):
- round_trip(
- """
- !!set
- # the beginning
- ? a
- # next one is B (lowercase)
- ? b # You see? Promised you.
- ? c
- # this is the end
- """
- )
-
- def test_omap_comment_roundtrip(self):
- round_trip(
- """
- !!omap
- - a: 1
- - b: 2 # two
- - c: 3 # three
- - d: 4
- """
- )
-
- def test_omap_comment_roundtrip_pre_comment(self):
- round_trip(
- """
- !!omap
- - a: 1
- - b: 2 # two
- - c: 3 # three
- # last one
- - d: 4
- """
- )
-
- def test_non_ascii(self):
- round_trip(
- """
- verbosity: 1 # 0 is minimal output, -1 none
- base_url: http://gopher.net
- special_indices: [1, 5, 8]
- also_special:
- - a
- - 19
- - 32
- asia and europe: &asia_europe
- Turkey: Ankara
- Russia: Moscow
- countries:
- Asia:
- <<: *asia_europe
- Japan: Tokyo # 東京
- Europe:
- <<: *asia_europe
- Spain: Madrid
- Italy: Rome
- """
- )
-
- def test_dump_utf8(self):
- import srsly.ruamel_yaml # NOQA
-
- x = dedent(
- """\
- ab:
- - x # comment
- - y # more comment
- """
- )
- data = round_trip_load(x)
- dumper = srsly.ruamel_yaml.RoundTripDumper
- for utf in [True, False]:
- y = srsly.ruamel_yaml.dump(
- data, default_flow_style=False, Dumper=dumper, allow_unicode=utf
- )
- assert y == x
-
- def test_dump_unicode_utf8(self):
- import srsly.ruamel_yaml # NOQA
-
- x = dedent(
- u"""\
- ab:
- - x # comment
- - y # more comment
- """
- )
- data = round_trip_load(x)
- dumper = srsly.ruamel_yaml.RoundTripDumper
- for utf in [True, False]:
- y = srsly.ruamel_yaml.dump(
- data, default_flow_style=False, Dumper=dumper, allow_unicode=utf
- )
- assert y == x
-
- def test_mlget_00(self):
- x = """\
- a:
- - b:
- c: 42
- - d:
- f: 196
- e:
- g: 3.14
- """
- d = round_trip_load(x)
- assert d.mlget(["a", 1, "d", "f"], list_ok=True) == 196
- with pytest.raises(AssertionError):
- d.mlget(["a", 1, "d", "f"]) == 196
-
-
-class TestInsertPopList:
- """list insertion is more complex than dict insertion, as you
- need to move the values to subsequent keys on insert"""
-
- @property
- def ins(self):
- return """\
- ab:
- - a # a
- - b # b
- - c
- - d # d
-
- de:
- - 1
- - 2
- """
-
- def test_insert_0(self):
- d = round_trip_load(self.ins)
- d["ab"].insert(0, "xyz")
- y = round_trip_dump(d, indent=2)
- assert y == dedent(
- """\
- ab:
- - xyz
- - a # a
- - b # b
- - c
- - d # d
-
- de:
- - 1
- - 2
- """
- )
-
- def test_insert_1(self):
- d = round_trip_load(self.ins)
- d["ab"].insert(4, "xyz")
- y = round_trip_dump(d, indent=2)
- assert y == dedent(
- """\
- ab:
- - a # a
- - b # b
- - c
- - d # d
-
- - xyz
- de:
- - 1
- - 2
- """
- )
-
- def test_insert_2(self):
- d = round_trip_load(self.ins)
- d["ab"].insert(1, "xyz")
- y = round_trip_dump(d, indent=2)
- assert y == dedent(
- """\
- ab:
- - a # a
- - xyz
- - b # b
- - c
- - d # d
-
- de:
- - 1
- - 2
- """
- )
-
- def test_pop_0(self):
- d = round_trip_load(self.ins)
- d["ab"].pop(0)
- y = round_trip_dump(d, indent=2)
- print(y)
- assert y == dedent(
- """\
- ab:
- - b # b
- - c
- - d # d
-
- de:
- - 1
- - 2
- """
- )
-
- def test_pop_1(self):
- d = round_trip_load(self.ins)
- d["ab"].pop(1)
- y = round_trip_dump(d, indent=2)
- print(y)
- assert y == dedent(
- """\
- ab:
- - a # a
- - c
- - d # d
-
- de:
- - 1
- - 2
- """
- )
-
- def test_pop_2(self):
- d = round_trip_load(self.ins)
- d["ab"].pop(2)
- y = round_trip_dump(d, indent=2)
- print(y)
- assert y == dedent(
- """\
- ab:
- - a # a
- - b # b
- - d # d
-
- de:
- - 1
- - 2
- """
- )
-
- def test_pop_3(self):
- d = round_trip_load(self.ins)
- d["ab"].pop(3)
- y = round_trip_dump(d, indent=2)
- print(y)
- assert y == dedent(
- """\
- ab:
- - a # a
- - b # b
- - c
- de:
- - 1
- - 2
- """
- )
-
-
-# inspired by demux' question on stackoverflow
-# http://stackoverflow.com/a/36970608/1307905
-class TestInsertInMapping:
- @property
- def ins(self):
- return """\
- first_name: Art
- occupation: Architect # This is an occupation comment
- about: Art Vandelay is a fictional character that George invents...
- """
-
- def test_insert_at_pos_1(self):
- d = round_trip_load(self.ins)
- d.insert(1, "last name", "Vandelay", comment="new key")
- y = round_trip_dump(d)
- print(y)
- assert y == dedent(
- """\
- first_name: Art
- last name: Vandelay # new key
- occupation: Architect # This is an occupation comment
- about: Art Vandelay is a fictional character that George invents...
- """
- )
-
- def test_insert_at_pos_0(self):
- d = round_trip_load(self.ins)
- d.insert(0, "last name", "Vandelay", comment="new key")
- y = round_trip_dump(d)
- print(y)
- assert y == dedent(
- """\
- last name: Vandelay # new key
- first_name: Art
- occupation: Architect # This is an occupation comment
- about: Art Vandelay is a fictional character that George invents...
- """
- )
-
- def test_insert_at_pos_3(self):
- # much more simple if done with appending.
- d = round_trip_load(self.ins)
- d.insert(3, "last name", "Vandelay", comment="new key")
- y = round_trip_dump(d)
- print(y)
- assert y == dedent(
- """\
- first_name: Art
- occupation: Architect # This is an occupation comment
- about: Art Vandelay is a fictional character that George invents...
- last name: Vandelay # new key
- """
- )
-
-
-class TestCommentedMapMerge:
- def test_in_operator(self):
- data = round_trip_load(
- """
- x: &base
- a: 1
- b: 2
- c: 3
- y:
- <<: *base
- k: 4
- l: 5
- """
- )
- assert data["x"]["a"] == 1
- assert "a" in data["x"]
- assert data["y"]["a"] == 1
- assert "a" in data["y"]
-
- def test_issue_60(self):
- data = round_trip_load(
- """
- x: &base
- a: 1
- y:
- <<: *base
- """
- )
- assert data["x"]["a"] == 1
- assert data["y"]["a"] == 1
- if sys.version_info >= (3, 12):
- assert str(data["y"]) == """ordereddict({'a': 1})"""
- else:
- assert str(data["y"]) == """ordereddict([('a', 1)])"""
-
- def test_issue_60_1(self):
- data = round_trip_load(
- """
- x: &base
- a: 1
- y:
- <<: *base
- b: 2
- """
- )
- assert data["x"]["a"] == 1
- assert data["y"]["a"] == 1
- if sys.version_info >= (3, 12):
- assert str(data["y"]) == """ordereddict({'b': 2, 'a': 1})"""
- else:
- assert str(data["y"]) == """ordereddict([('b', 2), ('a', 1)])"""
-
-
-class TestEmptyLines:
- # prompted by issue 46 from Alex Harvey
- def test_issue_46(self):
- yaml_str = dedent(
- """\
- ---
- # Please add key/value pairs in alphabetical order
-
- aws_s3_bucket: 'mys3bucket'
-
- jenkins_ad_credentials:
- bind_name: 'CN=svc-AAA-BBB-T,OU=Example,DC=COM,DC=EXAMPLE,DC=Local'
- bind_pass: 'xxxxyyyy{'
- """
- )
- d = round_trip_load(yaml_str, preserve_quotes=True)
- y = round_trip_dump(d, explicit_start=True)
- assert yaml_str == y
-
- def test_multispace_map(self):
- round_trip(
- """
- a: 1x
-
- b: 2x
-
-
- c: 3x
-
-
-
- d: 4x
-
- """
- )
-
- @pytest.mark.xfail(strict=True)
- def test_multispace_map_initial(self):
- round_trip(
- """
-
- a: 1x
-
- b: 2x
-
-
- c: 3x
-
-
-
- d: 4x
-
- """
- )
-
- def test_embedded_map(self):
- round_trip(
- """
- - a: 1y
- b: 2y
-
- c: 3y
- """
- )
-
- def test_toplevel_seq(self):
- round_trip(
- """\
- - 1
-
- - 2
-
- - 3
- """
- )
-
- def test_embedded_seq(self):
- round_trip(
- """
- a:
- b:
- - 1
-
- - 2
-
-
- - 3
- """
- )
-
- def test_line_with_only_spaces(self):
- # issue 54
- yaml_str = "---\n\na: 'x'\n \nb: y\n"
- d = round_trip_load(yaml_str, preserve_quotes=True)
- y = round_trip_dump(d, explicit_start=True)
- stripped = ""
- for line in yaml_str.splitlines():
- stripped += line.rstrip() + "\n"
- print(line + "$")
- assert stripped == y
-
- def test_some_eol_spaces(self):
- # spaces after tokens and on empty lines
- yaml_str = '--- \n \na: "x" \n \nb: y \n'
- d = round_trip_load(yaml_str, preserve_quotes=True)
- y = round_trip_dump(d, explicit_start=True)
- stripped = ""
- for line in yaml_str.splitlines():
- stripped += line.rstrip() + "\n"
- print(line + "$")
- assert stripped == y
-
- def test_issue_54_not_ok(self):
- yaml_str = dedent(
- """\
- toplevel:
-
- # some comment
- sublevel: 300
- """
- )
- d = round_trip_load(yaml_str)
- print(d.ca)
- y = round_trip_dump(d, indent=4)
- print(y.replace("\n", "$\n"))
- assert yaml_str == y
-
- def test_issue_54_ok(self):
- yaml_str = dedent(
- """\
- toplevel:
- # some comment
- sublevel: 300
- """
- )
- d = round_trip_load(yaml_str)
- y = round_trip_dump(d, indent=4)
- assert yaml_str == y
-
- def test_issue_93(self):
- round_trip(
- """\
- a:
- b:
- - c1: cat # a1
- # my comment on catfish
- - c2: catfish # a2
- """
- )
-
- def test_issue_93_00(self):
- round_trip(
- """\
- a:
- - - c1: cat # a1
- # my comment on catfish
- - c2: catfish # a2
- """
- )
-
- def test_issue_93_01(self):
- round_trip(
- """\
- - - c1: cat # a1
- # my comment on catfish
- - c2: catfish # a2
- """
- )
-
- def test_issue_93_02(self):
- # never failed as there is no indent
- round_trip(
- """\
- - c1: cat
- # my comment on catfish
- - c2: catfish
- """
- )
-
- def test_issue_96(self):
- # inserted extra line on trailing spaces
- round_trip(
- """\
- a:
- b:
- c: c_val
- d:
-
- e:
- g: g_val
- """
- )
-
-
-class TestUnicodeComments:
- @pytest.mark.skipif(sys.version_info < (2, 7), reason="wide unicode")
- def test_issue_55(self): # reported by Haraguroicha Hsu
- round_trip(
- """\
- name: TEST
- description: test using
- author: Harguroicha
- sql:
- command: |-
- select name from testtbl where no = :no
-
- ci-test:
- - :no: 04043709 # 小花
- - :no: 05161690 # 茶
- - :no: 05293147 # 〇𤋥川
- - :no: 05338777 # 〇〇啓
- - :no: 05273867 # 〇
- - :no: 05205786 # 〇𤦌
- """
- )
-
-
-class TestEmptyValueBeforeComments:
- def test_issue_25a(self):
- round_trip(
- """\
- - a: b
- c: d
- d: # foo
- - e: f
- """
- )
-
- def test_issue_25a1(self):
- round_trip(
- """\
- - a: b
- c: d
- d: # foo
- e: f
- """
- )
-
- def test_issue_25b(self):
- round_trip(
- """\
- var1: #empty
- var2: something #notempty
- """
- )
-
- def test_issue_25c(self):
- round_trip(
- """\
- params:
- a: 1 # comment a
- b: # comment b
- c: 3 # comment c
- """
- )
-
- def test_issue_25c1(self):
- round_trip(
- """\
- params:
- a: 1 # comment a
- b: # comment b
- # extra
- c: 3 # comment c
- """
- )
-
- def test_issue_25_00(self):
- round_trip(
- """\
- params:
- a: 1 # comment a
- b: # comment b
- """
- )
-
- def test_issue_25_01(self):
- round_trip(
- """\
- a: # comment 1
- # comment 2
- - b: # comment 3
- c: 1 # comment 4
- """
- )
-
- def test_issue_25_02(self):
- round_trip(
- """\
- a: # comment 1
- # comment 2
- - b: 2 # comment 3
- """
- )
-
- def test_issue_25_03(self):
- s = """\
- a: # comment 1
- # comment 2
- - b: 2 # comment 3
- """
- round_trip(s, indent=4, block_seq_indent=2)
-
- def test_issue_25_04(self):
- round_trip(
- """\
- a: # comment 1
- # comment 2
- b: 1 # comment 3
- """
- )
-
- def test_flow_seq_within_seq(self):
- round_trip(
- """\
- # comment 1
- - a
- - b
- # comment 2
- - c
- - d
- # comment 3
- - [e]
- - f
- # comment 4
- - []
- """
- )
-
-
-test_block_scalar_commented_line_template = """\
-y: p
-# Some comment
-
-a: |
- x
-{}b: y
-"""
-
-
-class TestBlockScalarWithComments:
- # issue 99 reported by Colm O'Connor
- def test_scalar_with_comments(self):
- import srsly.ruamel_yaml # NOQA
-
- for x in [
- "",
- "\n",
- "\n# Another comment\n",
- "\n\n",
- "\n\n# abc\n#xyz\n",
- "\n\n# abc\n#xyz\n",
- "# abc\n\n#xyz\n",
- "\n\n # abc\n #xyz\n",
- ]:
-
- commented_line = test_block_scalar_commented_line_template.format(x)
- data = srsly.ruamel_yaml.round_trip_load(commented_line)
-
- assert srsly.ruamel_yaml.round_trip_dump(data) == commented_line
diff --git a/srsly/tests/ruamel_yaml/test_contextmanager.py b/srsly/tests/ruamel_yaml/test_contextmanager.py
deleted file mode 100755
index f94f11c..0000000
--- a/srsly/tests/ruamel_yaml/test_contextmanager.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-"""
-testing of anchors and the aliases referring to them
-"""
-
-import sys
-import pytest
-
-
-single_doc = """\
-- a: 1
-- b:
- - 2
- - 3
-"""
-
-single_data = [dict(a=1), dict(b=[2, 3])]
-
-multi_doc = """\
----
-- abc
-- xyz
----
-- a: 1
-- b:
- - 2
- - 3
-"""
-
-multi_doc_data = [["abc", "xyz"], single_data]
-
-
-def get_yaml():
- from srsly.ruamel_yaml import YAML
-
- return YAML()
-
-
-class TestOldStyle:
- def test_single_load(self):
- d = get_yaml().load(single_doc)
- print(d)
- print(type(d[0]))
- assert d == single_data
-
- def test_single_load_no_arg(self):
- with pytest.raises(TypeError):
- assert get_yaml().load() == single_data
-
- def test_multi_load(self):
- data = list(get_yaml().load_all(multi_doc))
- assert data == multi_doc_data
-
- def test_single_dump(self, capsys):
- get_yaml().dump(single_data, sys.stdout)
- out, err = capsys.readouterr()
- assert out == single_doc
-
- def test_multi_dump(self, capsys):
- yaml = get_yaml()
- yaml.explicit_start = True
- yaml.dump_all(multi_doc_data, sys.stdout)
- out, err = capsys.readouterr()
- assert out == multi_doc
-
-
-class TestContextManager:
- def test_single_dump(self, capsys):
- from srsly.ruamel_yaml import YAML
-
- with YAML(output=sys.stdout) as yaml:
- yaml.dump(single_data)
- out, err = capsys.readouterr()
- print(err)
- assert out == single_doc
-
- def test_multi_dump(self, capsys):
- from srsly.ruamel_yaml import YAML
-
- with YAML(output=sys.stdout) as yaml:
- yaml.explicit_start = True
- yaml.dump(multi_doc_data[0])
- yaml.dump(multi_doc_data[1])
-
- out, err = capsys.readouterr()
- print(err)
- assert out == multi_doc
-
- # input is not as simple with a context manager
- # you need to indicate what you expect hence load and load_all
-
- # @pytest.mark.xfail(strict=True)
- # def test_single_load(self):
- # from srsly.ruamel_yaml import YAML
- # with YAML(input=single_doc) as yaml:
- # assert yaml.load() == single_data
- #
- # @pytest.mark.xfail(strict=True)
- # def test_multi_load(self):
- # from srsly.ruamel_yaml import YAML
- # with YAML(input=multi_doc) as yaml:
- # for idx, data in enumerate(yaml.load()):
- # assert data == multi_doc_data[0]
-
- def test_roundtrip(self, capsys):
- from srsly.ruamel_yaml import YAML
-
- with YAML(output=sys.stdout) as yaml:
- yaml.explicit_start = True
- for data in yaml.load_all(multi_doc):
- yaml.dump(data)
-
- out, err = capsys.readouterr()
- print(err)
- assert out == multi_doc
diff --git a/srsly/tests/ruamel_yaml/test_copy.py b/srsly/tests/ruamel_yaml/test_copy.py
deleted file mode 100755
index c17cfb9..0000000
--- a/srsly/tests/ruamel_yaml/test_copy.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# coding: utf-8
-
-"""
-Testing copy and deepcopy, instigated by Issue 84 (Peter Amstutz)
-"""
-
-import copy
-
-import pytest # NOQA
-
-from .roundtrip import dedent, round_trip_load, round_trip_dump
-
-
-class TestDeepCopy:
- def test_preserve_flow_style_simple(self):
- x = dedent(
- """\
- {foo: bar, baz: quux}
- """
- )
- data = round_trip_load(x)
- data_copy = copy.deepcopy(data)
- y = round_trip_dump(data_copy)
- print("x [{}]".format(x))
- print("y [{}]".format(y))
- assert y == x
- assert data.fa.flow_style() == data_copy.fa.flow_style()
-
- def test_deepcopy_flow_style_nested_dict(self):
- x = dedent(
- """\
- a: {foo: bar, baz: quux}
- """
- )
- data = round_trip_load(x)
- assert data["a"].fa.flow_style() is True
- data_copy = copy.deepcopy(data)
- assert data_copy["a"].fa.flow_style() is True
- data_copy["a"].fa.set_block_style()
- assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style()
- assert data["a"].fa._flow_style is True
- assert data_copy["a"].fa._flow_style is False
- y = round_trip_dump(data_copy)
-
- print("x [{}]".format(x))
- print("y [{}]".format(y))
- assert y == dedent(
- """\
- a:
- foo: bar
- baz: quux
- """
- )
-
- def test_deepcopy_flow_style_nested_list(self):
- x = dedent(
- """\
- a: [1, 2, 3]
- """
- )
- data = round_trip_load(x)
- assert data["a"].fa.flow_style() is True
- data_copy = copy.deepcopy(data)
- assert data_copy["a"].fa.flow_style() is True
- data_copy["a"].fa.set_block_style()
- assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style()
- assert data["a"].fa._flow_style is True
- assert data_copy["a"].fa._flow_style is False
- y = round_trip_dump(data_copy)
-
- print("x [{}]".format(x))
- print("y [{}]".format(y))
- assert y == dedent(
- """\
- a:
- - 1
- - 2
- - 3
- """
- )
-
-
-class TestCopy:
- def test_copy_flow_style_nested_dict(self):
- x = dedent(
- """\
- a: {foo: bar, baz: quux}
- """
- )
- data = round_trip_load(x)
- assert data["a"].fa.flow_style() is True
- data_copy = copy.copy(data)
- assert data_copy["a"].fa.flow_style() is True
- data_copy["a"].fa.set_block_style()
- assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style()
- assert data["a"].fa._flow_style is False
- assert data_copy["a"].fa._flow_style is False
- y = round_trip_dump(data_copy)
- z = round_trip_dump(data)
- assert y == z
-
- assert y == dedent(
- """\
- a:
- foo: bar
- baz: quux
- """
- )
-
- def test_copy_flow_style_nested_list(self):
- x = dedent(
- """\
- a: [1, 2, 3]
- """
- )
- data = round_trip_load(x)
- assert data["a"].fa.flow_style() is True
- data_copy = copy.copy(data)
- assert data_copy["a"].fa.flow_style() is True
- data_copy["a"].fa.set_block_style()
- assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style()
- assert data["a"].fa._flow_style is False
- assert data_copy["a"].fa._flow_style is False
- y = round_trip_dump(data_copy)
-
- print("x [{}]".format(x))
- print("y [{}]".format(y))
- assert y == dedent(
- """\
- a:
- - 1
- - 2
- - 3
- """
- )
diff --git a/srsly/tests/ruamel_yaml/test_datetime.py b/srsly/tests/ruamel_yaml/test_datetime.py
deleted file mode 100755
index db88ef8..0000000
--- a/srsly/tests/ruamel_yaml/test_datetime.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# coding: utf-8
-
-"""
-http://yaml.org/type/timestamp.html specifies the regexp to use
-for datetime.date and datetime.datetime construction. Date is simple
-but datetime can have 'T' or 't' as well as 'Z' or a timezone offset (in
-hours and minutes). This information was originally used to create
-a UTC datetime and then discarded
-
-examples from the above:
-
-canonical: 2001-12-15T02:59:43.1Z
-valid iso8601: 2001-12-14t21:59:43.10-05:00
-space separated: 2001-12-14 21:59:43.10 -5
-no time zone (Z): 2001-12-15 2:59:43.10
-date (00:00:00Z): 2002-12-14
-
-Please note that a fraction can only be included if not equal to 0
-
-"""
-
-import copy
-import pytest # NOQA
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-class TestDateTime:
- def test_date_only(self):
- inp = """
- - 2011-10-02
- """
- exp = """
- - 2011-10-02
- """
- round_trip(inp, exp)
-
- def test_zero_fraction(self):
- inp = """
- - 2011-10-02 16:45:00.0
- """
- exp = """
- - 2011-10-02 16:45:00
- """
- round_trip(inp, exp)
-
- def test_long_fraction(self):
- inp = """
- - 2011-10-02 16:45:00.1234 # expand with zeros
- - 2011-10-02 16:45:00.123456
- - 2011-10-02 16:45:00.12345612 # round to microseconds
- - 2011-10-02 16:45:00.1234565 # round up
- - 2011-10-02 16:45:00.12345678 # round up
- """
- exp = """
- - 2011-10-02 16:45:00.123400 # expand with zeros
- - 2011-10-02 16:45:00.123456
- - 2011-10-02 16:45:00.123456 # round to microseconds
- - 2011-10-02 16:45:00.123457 # round up
- - 2011-10-02 16:45:00.123457 # round up
- """
- round_trip(inp, exp)
-
- def test_canonical(self):
- inp = """
- - 2011-10-02T16:45:00.1Z
- """
- exp = """
- - 2011-10-02T16:45:00.100000Z
- """
- round_trip(inp, exp)
-
- def test_spaced_timezone(self):
- inp = """
- - 2011-10-02T11:45:00 -5
- """
- exp = """
- - 2011-10-02T11:45:00-5
- """
- round_trip(inp, exp)
-
- def test_normal_timezone(self):
- round_trip(
- """
- - 2011-10-02T11:45:00-5
- - 2011-10-02 11:45:00-5
- - 2011-10-02T11:45:00-05:00
- - 2011-10-02 11:45:00-05:00
- """
- )
-
- def test_no_timezone(self):
- inp = """
- - 2011-10-02 6:45:00
- """
- exp = """
- - 2011-10-02 06:45:00
- """
- round_trip(inp, exp)
-
- def test_explicit_T(self):
- inp = """
- - 2011-10-02T16:45:00
- """
- exp = """
- - 2011-10-02T16:45:00
- """
- round_trip(inp, exp)
-
- def test_explicit_t(self): # to upper
- inp = """
- - 2011-10-02t16:45:00
- """
- exp = """
- - 2011-10-02T16:45:00
- """
- round_trip(inp, exp)
-
- def test_no_T_multi_space(self):
- inp = """
- - 2011-10-02 16:45:00
- """
- exp = """
- - 2011-10-02 16:45:00
- """
- round_trip(inp, exp)
-
- def test_iso(self):
- round_trip(
- """
- - 2011-10-02T15:45:00+01:00
- """
- )
-
- def test_zero_tz(self):
- round_trip(
- """
- - 2011-10-02T15:45:00+0
- """
- )
-
- def test_issue_45(self):
- round_trip(
- """
- dt: 2016-08-19T22:45:47Z
- """
- )
-
- def test_deepcopy_datestring(self):
- # reported by Quuxplusone, http://stackoverflow.com/a/41577841/1307905
- x = dedent(
- """\
- foo: 2016-10-12T12:34:56
- """
- )
- data = copy.deepcopy(round_trip_load(x))
- assert round_trip_dump(data) == x
diff --git a/srsly/tests/ruamel_yaml/test_deprecation.py b/srsly/tests/ruamel_yaml/test_deprecation.py
deleted file mode 100755
index 14acd71..0000000
--- a/srsly/tests/ruamel_yaml/test_deprecation.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-import sys
-import pytest # NOQA
-
-
-@pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 9),
- reason='collections not available?')
-def test_collections_deprecation():
- with pytest.warns(DeprecationWarning):
- from collections import Hashable # NOQA
diff --git a/srsly/tests/ruamel_yaml/test_documents.py b/srsly/tests/ruamel_yaml/test_documents.py
deleted file mode 100755
index cad6ff2..0000000
--- a/srsly/tests/ruamel_yaml/test_documents.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# coding: utf-8
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, round_trip_load_all
-
-
-class TestDocument:
- def test_single_doc_begin_end(self):
- inp = """\
- ---
- - a
- - b
- ...
- """
- round_trip(inp, explicit_start=True, explicit_end=True)
-
- def test_multi_doc_begin_end(self):
- from srsly.ruamel_yaml import dump_all, RoundTripDumper
-
- inp = """\
- ---
- - a
- ...
- ---
- - b
- ...
- """
- docs = list(round_trip_load_all(inp))
- assert docs == [["a"], ["b"]]
- out = dump_all(
- docs, Dumper=RoundTripDumper, explicit_start=True, explicit_end=True
- )
- assert out == "---\n- a\n...\n---\n- b\n...\n"
-
- def test_multi_doc_no_start(self):
- inp = """\
- - a
- ...
- ---
- - b
- ...
- """
- docs = list(round_trip_load_all(inp))
- assert docs == [["a"], ["b"]]
-
- def test_multi_doc_no_end(self):
- inp = """\
- - a
- ---
- - b
- """
- docs = list(round_trip_load_all(inp))
- assert docs == [["a"], ["b"]]
-
- def test_multi_doc_ends_only(self):
- # this is ok in 1.2
- inp = """\
- - a
- ...
- - b
- ...
- """
- docs = list(round_trip_load_all(inp, version=(1, 2)))
- assert docs == [["a"], ["b"]]
-
- def test_multi_doc_ends_only_1_1(self):
- from srsly.ruamel_yaml import parser
-
- # this is not ok in 1.1
- with pytest.raises(parser.ParserError):
- inp = """\
- - a
- ...
- - b
- ...
- """
- docs = list(round_trip_load_all(inp, version=(1, 1)))
- assert docs == [["a"], ["b"]] # not True, but not reached
diff --git a/srsly/tests/ruamel_yaml/test_fail.py b/srsly/tests/ruamel_yaml/test_fail.py
deleted file mode 100755
index 02cef0b..0000000
--- a/srsly/tests/ruamel_yaml/test_fail.py
+++ /dev/null
@@ -1,255 +0,0 @@
-# coding: utf-8
-
-# there is some work to do
-# provide a failing test xyz and a non-failing xyz_no_fail ( to see
-# what the current failing output is.
-# on fix of srsly.ruamel_yaml, move the marked test to the appropriate test (without mark)
-# and remove remove the xyz_no_fail
-
-import pytest
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump
-
-
-class TestCommentFailures:
- @pytest.mark.xfail(strict=True)
- def test_set_comment_before_tag(self):
- # no comments before tags
- round_trip(
- """
- # the beginning
- !!set
- # or this one?
- ? a
- # next one is B (lowercase)
- ? b # You see? Promised you.
- ? c
- # this is the end
- """
- )
-
- def test_set_comment_before_tag_no_fail(self):
- # no comments before tags
- inp = """
- # the beginning
- !!set
- # or this one?
- ? a
- # next one is B (lowercase)
- ? b # You see? Promised you.
- ? c
- # this is the end
- """
- assert round_trip_dump(round_trip_load(inp)) == dedent(
- """
- !!set
- # or this one?
- ? a
- # next one is B (lowercase)
- ? b # You see? Promised you.
- ? c
- # this is the end
- """
- )
-
- @pytest.mark.xfail(strict=True)
- def test_comment_dash_line(self):
- round_trip(
- """
- - # abc
- a: 1
- b: 2
- """
- )
-
- def test_comment_dash_line_fail(self):
- x = """
- - # abc
- a: 1
- b: 2
- """
- data = round_trip_load(x)
- # this is not nice
- assert round_trip_dump(data) == dedent(
- """
- # abc
- - a: 1
- b: 2
- """
- )
-
-
-class TestIndentFailures:
- @pytest.mark.xfail(strict=True)
- def test_indent_not_retained(self):
- round_trip(
- """
- verbosity: 1 # 0 is minimal output, -1 none
- base_url: http://gopher.net
- special_indices: [1, 5, 8]
- also_special:
- - a
- - 19
- - 32
- asia and europe: &asia_europe
- Turkey: Ankara
- Russia: Moscow
- countries:
- Asia:
- <<: *asia_europe
- Japan: Tokyo # 東京
- Europe:
- <<: *asia_europe
- Spain: Madrid
- Italy: Rome
- Antarctica:
- - too cold
- """
- )
-
- def test_indent_not_retained_no_fail(self):
- inp = """
- verbosity: 1 # 0 is minimal output, -1 none
- base_url: http://gopher.net
- special_indices: [1, 5, 8]
- also_special:
- - a
- - 19
- - 32
- asia and europe: &asia_europe
- Turkey: Ankara
- Russia: Moscow
- countries:
- Asia:
- <<: *asia_europe
- Japan: Tokyo # 東京
- Europe:
- <<: *asia_europe
- Spain: Madrid
- Italy: Rome
- Antarctica:
- - too cold
- """
- assert round_trip_dump(round_trip_load(inp), indent=4) == dedent(
- """
- verbosity: 1 # 0 is minimal output, -1 none
- base_url: http://gopher.net
- special_indices: [1, 5, 8]
- also_special:
- - a
- - 19
- - 32
- asia and europe: &asia_europe
- Turkey: Ankara
- Russia: Moscow
- countries:
- Asia:
- <<: *asia_europe
- Japan: Tokyo # 東京
- Europe:
- <<: *asia_europe
- Spain: Madrid
- Italy: Rome
- Antarctica:
- - too cold
- """
- )
-
- def Xtest_indent_top_level_no_fail(self):
- inp = """
- - a:
- - b
- """
- round_trip(inp, indent=4)
-
-
-class TestTagFailures:
- @pytest.mark.xfail(strict=True)
- def test_standard_short_tag(self):
- round_trip(
- """\
- !!map
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_standard_short_tag_no_fail(self):
- inp = """
- !!map
- name: Anthon
- location: Germany
- language: python
- """
- exp = """
- name: Anthon
- location: Germany
- language: python
- """
- assert round_trip_dump(round_trip_load(inp)) == dedent(exp)
-
-
-class TestFlowValues:
- def test_flow_value_with_colon(self):
- inp = """\
- {a: bcd:efg}
- """
- round_trip(inp)
-
- def test_flow_value_with_colon_quoted(self):
- inp = """\
- {a: 'bcd:efg'}
- """
- round_trip(inp, preserve_quotes=True)
-
-
-class TestMappingKey:
- def test_simple_mapping_key(self):
- inp = """\
- {a: 1, b: 2}: hello world
- """
- round_trip(inp, preserve_quotes=True, dump_data=False)
-
- def test_set_simple_mapping_key(self):
- from srsly.ruamel_yaml.comments import CommentedKeyMap
-
- d = {CommentedKeyMap([("a", 1), ("b", 2)]): "hello world"}
- exp = dedent(
- """\
- {a: 1, b: 2}: hello world
- """
- )
- assert round_trip_dump(d) == exp
-
- def test_change_key_simple_mapping_key(self):
- from srsly.ruamel_yaml.comments import CommentedKeyMap
-
- inp = """\
- {a: 1, b: 2}: hello world
- """
- d = round_trip_load(inp, preserve_quotes=True)
- d[CommentedKeyMap([("b", 1), ("a", 2)])] = d.pop(
- CommentedKeyMap([("a", 1), ("b", 2)])
- )
- exp = dedent(
- """\
- {b: 1, a: 2}: hello world
- """
- )
- assert round_trip_dump(d) == exp
-
- def test_change_value_simple_mapping_key(self):
- from srsly.ruamel_yaml.comments import CommentedKeyMap
-
- inp = """\
- {a: 1, b: 2}: hello world
- """
- d = round_trip_load(inp, preserve_quotes=True)
- d = {CommentedKeyMap([("a", 1), ("b", 2)]): "goodbye"}
- exp = dedent(
- """\
- {a: 1, b: 2}: goodbye
- """
- )
- assert round_trip_dump(d) == exp
diff --git a/srsly/tests/ruamel_yaml/test_float.py b/srsly/tests/ruamel_yaml/test_float.py
deleted file mode 100755
index 4cffc68..0000000
--- a/srsly/tests/ruamel_yaml/test_float.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-# http://yaml.org/type/int.html is where underscores in integers are defined
-
-
-class TestFloat:
- def test_round_trip_non_exp(self):
- data = round_trip(
- """\
- - 1.0
- - 1.00
- - 23.100
- - -1.0
- - -1.00
- - -23.100
- - 42.
- - -42.
- - +42.
- - .5
- - +.5
- - -.5
- """
- )
- print(data)
- assert 0.999 < data[0] < 1.001
- assert 0.999 < data[1] < 1.001
- assert 23.099 < data[2] < 23.101
- assert 0.999 < -data[3] < 1.001
- assert 0.999 < -data[4] < 1.001
- assert 23.099 < -data[5] < 23.101
- assert 41.999 < data[6] < 42.001
- assert 41.999 < -data[7] < 42.001
- assert 41.999 < data[8] < 42.001
- assert 0.49 < data[9] < 0.51
- assert 0.49 < data[10] < 0.51
- assert -0.51 < data[11] < -0.49
-
- def test_round_trip_zeros_0(self):
- data = round_trip(
- """\
- - 0.
- - +0.
- - -0.
- - 0.0
- - +0.0
- - -0.0
- - 0.00
- - +0.00
- - -0.00
- """
- )
- print(data)
- for d in data:
- assert -0.00001 < d < 0.00001
-
- def Xtest_round_trip_non_exp_trailing_dot(self):
- data = round_trip(
- """\
- """
- )
- print(data)
-
- def test_yaml_1_1_no_dot(self):
- from srsly.ruamel_yaml.error import MantissaNoDotYAML1_1Warning
-
- with pytest.warns(MantissaNoDotYAML1_1Warning):
- round_trip_load(
- """\
- %YAML 1.1
- ---
- - 1e6
- """
- )
-
-
-class TestCalculations(object):
- def test_mul_00(self):
- # issue 149 reported by jan.brezina@tul.cz
- d = round_trip_load(
- """\
- - 0.1
- """
- )
- d[0] *= -1
- x = round_trip_dump(d)
- assert x == "- -0.1\n"
diff --git a/srsly/tests/ruamel_yaml/test_flowsequencekey.py b/srsly/tests/ruamel_yaml/test_flowsequencekey.py
deleted file mode 100755
index 8362bec..0000000
--- a/srsly/tests/ruamel_yaml/test_flowsequencekey.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# coding: utf-8
-
-"""
-test flow style sequences as keys roundtrip
-
-"""
-
-# import pytest
-
-from .roundtrip import round_trip # , dedent, round_trip_load, round_trip_dump
-
-
-class TestFlowStyleSequenceKey:
- def test_so_39595807(self):
- inp = """\
- %YAML 1.2
- ---
- [2, 3, 4]:
- a:
- - 1
- - 2
- b: Hello World!
- c: 'Voilà!'
- """
- round_trip(inp, preserve_quotes=True, explicit_start=True, version=(1, 2))
diff --git a/srsly/tests/ruamel_yaml/test_indentation.py b/srsly/tests/ruamel_yaml/test_indentation.py
deleted file mode 100755
index ef1057d..0000000
--- a/srsly/tests/ruamel_yaml/test_indentation.py
+++ /dev/null
@@ -1,365 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import
-from __future__ import print_function
-from __future__ import unicode_literals
-
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, round_trip_load, round_trip_dump, dedent, YAML
-
-
-def rt(s):
- import srsly.ruamel_yaml
-
- res = srsly.ruamel_yaml.dump(
- srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader),
- Dumper=srsly.ruamel_yaml.RoundTripDumper,
- )
- return res.strip() + "\n"
-
-
-class TestIndent:
- def test_roundtrip_inline_list(self):
- s = "a: [a, b, c]\n"
- output = rt(s)
- assert s == output
-
- def test_roundtrip_mapping_of_inline_lists(self):
- s = dedent(
- """\
- a: [a, b, c]
- j: [k, l, m]
- """
- )
- output = rt(s)
- assert s == output
-
- def test_roundtrip_mapping_of_inline_lists_comments(self):
- s = dedent(
- """\
- # comment A
- a: [a, b, c]
- # comment B
- j: [k, l, m]
- """
- )
- output = rt(s)
- assert s == output
-
- def test_roundtrip_mapping_of_inline_sequence_eol_comments(self):
- s = dedent(
- """\
- # comment A
- a: [a, b, c] # comment B
- j: [k, l, m] # comment C
- """
- )
- output = rt(s)
- assert s == output
-
- # first test by explicitly setting flow style
- def test_added_inline_list(self):
- import srsly.ruamel_yaml
-
- s1 = dedent(
- """
- a:
- - b
- - c
- - d
- """
- )
- s = "a: [b, c, d]\n"
- data = srsly.ruamel_yaml.load(s1, Loader=srsly.ruamel_yaml.RoundTripLoader)
- val = data["a"]
- val.fa.set_flow_style()
- # print(type(val), '_yaml_format' in dir(val))
- output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
- assert s == output
-
- # ############ flow mappings
-
- def test_roundtrip_flow_mapping(self):
- import srsly.ruamel_yaml
-
- s = dedent(
- """\
- - {a: 1, b: hallo}
- - {j: fka, k: 42}
- """
- )
- data = srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader)
- output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
- assert s == output
-
- def test_roundtrip_sequence_of_inline_mappings_eol_comments(self):
- s = dedent(
- """\
- # comment A
- - {a: 1, b: hallo} # comment B
- - {j: fka, k: 42} # comment C
- """
- )
- output = rt(s)
- assert s == output
-
- def test_indent_top_level(self):
- inp = """
- - a:
- - b
- """
- round_trip(inp, indent=4)
-
- def test_set_indent_5_block_list_indent_1(self):
- inp = """
- a:
- - b: c
- - 1
- - d:
- - 2
- """
- round_trip(inp, indent=5, block_seq_indent=1)
-
- def test_set_indent_4_block_list_indent_2(self):
- inp = """
- a:
- - b: c
- - 1
- - d:
- - 2
- """
- round_trip(inp, indent=4, block_seq_indent=2)
-
- def test_set_indent_3_block_list_indent_0(self):
- inp = """
- a:
- - b: c
- - 1
- - d:
- - 2
- """
- round_trip(inp, indent=3, block_seq_indent=0)
-
- def Xtest_set_indent_3_block_list_indent_2(self):
- inp = """
- a:
- -
- b: c
- -
- 1
- -
- d:
- -
- 2
- """
- round_trip(inp, indent=3, block_seq_indent=2)
-
- def test_set_indent_3_block_list_indent_2(self):
- inp = """
- a:
- - b: c
- - 1
- - d:
- - 2
- """
- round_trip(inp, indent=3, block_seq_indent=2)
-
- def Xtest_set_indent_2_block_list_indent_2(self):
- inp = """
- a:
- -
- b: c
- -
- 1
- -
- d:
- -
- 2
- """
- round_trip(inp, indent=2, block_seq_indent=2)
-
- # this is how it should be: block_seq_indent stretches the indent
- def test_set_indent_2_block_list_indent_2(self):
- inp = """
- a:
- - b: c
- - 1
- - d:
- - 2
- """
- round_trip(inp, indent=2, block_seq_indent=2)
-
- # have to set indent!
- def test_roundtrip_four_space_indents(self):
- # fmt: off
- s = (
- 'a:\n'
- '- foo\n'
- '- bar\n'
- )
- # fmt: on
- round_trip(s, indent=4)
-
- def test_roundtrip_four_space_indents_no_fail(self):
- inp = """
- a:
- - foo
- - bar
- """
- exp = """
- a:
- - foo
- - bar
- """
- assert round_trip_dump(round_trip_load(inp)) == dedent(exp)
-
-
-class TestYpkgIndent:
- def test_00(self):
- inp = """
- name : nano
- version : 2.3.2
- release : 1
- homepage : http://www.nano-editor.org
- source :
- - http://www.nano-editor.org/dist/v2.3/nano-2.3.2.tar.gz : ff30924807ea289f5b60106be8
- license : GPL-2.0
- summary : GNU nano is an easy-to-use text editor
- builddeps :
- - ncurses-devel
- description: |
- GNU nano is an easy-to-use text editor originally designed
- as a replacement for Pico, the ncurses-based editor from the non-free mailer
- package Pine (itself now available under the Apache License as Alpine).
- """
- round_trip(
- inp,
- indent=4,
- block_seq_indent=2,
- top_level_colon_align=True,
- prefix_colon=" ",
- )
-
-
-def guess(s):
- from srsly.ruamel_yaml.util import load_yaml_guess_indent
-
- x, y, z = load_yaml_guess_indent(dedent(s))
- return y, z
-
-
-class TestGuessIndent:
- def test_guess_20(self):
- inp = """\
- a:
- - 1
- """
- assert guess(inp) == (2, 0)
-
- def test_guess_42(self):
- inp = """\
- a:
- - 1
- """
- assert guess(inp) == (4, 2)
-
- def test_guess_42a(self):
- # block seq indent prevails over nested key indent level
- inp = """\
- b:
- a:
- - 1
- """
- assert guess(inp) == (4, 2)
-
- def test_guess_3None(self):
- inp = """\
- b:
- a: 1
- """
- assert guess(inp) == (3, None)
-
-
-class TestSeparateMapSeqIndents:
- # using uncommon 6 indent with 3 push in as 2 push in automatically
- # gets you 4 indent even if not set
- def test_00(self):
- # old style
- yaml = YAML()
- yaml.indent = 6
- yaml.block_seq_indent = 3
- inp = """
- a:
- - 1
- - [1, 2]
- """
- yaml.round_trip(inp)
-
- def test_01(self):
- yaml = YAML()
- yaml.indent(sequence=6)
- yaml.indent(offset=3)
- inp = """
- a:
- - 1
- - {b: 3}
- """
- yaml.round_trip(inp)
-
- def test_02(self):
- yaml = YAML()
- yaml.indent(mapping=5, sequence=6, offset=3)
- inp = """
- a:
- b:
- - 1
- - [1, 2]
- """
- yaml.round_trip(inp)
-
- def test_03(self):
- inp = """
- a:
- b:
- c:
- - 1
- - [1, 2]
- """
- round_trip(inp, indent=4)
-
- def test_04(self):
- yaml = YAML()
- yaml.indent(mapping=5, sequence=6)
- inp = """
- a:
- b:
- - 1
- - [1, 2]
- - {d: 3.14}
- """
- yaml.round_trip(inp)
-
- def test_issue_51(self):
- yaml = YAML()
- # yaml.map_indent = 2 # the default
- yaml.indent(sequence=4, offset=2)
- yaml.preserve_quotes = True
- yaml.round_trip(
- """
- role::startup::author::rsyslog_inputs:
- imfile:
- - ruleset: 'AEM-slinglog'
- File: '/opt/aem/author/crx-quickstart/logs/error.log'
- startmsg.regex: '^[-+T.:[:digit:]]*'
- tag: 'error'
- - ruleset: 'AEM-slinglog'
- File: '/opt/aem/author/crx-quickstart/logs/stdout.log'
- startmsg.regex: '^[-+T.:[:digit:]]*'
- tag: 'stdout'
- """
- )
-
-
-# ############ indentation
diff --git a/srsly/tests/ruamel_yaml/test_int.py b/srsly/tests/ruamel_yaml/test_int.py
deleted file mode 100755
index 6635fcf..0000000
--- a/srsly/tests/ruamel_yaml/test_int.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-import pytest # NOQA
-
-from .roundtrip import dedent, round_trip_load, round_trip_dump
-
-# http://yaml.org/type/int.html is where underscores in integers are defined
-
-
-class TestBinHexOct:
- def test_calculate(self):
- # make sure type, leading zero(s) and underscore are preserved
- s = dedent(
- """\
- - 42
- - 0b101010
- - 0x_2a
- - 0x2A
- - 0o00_52
- """
- )
- d = round_trip_load(s)
- for idx, elem in enumerate(d):
- elem -= 21
- d[idx] = elem
- for idx, elem in enumerate(d):
- elem *= 2
- d[idx] = elem
- for idx, elem in enumerate(d):
- t = elem
- elem **= 2
- elem //= t
- d[idx] = elem
- assert round_trip_dump(d) == s
diff --git a/srsly/tests/ruamel_yaml/test_issues.py b/srsly/tests/ruamel_yaml/test_issues.py
deleted file mode 100755
index 11dfbb7..0000000
--- a/srsly/tests/ruamel_yaml/test_issues.py
+++ /dev/null
@@ -1,971 +0,0 @@
-# coding: utf-8
-
-from __future__ import absolute_import, print_function, unicode_literals
-
-
-import pytest # NOQA
-import sys
-
-
-from .roundtrip import (
- round_trip,
- na_round_trip,
- round_trip_load,
- round_trip_dump,
- dedent,
- save_and_run,
- YAML,
-) # NOQA
-
-
-class TestIssues:
- def test_issue_61(self):
- import srsly.ruamel_yaml
-
- s = dedent(
- """
- def1: &ANCHOR1
- key1: value1
- def: &ANCHOR
- <<: *ANCHOR1
- key: value
- comb:
- <<: *ANCHOR
- """
- )
- data = srsly.ruamel_yaml.round_trip_load(s)
- assert str(data["comb"]) == str(data["def"])
- if sys.version_info >= (3, 12):
- assert (
- str(data["comb"]) == "ordereddict({'key': 'value', 'key1': 'value1'})"
- )
- else:
- assert (
- str(data["comb"]) == "ordereddict([('key', 'value'), ('key1', 'value1')])"
- )
-
- def test_issue_82(self, tmpdir):
- program_src = r'''
- from __future__ import print_function
-
- import srsly.ruamel_yaml as yaml
-
- import re
-
-
- class SINumber(yaml.YAMLObject):
- PREFIXES = {'k': 1e3, 'M': 1e6, 'G': 1e9}
- yaml_loader = yaml.Loader
- yaml_dumper = yaml.Dumper
- yaml_tag = u'!si'
- yaml_implicit_pattern = re.compile(
- r'^(?P[0-9]+(?:\.[0-9]+)?)(?P[kMG])$')
-
- @classmethod
- def from_yaml(cls, loader, node):
- return cls(node.value)
-
- @classmethod
- def to_yaml(cls, dumper, data):
- return dumper.represent_scalar(cls.yaml_tag, str(data))
-
- def __init__(self, *args):
- m = self.yaml_implicit_pattern.match(args[0])
- self.value = float(m.groupdict()['value'])
- self.prefix = m.groupdict()['prefix']
-
- def __str__(self):
- return str(self.value)+self.prefix
-
- def __int__(self):
- return int(self.value*self.PREFIXES[self.prefix])
-
- # This fails:
- yaml.add_implicit_resolver(SINumber.yaml_tag, SINumber.yaml_implicit_pattern)
-
- ret = yaml.load("""
- [1,2,3, !si 10k, 100G]
- """, Loader=yaml.Loader)
- for idx, l in enumerate([1, 2, 3, 10000, 100000000000]):
- assert int(ret[idx]) == l
- '''
- assert save_and_run(dedent(program_src), tmpdir) == 1
-
- def test_issue_82rt(self, tmpdir):
- yaml_str = "[1, 2, 3, !si 10k, 100G]\n"
- x = round_trip(yaml_str, preserve_quotes=True) # NOQA
-
- def test_issue_102(self):
- yaml_str = dedent(
- """
- var1: #empty
- var2: something #notempty
- var3: {} #empty object
- var4: {a: 1} #filled object
- var5: [] #empty array
- """
- )
- x = round_trip(yaml_str, preserve_quotes=True) # NOQA
-
- def test_issue_150(self):
- from srsly.ruamel_yaml import YAML
-
- inp = """\
- base: &base_key
- first: 123
- second: 234
-
- child:
- <<: *base_key
- third: 345
- """
- yaml = YAML()
- data = yaml.load(inp)
- child = data["child"]
- assert "second" in dict(**child)
-
- def test_issue_160(self):
- from srsly.ruamel_yaml.compat import StringIO
-
- s = dedent(
- """\
- root:
- # a comment
- - {some_key: "value"}
-
- foo: 32
- bar: 32
- """
- )
- a = round_trip_load(s)
- del a["root"][0]["some_key"]
- buf = StringIO()
- round_trip_dump(a, buf, block_seq_indent=4)
- exp = dedent(
- """\
- root:
- # a comment
- - {}
-
- foo: 32
- bar: 32
- """
- )
- assert buf.getvalue() == exp
-
- def test_issue_161(self):
- yaml_str = dedent(
- """\
- mapping-A:
- key-A:{}
- mapping-B:
- """
- )
- for comment in ["", " # no-newline", " # some comment\n", "\n"]:
- s = yaml_str.format(comment)
- res = round_trip(s) # NOQA
-
- def test_issue_161a(self):
- yaml_str = dedent(
- """\
- mapping-A:
- key-A:{}
- mapping-B:
- """
- )
- for comment in ["\n# between"]:
- s = yaml_str.format(comment)
- res = round_trip(s) # NOQA
-
- def test_issue_163(self):
- s = dedent(
- """\
- some-list:
- # List comment
- - {}
- """
- )
- x = round_trip(s, preserve_quotes=True) # NOQA
-
- json_str = (
- r'{"sshKeys":[{"name":"AETROS\/google-k80-1","uses":0,"getLastUse":0,'
- '"fingerprint":"MD5:19:dd:41:93:a1:a3:f5:91:4a:8e:9b:d0:ae:ce:66:4c",'
- '"created":1509497961}]}'
- )
-
- json_str2 = '{"abc":[{"a":"1", "uses":0}]}'
-
- def test_issue_172(self):
- x = round_trip_load(TestIssues.json_str2) # NOQA
- x = round_trip_load(TestIssues.json_str) # NOQA
-
- def test_issue_176(self):
- # basic request by Stuart Berg
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- seq = yaml.load("[1,2,3]")
- seq[:] = [1, 2, 3, 4]
-
- def test_issue_176_preserve_comments_on_extended_slice_assignment(self):
- yaml_str = dedent(
- """\
- - a
- - b # comment
- - c # commment c
- # comment c+
- - d
-
- - e # comment
- """
- )
- seq = round_trip_load(yaml_str)
- seq[1::2] = ["B", "D"]
- res = round_trip_dump(seq)
- assert res == yaml_str.replace(" b ", " B ").replace(" d\n", " D\n")
-
- def test_issue_176_test_slicing(self):
- from srsly.ruamel_yaml.compat import PY2
-
- mss = round_trip_load("[0, 1, 2, 3, 4]")
- assert len(mss) == 5
- assert mss[2:2] == []
- assert mss[2:4] == [2, 3]
- assert mss[1::2] == [1, 3]
-
- # slice assignment
- m = mss[:]
- m[2:2] = [42]
- assert m == [0, 1, 42, 2, 3, 4]
-
- m = mss[:]
- m[:3] = [42, 43, 44]
- assert m == [42, 43, 44, 3, 4]
- m = mss[:]
- m[2:] = [42, 43, 44]
- assert m == [0, 1, 42, 43, 44]
- m = mss[:]
- m[:] = [42, 43, 44]
- assert m == [42, 43, 44]
-
- # extend slice assignment
- m = mss[:]
- m[2:4] = [42, 43, 44]
- assert m == [0, 1, 42, 43, 44, 4]
- m = mss[:]
- m[1::2] = [42, 43]
- assert m == [0, 42, 2, 43, 4]
- m = mss[:]
- if PY2:
- with pytest.raises(ValueError, match="attempt to assign"):
- m[1::2] = [42, 43, 44]
- else:
- with pytest.raises(TypeError, match="too many"):
- m[1::2] = [42, 43, 44]
- if PY2:
- with pytest.raises(ValueError, match="attempt to assign"):
- m[1::2] = [42]
- else:
- with pytest.raises(TypeError, match="not enough"):
- m[1::2] = [42]
- m = mss[:]
- m += [5]
- m[1::2] = [42, 43, 44]
- assert m == [0, 42, 2, 43, 4, 44]
-
- # deleting
- m = mss[:]
- del m[1:3]
- assert m == [0, 3, 4]
- m = mss[:]
- del m[::2]
- assert m == [1, 3]
- m = mss[:]
- del m[:]
- assert m == []
-
- def test_issue_184(self):
- yaml_str = dedent(
- """\
- test::test:
- # test
- foo:
- bar: baz
- """
- )
- d = round_trip_load(yaml_str)
- d["bar"] = "foo"
- d.yaml_add_eol_comment("test1", "bar")
- assert round_trip_dump(d) == yaml_str + "bar: foo # test1\n"
-
- def test_issue_219(self):
- yaml_str = dedent(
- """\
- [StackName: AWS::StackName]
- """
- )
- d = round_trip_load(yaml_str) # NOQA
-
- def test_issue_219a(self):
- yaml_str = dedent(
- """\
- [StackName:
- AWS::StackName]
- """
- )
- d = round_trip_load(yaml_str) # NOQA
-
- def test_issue_220(self, tmpdir):
- program_src = r'''
- from srsly.ruamel_yaml import YAML
-
- yaml_str = u"""\
- ---
- foo: ["bar"]
- """
-
- yaml = YAML(typ='safe', pure=True)
- d = yaml.load(yaml_str)
- print(d)
- '''
- assert save_and_run(dedent(program_src), tmpdir, optimized=True) == 0
-
- def test_issue_221_add(self):
- from srsly.ruamel_yaml.comments import CommentedSeq
-
- a = CommentedSeq([1, 2, 3])
- a + [4, 5]
-
- def test_issue_221_sort(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- inp = dedent(
- """\
- - d
- - a # 1
- - c # 3
- - e # 5
- - b # 2
- """
- )
- a = yaml.load(dedent(inp))
- a.sort()
- buf = StringIO()
- yaml.dump(a, buf)
- exp = dedent(
- """\
- - a # 1
- - b # 2
- - c # 3
- - d
- - e # 5
- """
- )
- assert buf.getvalue() == exp
-
- def test_issue_221_sort_reverse(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- inp = dedent(
- """\
- - d
- - a # 1
- - c # 3
- - e # 5
- - b # 2
- """
- )
- a = yaml.load(dedent(inp))
- a.sort(reverse=True)
- buf = StringIO()
- yaml.dump(a, buf)
- exp = dedent(
- """\
- - e # 5
- - d
- - c # 3
- - b # 2
- - a # 1
- """
- )
- assert buf.getvalue() == exp
-
- def test_issue_221_sort_key(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- inp = dedent(
- """\
- - four
- - One # 1
- - Three # 3
- - five # 5
- - two # 2
- """
- )
- a = yaml.load(dedent(inp))
- a.sort(key=str.lower)
- buf = StringIO()
- yaml.dump(a, buf)
- exp = dedent(
- """\
- - five # 5
- - four
- - One # 1
- - Three # 3
- - two # 2
- """
- )
- assert buf.getvalue() == exp
-
- def test_issue_221_sort_key_reverse(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- inp = dedent(
- """\
- - four
- - One # 1
- - Three # 3
- - five # 5
- - two # 2
- """
- )
- a = yaml.load(dedent(inp))
- a.sort(key=str.lower, reverse=True)
- buf = StringIO()
- yaml.dump(a, buf)
- exp = dedent(
- """\
- - two # 2
- - Three # 3
- - One # 1
- - four
- - five # 5
- """
- )
- assert buf.getvalue() == exp
-
- def test_issue_222(self):
- import srsly.ruamel_yaml
- from srsly.ruamel_yaml.compat import StringIO
-
- buf = StringIO()
- srsly.ruamel_yaml.safe_dump(["012923"], buf)
- assert buf.getvalue() == "['012923']\n"
-
- def test_issue_223(self):
- import srsly.ruamel_yaml
-
- yaml = srsly.ruamel_yaml.YAML(typ="safe")
- yaml.load("phone: 0123456789")
-
- def test_issue_232(self):
- import srsly.ruamel_yaml
- import srsly.ruamel_yaml as yaml
-
- with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
- yaml.safe_load("]")
- with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
- yaml.safe_load("{]")
-
- def test_issue_233(self):
- from srsly.ruamel_yaml import YAML
- import json
-
- yaml = YAML()
- data = yaml.load("{}")
- json_str = json.dumps(data) # NOQA
-
- def test_issue_233a(self):
- from srsly.ruamel_yaml import YAML
- import json
-
- yaml = YAML()
- data = yaml.load("[]")
- json_str = json.dumps(data) # NOQA
-
- def test_issue_234(self):
- from srsly.ruamel_yaml import YAML
-
- inp = dedent(
- """\
- - key: key1
- ctx: [one, two]
- help: one
- cmd: >
- foo bar
- foo bar
- """
- )
- yaml = YAML(typ="safe", pure=True)
- data = yaml.load(inp)
- fold = data[0]["cmd"]
- print(repr(fold))
- assert "\a" not in fold
-
- def test_issue_236(self):
- inp = """
- conf:
- xx: {a: "b", c: []}
- asd: "nn"
- """
- d = round_trip(inp, preserve_quotes=True) # NOQA
-
- def test_issue_238(self, tmpdir):
- program_src = r"""
- import srsly.ruamel_yaml
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = srsly.ruamel_yaml.YAML(typ='unsafe')
-
-
- class A:
- def __setstate__(self, d):
- self.__dict__ = d
-
-
- class B:
- pass
-
-
- a = A()
- b = B()
-
- a.x = b
- b.y = [b]
- assert a.x.y[0] == a.x
-
- buf = StringIO()
- yaml.dump(a, buf)
-
- data = yaml.load(buf.getvalue())
- assert data.x.y[0] == data.x
- """
- assert save_and_run(dedent(program_src), tmpdir) == 1
-
- def test_issue_239(self):
- inp = """
- first_name: Art
- occupation: Architect
- # I'm safe
- about: Art Vandelay is a fictional character that George invents...
- # we are not :(
- # help me!
- ---
- # what?!
- hello: world
- # someone call the Batman
- foo: bar # or quz
- # Lost again
- ---
- I: knew
- # final words
- """
- d = YAML().round_trip_all(inp) # NOQA
-
- def test_issue_242(self):
- from srsly.ruamel_yaml.comments import CommentedMap
-
- d0 = CommentedMap([("a", "b")])
- assert d0["a"] == "b"
-
- def test_issue_245(self):
- from srsly.ruamel_yaml import YAML
-
- inp = """
- d: yes
- """
- for typ in ["safepure", "rt", "safe"]:
- if typ.endswith("pure"):
- pure = True
- typ = typ[:-4]
- else:
- pure = None
-
- yaml = YAML(typ=typ, pure=pure)
- yaml.version = (1, 1)
- d = yaml.load(inp)
- print(typ, yaml.parser, yaml.resolver)
- assert d["d"] is True
-
- def test_issue_249(self):
- yaml = YAML()
- inp = dedent(
- """\
- # comment
- -
- - 1
- - 2
- - 3
- """
- )
- exp = dedent(
- """\
- # comment
- - - 1
- - 2
- - 3
- """
- )
- yaml.round_trip(inp, outp=exp) # NOQA
-
- def test_issue_250(self):
- inp = """
- # 1.
- - - 1
- # 2.
- - map: 2
- # 3.
- - 4
- """
- d = round_trip(inp) # NOQA
-
- # @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError)
- def test_issue_279(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- yaml.indent(sequence=4, offset=2)
- inp = dedent(
- """\
- experiments:
- - datasets:
- # ATLAS EWK
- - {dataset: ATLASWZRAP36PB, frac: 1.0}
- - {dataset: ATLASZHIGHMASS49FB, frac: 1.0}
- """
- )
- a = yaml.load(inp)
- buf = StringIO()
- yaml.dump(a, buf)
- print(buf.getvalue())
- assert buf.getvalue() == inp
-
- def test_issue_280(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.representer import RepresenterError
- from collections import namedtuple
- from sys import stdout
-
- T = namedtuple("T", ("a", "b"))
- t = T(1, 2)
- yaml = YAML()
- with pytest.raises(RepresenterError, match="cannot represent"):
- yaml.dump({"t": t}, stdout)
-
- def test_issue_282(self):
- # update from list of tuples caused AttributeError
- import srsly.ruamel_yaml
-
- yaml_data = srsly.ruamel_yaml.comments.CommentedMap(
- [("a", "apple"), ("b", "banana")]
- )
- yaml_data.update([("c", "cantaloupe")])
- yaml_data.update({"d": "date", "k": "kiwi"})
- assert "c" in yaml_data.keys()
- assert "c" in yaml_data._ok
-
- def test_issue_284(self):
- import srsly.ruamel_yaml
-
- inp = dedent(
- """\
- plain key: in-line value
- : # Both empty
- "quoted key":
- - entry
- """
- )
- yaml = srsly.ruamel_yaml.YAML(typ="rt")
- yaml.version = (1, 2)
- d = yaml.load(inp)
- assert d[None] is None
-
- yaml = srsly.ruamel_yaml.YAML(typ="rt")
- yaml.version = (1, 1)
- with pytest.raises(
- srsly.ruamel_yaml.parser.ParserError, match="expected "
- ):
- d = yaml.load(inp)
-
- def test_issue_285(self):
- from srsly.ruamel_yaml import YAML
-
- yaml = YAML()
- inp = dedent(
- """\
- %YAML 1.1
- ---
- - y
- - n
- - Y
- - N
- """
- )
- a = yaml.load(inp)
- assert a[0]
- assert a[2]
- assert not a[1]
- assert not a[3]
-
- def test_issue_286(self):
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml = YAML()
- inp = dedent(
- """\
- parent_key:
- - sub_key: sub_value
-
- # xxx"""
- )
- a = yaml.load(inp)
- a["new_key"] = "new_value"
- buf = StringIO()
- yaml.dump(a, buf)
- assert buf.getvalue().endswith("xxx\nnew_key: new_value\n")
-
- def test_issue_288(self):
- import sys
- from srsly.ruamel_yaml.compat import StringIO
- from srsly.ruamel_yaml import YAML
-
- yamldoc = dedent(
- """\
- ---
- # Reusable values
- aliases:
- # First-element comment
- - &firstEntry First entry
- # Second-element comment
- - &secondEntry Second entry
-
- # Third-element comment is
- # a multi-line value
- - &thirdEntry Third entry
-
- # EOF Comment
- """
- )
-
- yaml = YAML()
- yaml.indent(mapping=2, sequence=4, offset=2)
- yaml.explicit_start = True
- yaml.preserve_quotes = True
- yaml.width = sys.maxsize
- data = yaml.load(yamldoc)
- buf = StringIO()
- yaml.dump(data, buf)
- assert buf.getvalue() == yamldoc
-
- def test_issue_288a(self):
- import sys
- from srsly.ruamel_yaml.compat import StringIO
- from srsly.ruamel_yaml import YAML
-
- yamldoc = dedent(
- """\
- ---
- # Reusable values
- aliases:
- # First-element comment
- - &firstEntry First entry
- # Second-element comment
- - &secondEntry Second entry
-
- # Third-element comment is
- # a multi-line value
- - &thirdEntry Third entry
-
- # EOF Comment
- """
- )
-
- yaml = YAML()
- yaml.indent(mapping=2, sequence=4, offset=2)
- yaml.explicit_start = True
- yaml.preserve_quotes = True
- yaml.width = sys.maxsize
- data = yaml.load(yamldoc)
- buf = StringIO()
- yaml.dump(data, buf)
- assert buf.getvalue() == yamldoc
-
- def test_issue_290(self):
- import sys
- from srsly.ruamel_yaml.compat import StringIO
- from srsly.ruamel_yaml import YAML
-
- yamldoc = dedent(
- """\
- ---
- aliases:
- # Folded-element comment
- # for a multi-line value
- - &FoldedEntry >
- THIS IS A
- FOLDED, MULTI-LINE
- VALUE
-
- # Literal-element comment
- # for a multi-line value
- - &literalEntry |
- THIS IS A
- LITERAL, MULTI-LINE
- VALUE
-
- # Plain-element comment
- - &plainEntry Plain entry
- """
- )
-
- yaml = YAML()
- yaml.indent(mapping=2, sequence=4, offset=2)
- yaml.explicit_start = True
- yaml.preserve_quotes = True
- yaml.width = sys.maxsize
- data = yaml.load(yamldoc)
- buf = StringIO()
- yaml.dump(data, buf)
- assert buf.getvalue() == yamldoc
-
- def test_issue_290a(self):
- import sys
- from srsly.ruamel_yaml.compat import StringIO
- from srsly.ruamel_yaml import YAML
-
- yamldoc = dedent(
- """\
- ---
- aliases:
- # Folded-element comment
- # for a multi-line value
- - &FoldedEntry >
- THIS IS A
- FOLDED, MULTI-LINE
- VALUE
-
- # Literal-element comment
- # for a multi-line value
- - &literalEntry |
- THIS IS A
- LITERAL, MULTI-LINE
- VALUE
-
- # Plain-element comment
- - &plainEntry Plain entry
- """
- )
-
- yaml = YAML()
- yaml.indent(mapping=2, sequence=4, offset=2)
- yaml.explicit_start = True
- yaml.preserve_quotes = True
- yaml.width = sys.maxsize
- data = yaml.load(yamldoc)
- buf = StringIO()
- yaml.dump(data, buf)
- assert buf.getvalue() == yamldoc
-
- # @pytest.mark.xfail(strict=True, reason='should fail pre 0.15.100', raises=AssertionError)
- def test_issue_295(self):
- # deepcopy also makes a copy of the start and end mark, and these did not
- # have any comparison beyond their ID, which of course changed, breaking
- # some old merge_comment code
- import copy
-
- inp = dedent(
- """
- A:
- b:
- # comment
- - l1
- - l2
-
- C:
- d: e
- f:
- # comment2
- - - l31
- - l32
- - l33: '5'
- """
- )
- data = round_trip_load(inp) # NOQA
- dc = copy.deepcopy(data)
- assert round_trip_dump(dc) == inp
-
- def test_issue_300(self):
- from srsly.ruamel_yaml import YAML
-
- inp = dedent(
- """
- %YAML 1.2
- %TAG ! tag:example.com,2019/path#fragment
- ---
- null
- """
- )
- YAML().load(inp)
-
- def test_issue_300a(self):
- import srsly.ruamel_yaml
-
- inp = dedent(
- """
- %YAML 1.1
- %TAG ! tag:example.com,2019/path#fragment
- ---
- null
- """
- )
- yaml = YAML()
- with pytest.raises(
- srsly.ruamel_yaml.scanner.ScannerError, match="while scanning a directive"
- ):
- yaml.load(inp)
-
- def test_issue_304(self):
- inp = """
- %YAML 1.2
- %TAG ! tag:example.com,2019:
- ---
- !foo null
- ...
- """
- d = na_round_trip(inp) # NOQA
-
- def test_issue_305(self):
- inp = """
- %YAML 1.2
- ---
- ! null
- ...
- """
- d = na_round_trip(inp) # NOQA
-
- def test_issue_307(self):
- inp = """
- %YAML 1.2
- %TAG ! tag:example.com,2019/path#
- ---
- null
- ...
- """
- d = na_round_trip(inp) # NOQA
-
-
-# @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError)
-# def test_issue_ xxx(self):
-# inp = """
-# """
-# d = round_trip(inp) # NOQA
diff --git a/srsly/tests/ruamel_yaml/test_json_numbers.py b/srsly/tests/ruamel_yaml/test_json_numbers.py
deleted file mode 100755
index 0998e39..0000000
--- a/srsly/tests/ruamel_yaml/test_json_numbers.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-import pytest # NOQA
-
-import json
-
-
-def load(s, typ=float):
- import srsly.ruamel_yaml
-
- x = '{"low": %s }' % (s)
- print("input: [%s]" % (s), repr(x))
- # just to check it is loadable json
- res = json.loads(x)
- assert isinstance(res["low"], typ)
- ret_val = srsly.ruamel_yaml.load(x, srsly.ruamel_yaml.RoundTripLoader)
- print(ret_val)
- return ret_val["low"]
-
-
-class TestJSONNumbers:
- # based on http://stackoverflow.com/a/30462009/1307905
- # yaml number regex: http://yaml.org/spec/1.2/spec.html#id2804092
- #
- # -? [1-9] ( \. [0-9]* [1-9] )? ( e [-+] [1-9] [0-9]* )?
- #
- # which is not a superset of the JSON numbers
- def test_json_number_float(self):
- for x in (
- y.split("#")[0].strip()
- for y in """
- 1.0 # should fail on YAML spec on 1-9 allowed as single digit
- -1.0
- 1e-06
- 3.1e-5
- 3.1e+5
- 3.1e5 # should fail on YAML spec: no +- after e
- """.splitlines()
- ):
- if not x:
- continue
- res = load(x)
- assert isinstance(res, float)
-
- def test_json_number_int(self):
- for x in (
- y.split("#")[0].strip()
- for y in """
- 42
- """.splitlines()
- ):
- if not x:
- continue
- res = load(x, int)
- assert isinstance(res, int)
diff --git a/srsly/tests/ruamel_yaml/test_line_col.py b/srsly/tests/ruamel_yaml/test_line_col.py
deleted file mode 100755
index 10bbd2e..0000000
--- a/srsly/tests/ruamel_yaml/test_line_col.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# coding: utf-8
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-def load(s):
- return round_trip_load(dedent(s))
-
-
-class TestLineCol:
- def test_item_00(self):
- data = load(
- """
- - a
- - e
- - [b, d]
- - c
- """
- )
- assert data[2].lc.line == 2
- assert data[2].lc.col == 2
-
- def test_item_01(self):
- data = load(
- """
- - a
- - e
- - {x: 3}
- - c
- """
- )
- assert data[2].lc.line == 2
- assert data[2].lc.col == 2
-
- def test_item_02(self):
- data = load(
- """
- - a
- - e
- - !!set {x, y}
- - c
- """
- )
- assert data[2].lc.line == 2
- assert data[2].lc.col == 2
-
- def test_item_03(self):
- data = load(
- """
- - a
- - e
- - !!omap
- - x: 1
- - y: 3
- - c
- """
- )
- assert data[2].lc.line == 2
- assert data[2].lc.col == 2
-
- def test_item_04(self):
- data = load(
- """
- # testing line and column based on SO
- # http://stackoverflow.com/questions/13319067/
- - key1: item 1
- key2: item 2
- - key3: another item 1
- key4: another item 2
- """
- )
- assert data[0].lc.line == 2
- assert data[0].lc.col == 2
- assert data[1].lc.line == 4
- assert data[1].lc.col == 2
-
- def test_pos_mapping(self):
- data = load(
- """
- a: 1
- b: 2
- c: 3
- # comment
- klm: 42
- d: 4
- """
- )
- assert data.lc.key("klm") == (4, 0)
- assert data.lc.value("klm") == (4, 5)
-
- def test_pos_sequence(self):
- data = load(
- """
- - a
- - b
- - c
- # next one!
- - klm
- - d
- """
- )
- assert data.lc.item(3) == (4, 2)
diff --git a/srsly/tests/ruamel_yaml/test_literal.py b/srsly/tests/ruamel_yaml/test_literal.py
deleted file mode 100755
index a6c1624..0000000
--- a/srsly/tests/ruamel_yaml/test_literal.py
+++ /dev/null
@@ -1,335 +0,0 @@
-# coding: utf-8
-from __future__ import print_function
-
-import pytest # NOQA
-
-from .roundtrip import YAML # does an automatic dedent on load
-
-
-"""
-YAML 1.0 allowed root level literal style without indentation:
- "Usually top level nodes are not indented" (example 4.21 in 4.6.3)
-YAML 1.1 is a bit vague but says:
- "Regardless of style, scalar content must always be indented by at least one space"
- (4.4.3)
- "In general, the document’s node is indented as if it has a parent indented at -1 spaces."
- (4.3.3)
-YAML 1.2 is again clear about root literal level scalar after directive in example 9.5:
-
-%YAML 1.2
---- |
-%!PS-Adobe-2.0
-...
-%YAML1.2
----
-# Empty
-...
-"""
-
-
-class TestNoIndent:
- def test_root_literal_scalar_indent_example_9_5(self):
- yaml = YAML()
- s = "%!PS-Adobe-2.0"
- inp = """
- --- |
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_no_indent(self):
- yaml = YAML()
- s = "testing123"
- inp = """
- --- |
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_no_indent_1_1(self):
- yaml = YAML()
- s = "testing123"
- inp = """
- %YAML 1.1
- --- |
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_no_indent_1_1_old_style(self):
- from textwrap import dedent
- from srsly.ruamel_yaml import safe_load
-
- s = "testing123"
- inp = """
- %YAML 1.1
- --- |
- {}
- """
- d = safe_load(dedent(inp.format(s)))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_no_indent_1_1_no_raise(self):
- # from srsly.ruamel_yaml.parser import ParserError
-
- yaml = YAML()
- yaml.root_level_block_style_scalar_no_indent_error_1_1 = True
- s = "testing123"
- # with pytest.raises(ParserError):
- if True:
- inp = """
- %YAML 1.1
- --- |
- {}
- """
- yaml.load(inp.format(s))
-
- def test_root_literal_scalar_indent_offset_one(self):
- yaml = YAML()
- s = "testing123"
- inp = """
- --- |1
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_indent_offset_four(self):
- yaml = YAML()
- s = "testing123"
- inp = """
- --- |4
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_scalar_indent_offset_two_leading_space(self):
- yaml = YAML()
- s = " testing123"
- inp = """
- --- |4
- {s}
- {s}
- """
- d = yaml.load(inp.format(s=s))
- print(d)
- assert d == (s + "\n") * 2
-
- def test_root_literal_scalar_no_indent_special(self):
- yaml = YAML()
- s = "%!PS-Adobe-2.0"
- inp = """
- --- |
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_folding_scalar_indent(self):
- yaml = YAML()
- s = "%!PS-Adobe-2.0"
- inp = """
- --- >
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_folding_scalar_no_indent(self):
- yaml = YAML()
- s = "testing123"
- inp = """
- --- >
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_folding_scalar_no_indent_special(self):
- yaml = YAML()
- s = "%!PS-Adobe-2.0"
- inp = """
- --- >
- {}
- """
- d = yaml.load(inp.format(s))
- print(d)
- assert d == s + "\n"
-
- def test_root_literal_multi_doc(self):
- yaml = YAML(typ="safe", pure=True)
- s1 = "abc"
- s2 = "klm"
- inp = """
- --- |-
- {}
- --- |
- {}
- """
- for idx, d1 in enumerate(yaml.load_all(inp.format(s1, s2))):
- print("d1:", d1)
- assert ["abc", "klm\n"][idx] == d1
-
- def test_root_literal_doc_indent_directives_end(self):
- yaml = YAML()
- yaml.explicit_start = True
- inp = """
- --- |-
- %YAML 1.3
- ---
- this: is a test
- """
- yaml.round_trip(inp)
-
- def test_root_literal_doc_indent_document_end(self):
- yaml = YAML()
- yaml.explicit_start = True
- inp = """
- --- |-
- some more
- ...
- text
- """
- yaml.round_trip(inp)
-
- def test_root_literal_doc_indent_marker(self):
- yaml = YAML()
- yaml.explicit_start = True
- inp = """
- --- |2
- some more
- text
- """
- d = yaml.load(inp)
- print(type(d), repr(d))
- yaml.round_trip(inp)
-
- def test_nested_literal_doc_indent_marker(self):
- yaml = YAML()
- yaml.explicit_start = True
- inp = """
- ---
- a: |2
- some more
- text
- """
- d = yaml.load(inp)
- print(type(d), repr(d))
- yaml.round_trip(inp)
-
-
-class Test_RoundTripLiteral:
- def test_rt_root_literal_scalar_no_indent(self):
- yaml = YAML()
- yaml.explicit_start = True
- s = "testing123"
- ys = """
- --- |
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_literal_scalar_indent(self):
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent = 4
- s = "testing123"
- ys = """
- --- |
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_plain_scalar_no_indent(self):
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent = 0
- s = "testing123"
- ys = """
- ---
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_plain_scalar_expl_indent(self):
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent = 4
- s = "testing123"
- ys = """
- ---
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_sq_scalar_expl_indent(self):
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent = 4
- s = "'testing: 123'"
- ys = """
- ---
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_dq_scalar_expl_indent(self):
- # if yaml.indent is the default (None)
- # then write after the directive indicator
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent = 0
- s = '"\'testing123"'
- ys = """
- ---
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_root_literal_scalar_no_indent_no_eol(self):
- yaml = YAML()
- yaml.explicit_start = True
- s = "testing123"
- ys = """
- --- |-
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
-
- def test_rt_non_root_literal_scalar(self):
- yaml = YAML()
- s = "testing123"
- ys = """
- - |
- {}
- """
- ys = ys.format(s)
- d = yaml.load(ys)
- yaml.dump(d, compare=ys)
diff --git a/srsly/tests/ruamel_yaml/test_none.py b/srsly/tests/ruamel_yaml/test_none.py
deleted file mode 100755
index 878c7e7..0000000
--- a/srsly/tests/ruamel_yaml/test_none.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# coding: utf-8
-
-
-import pytest # NOQA
-
-
-class TestNone:
- def test_dump00(self):
- import srsly.ruamel_yaml # NOQA
-
- data = None
- s = srsly.ruamel_yaml.round_trip_dump(data)
- assert s == "null\n...\n"
- d = srsly.ruamel_yaml.round_trip_load(s)
- assert d == data
-
- def test_dump01(self):
- import srsly.ruamel_yaml # NOQA
-
- data = None
- s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=True)
- assert s == "null\n...\n"
- d = srsly.ruamel_yaml.round_trip_load(s)
- assert d == data
-
- def test_dump02(self):
- import srsly.ruamel_yaml # NOQA
-
- data = None
- s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=False)
- assert s == "null\n...\n"
- d = srsly.ruamel_yaml.round_trip_load(s)
- assert d == data
-
- def test_dump03(self):
- import srsly.ruamel_yaml # NOQA
-
- data = None
- s = srsly.ruamel_yaml.round_trip_dump(data, explicit_start=True)
- assert s == "---\n...\n"
- d = srsly.ruamel_yaml.round_trip_load(s)
- assert d == data
-
- def test_dump04(self):
- import srsly.ruamel_yaml # NOQA
-
- data = None
- s = srsly.ruamel_yaml.round_trip_dump(
- data, explicit_start=True, explicit_end=False
- )
- assert s == "---\n...\n"
- d = srsly.ruamel_yaml.round_trip_load(s)
- assert d == data
diff --git a/srsly/tests/ruamel_yaml/test_numpy.py b/srsly/tests/ruamel_yaml/test_numpy.py
deleted file mode 100755
index 2c21854..0000000
--- a/srsly/tests/ruamel_yaml/test_numpy.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, absolute_import, division, unicode_literals
-
-try:
- import numpy
-except: # NOQA
- numpy = None
-
-
-def Xtest_numpy():
- import srsly.ruamel_yaml
-
- if numpy is None:
- return
- data = numpy.arange(10)
- print("data", type(data), data)
-
- yaml_str = srsly.ruamel_yaml.dump(data)
- datb = srsly.ruamel_yaml.load(yaml_str)
- print("datb", type(datb), datb)
-
- print("\nYAML", yaml_str)
- assert data == datb
diff --git a/srsly/tests/ruamel_yaml/test_program_config.py b/srsly/tests/ruamel_yaml/test_program_config.py
deleted file mode 100755
index 7e5331b..0000000
--- a/srsly/tests/ruamel_yaml/test_program_config.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import pytest # NOQA
-
-# import srsly.ruamel_yaml
-from .roundtrip import round_trip
-
-
-class TestProgramConfig:
- def test_application_arguments(self):
- # application configur
- round_trip(
- """
- args:
- username: anthon
- passwd: secret
- fullname: Anthon van der Neut
- tmux:
- session-name: test
- loop:
- wait: 10
- """
- )
-
- def test_single(self):
- # application configuration
- round_trip(
- """
- # default arguments for the program
- args: # needed to prevent comment wrapping
- # this should be your username
- username: anthon
- passwd: secret # this is plaintext don't reuse \
-# important/system passwords
- fullname: Anthon van der Neut
- tmux:
- session-name: test # make sure this doesn't clash with
- # other sessions
- loop: # looping related defaults
- # experiment with the following
- wait: 10
- # no more argument info to pass
- """
- )
-
- def test_multi(self):
- # application configuration
- round_trip(
- """
- # default arguments for the program
- args: # needed to prevent comment wrapping
- # this should be your username
- username: anthon
- passwd: secret # this is plaintext don't reuse
- # important/system passwords
- fullname: Anthon van der Neut
- tmux:
- session-name: test # make sure this doesn't clash with
- # other sessions
- loop: # looping related defaults
- # experiment with the following
- wait: 10
- # no more argument info to pass
- """
- )
diff --git a/srsly/tests/ruamel_yaml/test_spec_examples.py b/srsly/tests/ruamel_yaml/test_spec_examples.py
deleted file mode 100755
index 666d641..0000000
--- a/srsly/tests/ruamel_yaml/test_spec_examples.py
+++ /dev/null
@@ -1,334 +0,0 @@
-from .roundtrip import YAML
-import pytest # NOQA
-
-
-def test_example_2_1():
- yaml = YAML()
- yaml.round_trip(
- """
- - Mark McGwire
- - Sammy Sosa
- - Ken Griffey
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_2():
- yaml = YAML()
- yaml.mapping_value_align = True
- yaml.round_trip(
- """
- hr: 65 # Home runs
- avg: 0.278 # Batting average
- rbi: 147 # Runs Batted In
- """
- )
-
-
-def test_example_2_3():
- yaml = YAML()
- yaml.indent(sequence=4, offset=2)
- yaml.round_trip(
- """
- american:
- - Boston Red Sox
- - Detroit Tigers
- - New York Yankees
- national:
- - New York Mets
- - Chicago Cubs
- - Atlanta Braves
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_4():
- yaml = YAML()
- yaml.mapping_value_align = True
- yaml.round_trip(
- """
- -
- name: Mark McGwire
- hr: 65
- avg: 0.278
- -
- name: Sammy Sosa
- hr: 63
- avg: 0.288
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_5():
- yaml = YAML()
- yaml.flow_sequence_element_align = True
- yaml.round_trip(
- """
- - [name , hr, avg ]
- - [Mark McGwire, 65, 0.278]
- - [Sammy Sosa , 63, 0.288]
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_6():
- yaml = YAML()
- # yaml.flow_mapping_final_comma = False
- yaml.flow_mapping_one_element_per_line = True
- yaml.round_trip(
- """
- Mark McGwire: {hr: 65, avg: 0.278}
- Sammy Sosa: {
- hr: 63,
- avg: 0.288
- }
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_7():
- yaml = YAML()
- yaml.round_trip_all(
- """
- # Ranking of 1998 home runs
- ---
- - Mark McGwire
- - Sammy Sosa
- - Ken Griffey
-
- # Team ranking
- ---
- - Chicago Cubs
- - St Louis Cardinals
- """
- )
-
-
-def test_example_2_8():
- yaml = YAML()
- yaml.explicit_start = True
- yaml.explicit_end = True
- yaml.round_trip_all(
- """
- ---
- time: 20:03:20
- player: Sammy Sosa
- action: strike (miss)
- ...
- ---
- time: 20:03:47
- player: Sammy Sosa
- action: grand slam
- ...
- """
- )
-
-
-def test_example_2_9():
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent(sequence=4, offset=2)
- yaml.round_trip(
- """
- ---
- hr: # 1998 hr ranking
- - Mark McGwire
- - Sammy Sosa
- rbi:
- # 1998 rbi ranking
- - Sammy Sosa
- - Ken Griffey
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_10():
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent(sequence=4, offset=2)
- yaml.round_trip(
- """
- ---
- hr:
- - Mark McGwire
- # Following node labeled SS
- - &SS Sammy Sosa
- rbi:
- - *SS # Subsequent occurrence
- - Ken Griffey
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_11():
- yaml = YAML()
- yaml.round_trip(
- """
- ? - Detroit Tigers
- - Chicago cubs
- :
- - 2001-07-23
-
- ? [ New York Yankees,
- Atlanta Braves ]
- : [ 2001-07-02, 2001-08-12,
- 2001-08-14 ]
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_12():
- yaml = YAML()
- yaml.explicit_start = True
- yaml.round_trip(
- """
- ---
- # Products purchased
- - item : Super Hoop
- quantity: 1
- - item : Basketball
- quantity: 4
- - item : Big Shoes
- quantity: 1
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_13():
- yaml = YAML()
- yaml.round_trip(
- r"""
- # ASCII Art
- --- |
- \//||\/||
- // || ||__
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_14():
- yaml = YAML()
- yaml.explicit_start = True
- yaml.indent(root_scalar=2) # needs to be added
- yaml.round_trip(
- """
- --- >
- Mark McGwire's
- year was crippled
- by a knee injury.
- """
- )
-
-
-@pytest.mark.xfail(strict=True)
-def test_example_2_15():
- yaml = YAML()
- yaml.round_trip(
- """
- >
- Sammy Sosa completed another
- fine season with great stats.
-
- 63 Home Runs
- 0.288 Batting Average
-
- What a year!
- """
- )
-
-
-def test_example_2_16():
- yaml = YAML()
- yaml.round_trip(
- """
- name: Mark McGwire
- accomplishment: >
- Mark set a major league
- home run record in 1998.
- stats: |
- 65 Home Runs
- 0.278 Batting Average
- """
- )
-
-
-@pytest.mark.xfail(
- strict=True, reason="cannot YAML dump escape sequences (\n) as hex and normal"
-)
-def test_example_2_17():
- yaml = YAML()
- yaml.allow_unicode = False
- yaml.preserve_quotes = True
- yaml.round_trip(
- r"""
- unicode: "Sosa did fine.\u263A"
- control: "\b1998\t1999\t2000\n"
- hex esc: "\x0d\x0a is \r\n"
-
- single: '"Howdy!" he cried.'
- quoted: ' # Not a ''comment''.'
- tie-fighter: '|\-*-/|'
- """
- )
-
-
-@pytest.mark.xfail(
- strict=True, reason="non-literal/folding multiline scalars not supported"
-)
-def test_example_2_18():
- yaml = YAML()
- yaml.round_trip(
- """
- plain:
- This unquoted scalar
- spans many lines.
-
- quoted: "So does this
- quoted scalar.\n"
- """
- )
-
-
-@pytest.mark.xfail(strict=True, reason="leading + on decimal dropped")
-def test_example_2_19():
- yaml = YAML()
- yaml.round_trip(
- """
- canonical: 12345
- decimal: +12345
- octal: 0o14
- hexadecimal: 0xC
- """
- )
-
-
-@pytest.mark.xfail(strict=True, reason="case of NaN not preserved")
-def test_example_2_20():
- yaml = YAML()
- yaml.round_trip(
- """
- canonical: 1.23015e+3
- exponential: 12.3015e+02
- fixed: 1230.15
- negative infinity: -.inf
- not a number: .NaN
- """
- )
-
-
-def Xtest_example_2_X():
- yaml = YAML()
- yaml.round_trip(
- """
- """
- )
diff --git a/srsly/tests/ruamel_yaml/test_string.py b/srsly/tests/ruamel_yaml/test_string.py
deleted file mode 100755
index c82eb81..0000000
--- a/srsly/tests/ruamel_yaml/test_string.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-"""
-various test cases for string scalars in YAML files
-'|' for preserved newlines
-'>' for folded (newlines become spaces)
-
-and the chomping modifiers:
-'-' for stripping: final line break and any trailing empty lines are excluded
-'+' for keeping: final line break and empty lines are preserved
-'' for clipping: final line break preserved, empty lines at end not
- included in content (no modifier)
-
-"""
-
-import pytest
-import platform
-
-# from srsly.ruamel_yaml.compat import ordereddict
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-class TestLiteralScalarString:
- def test_basic_string(self):
- round_trip(
- """
- a: abcdefg
- """
- )
-
- def test_quoted_integer_string(self):
- round_trip(
- """
- a: '12345'
- """
- )
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_preserve_string(self):
- inp = """
- a: |
- abc
- def
- """
- round_trip(inp, intermediate=dict(a="abc\ndef\n"))
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_preserve_string_strip(self):
- s = """
- a: |-
- abc
- def
-
- """
- round_trip(s, intermediate=dict(a="abc\ndef"))
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_preserve_string_keep(self):
- # with pytest.raises(AssertionError) as excinfo:
- inp = """
- a: |+
- ghi
- jkl
-
-
- b: x
- """
- round_trip(inp, intermediate=dict(a="ghi\njkl\n\n\n", b="x"))
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_preserve_string_keep_at_end(self):
- # at EOF you have to specify the ... to get proper "closure"
- # of the multiline scalar
- inp = """
- a: |+
- ghi
- jkl
-
- ...
- """
- round_trip(inp, intermediate=dict(a="ghi\njkl\n\n"))
-
- def test_fold_string(self):
- inp = """
- a: >
- abc
- def
-
- """
- round_trip(inp)
-
- def test_fold_string_strip(self):
- inp = """
- a: >-
- abc
- def
-
- """
- round_trip(inp)
-
- def test_fold_string_keep(self):
- with pytest.raises(AssertionError) as excinfo: # NOQA
- inp = """
- a: >+
- abc
- def
-
- """
- round_trip(inp, intermediate=dict(a="abc def\n\n"))
-
-
-class TestQuotedScalarString:
- def test_single_quoted_string(self):
- inp = """
- a: 'abc'
- """
- round_trip(inp, preserve_quotes=True)
-
- def test_double_quoted_string(self):
- inp = """
- a: "abc"
- """
- round_trip(inp, preserve_quotes=True)
-
- def test_non_preserved_double_quoted_string(self):
- inp = """
- a: "abc"
- """
- exp = """
- a: abc
- """
- round_trip(inp, outp=exp)
-
-
-class TestReplace:
- """inspired by issue 110 from sandres23"""
-
- def test_replace_preserved_scalar_string(self):
- import srsly
-
- s = dedent(
- """\
- foo: |
- foo
- foo
- bar
- foo
- """
- )
- data = round_trip_load(s, preserve_quotes=True)
- so = data["foo"].replace("foo", "bar", 2)
- assert isinstance(so, srsly.ruamel_yaml.scalarstring.LiteralScalarString)
- assert so == dedent(
- """
- bar
- bar
- bar
- foo
- """
- )
-
- def test_replace_double_quoted_scalar_string(self):
- import srsly
-
- s = dedent(
- """\
- foo: "foo foo bar foo"
- """
- )
- data = round_trip_load(s, preserve_quotes=True)
- so = data["foo"].replace("foo", "bar", 2)
- assert isinstance(so, srsly.ruamel_yaml.scalarstring.DoubleQuotedScalarString)
- assert so == "bar bar bar foo"
-
-
-class TestWalkTree:
- def test_basic(self):
- from srsly.ruamel_yaml.comments import CommentedMap
- from srsly.ruamel_yaml.scalarstring import walk_tree
-
- data = CommentedMap()
- data[1] = "a"
- data[2] = "with\nnewline\n"
- walk_tree(data)
- exp = """\
- 1: a
- 2: |
- with
- newline
- """
- assert round_trip_dump(data) == dedent(exp)
-
- def test_map(self):
- from srsly.ruamel_yaml.compat import ordereddict
- from srsly.ruamel_yaml.comments import CommentedMap
- from srsly.ruamel_yaml.scalarstring import walk_tree, preserve_literal
- from srsly.ruamel_yaml.scalarstring import DoubleQuotedScalarString as dq
- from srsly.ruamel_yaml.scalarstring import SingleQuotedScalarString as sq
-
- data = CommentedMap()
- data[1] = "a"
- data[2] = "with\nnew : line\n"
- data[3] = "${abc}"
- data[4] = "almost:mapping"
- m = ordereddict([("\n", preserve_literal), ("${", sq), (":", dq)])
- walk_tree(data, map=m)
- exp = """\
- 1: a
- 2: |
- with
- new : line
- 3: '${abc}'
- 4: "almost:mapping"
- """
- assert round_trip_dump(data) == dedent(exp)
diff --git a/srsly/tests/ruamel_yaml/test_tag.py b/srsly/tests/ruamel_yaml/test_tag.py
deleted file mode 100755
index 4e4860c..0000000
--- a/srsly/tests/ruamel_yaml/test_tag.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# coding: utf-8
-
-import pytest # NOQA
-
-from .roundtrip import round_trip, round_trip_load, YAML
-
-
-def register_xxx(**kw):
- import srsly.ruamel_yaml as yaml
-
- class XXX(yaml.comments.CommentedMap):
- @staticmethod
- def yaml_dump(dumper, data):
- return dumper.represent_mapping(u"!xxx", data)
-
- @classmethod
- def yaml_load(cls, constructor, node):
- data = cls()
- yield data
- constructor.construct_mapping(node, data)
-
- yaml.add_constructor(u"!xxx", XXX.yaml_load, constructor=yaml.RoundTripConstructor)
- yaml.add_representer(XXX, XXX.yaml_dump, representer=yaml.RoundTripRepresenter)
-
-
-class TestIndentFailures:
- def test_tag(self):
- round_trip(
- """\
- !!python/object:__main__.Developer
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_full_tag(self):
- round_trip(
- """\
- !!tag:yaml.org,2002:python/object:__main__.Developer
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_standard_tag(self):
- round_trip(
- """\
- !!tag:yaml.org,2002:python/object:map
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_Y1(self):
- round_trip(
- """\
- !yyy
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_Y2(self):
- round_trip(
- """\
- !!yyy
- name: Anthon
- location: Germany
- language: python
- """
- )
-
-
-class TestRoundTripCustom:
- def test_X1(self):
- register_xxx()
- round_trip(
- """\
- !xxx
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- @pytest.mark.xfail(strict=True)
- def test_X_pre_tag_comment(self):
- register_xxx()
- round_trip(
- """\
- -
- # hello
- !xxx
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- @pytest.mark.xfail(strict=True)
- def test_X_post_tag_comment(self):
- register_xxx()
- round_trip(
- """\
- - !xxx
- # hello
- name: Anthon
- location: Germany
- language: python
- """
- )
-
- def test_scalar_00(self):
- # https://stackoverflow.com/a/45967047/1307905
- round_trip(
- """\
- Outputs:
- Vpc:
- Value: !Ref: vpc # first tag
- Export:
- Name: !Sub "${AWS::StackName}-Vpc" # second tag
- """
- )
-
-
-class TestIssue201:
- def test_encoded_unicode_tag(self):
- round_trip_load(
- """
- s: !!python/%75nicode 'abc'
- """
- )
-
-
-class TestImplicitTaggedNodes:
- def test_scalar(self):
- round_trip(
- """\
- - !Scalar abcdefg
- """
- )
-
- def test_mapping(self):
- round_trip(
- """\
- - !Mapping {a: 1, b: 2}
- """
- )
-
- def test_sequence(self):
- yaml = YAML()
- yaml.brace_single_entry_mapping_in_flow_sequence = True
- yaml.mapping_value_align = True
- yaml.round_trip(
- """
- - !Sequence [a, {b: 1}, {c: {d: 3}}]
- """
- )
-
- def test_sequence2(self):
- yaml = YAML()
- yaml.mapping_value_align = True
- yaml.round_trip(
- """
- - !Sequence [a, b: 1, c: {d: 3}]
- """
- )
diff --git a/srsly/tests/ruamel_yaml/test_version.py b/srsly/tests/ruamel_yaml/test_version.py
deleted file mode 100755
index 2dbf53d..0000000
--- a/srsly/tests/ruamel_yaml/test_version.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# coding: utf-8
-
-import pytest # NOQA
-
-from .roundtrip import dedent, round_trip, round_trip_load
-
-
-def load(s, version=None):
- import srsly.ruamel_yaml # NOQA
-
- return srsly.ruamel_yaml.round_trip_load(dedent(s), version)
-
-
-class TestVersions:
- def test_explicit_1_2(self):
- r = load(
- """\
- %YAML 1.2
- ---
- - 12:34:56
- - 012
- - 012345678
- - 0o12
- - on
- - off
- - yes
- - no
- - true
- """
- )
- assert r[0] == "12:34:56"
- assert r[1] == 12
- assert r[2] == 12345678
- assert r[3] == 10
- assert r[4] == "on"
- assert r[5] == "off"
- assert r[6] == "yes"
- assert r[7] == "no"
- assert r[8] is True
-
- def test_explicit_1_1(self):
- r = load(
- """\
- %YAML 1.1
- ---
- - 12:34:56
- - 012
- - 012345678
- - 0o12
- - on
- - off
- - yes
- - no
- - true
- """
- )
- assert r[0] == 45296
- assert r[1] == 10
- assert r[2] == "012345678"
- assert r[3] == "0o12"
- assert r[4] is True
- assert r[5] is False
- assert r[6] is True
- assert r[7] is False
- assert r[8] is True
-
- def test_implicit_1_2(self):
- r = load(
- """\
- - 12:34:56
- - 12:34:56.78
- - 012
- - 012345678
- - 0o12
- - on
- - off
- - yes
- - no
- - true
- """
- )
- assert r[0] == "12:34:56"
- assert r[1] == "12:34:56.78"
- assert r[2] == 12
- assert r[3] == 12345678
- assert r[4] == 10
- assert r[5] == "on"
- assert r[6] == "off"
- assert r[7] == "yes"
- assert r[8] == "no"
- assert r[9] is True
-
- def test_load_version_1_1(self):
- inp = """\
- - 12:34:56
- - 12:34:56.78
- - 012
- - 012345678
- - 0o12
- - on
- - off
- - yes
- - no
- - true
- """
- r = load(inp, version="1.1")
- assert r[0] == 45296
- assert r[1] == 45296.78
- assert r[2] == 10
- assert r[3] == "012345678"
- assert r[4] == "0o12"
- assert r[5] is True
- assert r[6] is False
- assert r[7] is True
- assert r[8] is False
- assert r[9] is True
-
-
-class TestIssue62:
- # bitbucket issue 62, issue_62
- def test_00(self):
- import srsly.ruamel_yaml # NOQA
-
- s = dedent(
- """\
- {}# Outside flow collection:
- - ::vector
- - ": - ()"
- - Up, up, and away!
- - -123
- - http://example.com/foo#bar
- # Inside flow collection:
- - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar]
- """
- )
- with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
- round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
- round_trip(s.format(""), preserve_quotes=True)
-
- def test_00_single_comment(self):
- import srsly.ruamel_yaml # NOQA
-
- s = dedent(
- """\
- {}# Outside flow collection:
- - ::vector
- - ": - ()"
- - Up, up, and away!
- - -123
- - http://example.com/foo#bar
- - [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar]
- """
- )
- with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
- round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
- round_trip(s.format(""), preserve_quotes=True)
- # round_trip(s.format('%YAML 1.2\n---\n'), preserve_quotes=True, version=(1, 2))
-
- def test_01(self):
- import srsly.ruamel_yaml # NOQA
-
- s = dedent(
- """\
- {}[random plain value that contains a ? character]
- """
- )
- with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
- round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
- round_trip(s.format(""), preserve_quotes=True)
- # note the flow seq on the --- line!
- round_trip(s.format("%YAML 1.2\n--- "), preserve_quotes=True, version="1.2")
-
- def test_so_45681626(self):
- # was not properly parsing
- round_trip_load('{"in":{},"out":{}}')
diff --git a/srsly/tests/ruamel_yaml/test_yamlfile.py b/srsly/tests/ruamel_yaml/test_yamlfile.py
deleted file mode 100755
index f30c47b..0000000
--- a/srsly/tests/ruamel_yaml/test_yamlfile.py
+++ /dev/null
@@ -1,243 +0,0 @@
-from __future__ import print_function
-
-"""
-various test cases for YAML files
-"""
-
-import sys
-import pytest # NOQA
-import platform
-
-from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
-
-
-class TestYAML:
- def test_backslash(self):
- round_trip(
- """
- handlers:
- static_files: applications/\\1/static/\\2
- """
- )
-
- def test_omap_out(self):
- # ordereddict mapped to !!omap
- from srsly.ruamel_yaml.compat import ordereddict
- import srsly.ruamel_yaml # NOQA
-
- x = ordereddict([("a", 1), ("b", 2)])
- res = srsly.ruamel_yaml.dump(x, default_flow_style=False)
- assert res == dedent(
- """
- !!omap
- - a: 1
- - b: 2
- """
- )
-
- def test_omap_roundtrip(self):
- round_trip(
- """
- !!omap
- - a: 1
- - b: 2
- - c: 3
- - d: 4
- """
- )
-
- @pytest.mark.skipif(sys.version_info < (2, 7), reason="collections not available")
- def test_dump_collections_ordereddict(self):
- from collections import OrderedDict
- import srsly.ruamel_yaml # NOQA
-
- # OrderedDict mapped to !!omap
- x = OrderedDict([("a", 1), ("b", 2)])
- res = srsly.ruamel_yaml.dump(
- x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False
- )
- assert res == dedent(
- """
- !!omap
- - a: 1
- - b: 2
- """
- )
-
- @pytest.mark.skipif(
- sys.version_info >= (3, 0) or platform.python_implementation() != "CPython",
- reason="srsly.ruamel_yaml not available",
- )
- def test_dump_ruamel_ordereddict(self):
- from srsly.ruamel_yaml.compat import ordereddict
- import srsly.ruamel_yaml # NOQA
-
- # OrderedDict mapped to !!omap
- x = ordereddict([("a", 1), ("b", 2)])
- res = srsly.ruamel_yaml.dump(
- x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False
- )
- assert res == dedent(
- """
- !!omap
- - a: 1
- - b: 2
- """
- )
-
- def test_CommentedSet(self):
- from srsly.ruamel_yaml.constructor import CommentedSet
-
- s = CommentedSet(["a", "b", "c"])
- s.remove("b")
- s.add("d")
- assert s == CommentedSet(["a", "c", "d"])
- s.add("e")
- s.add("f")
- s.remove("e")
- assert s == CommentedSet(["a", "c", "d", "f"])
-
- def test_set_out(self):
- # preferable would be the shorter format without the ': null'
- import srsly.ruamel_yaml # NOQA
-
- x = set(["a", "b", "c"])
- res = srsly.ruamel_yaml.dump(x, default_flow_style=False)
- assert res == dedent(
- """
- !!set
- a: null
- b: null
- c: null
- """
- )
-
- # ordering is not preserved in a set
- def test_set_compact(self):
- # this format is read and also should be written by default
- round_trip(
- """
- !!set
- ? a
- ? b
- ? c
- """
- )
-
- def test_blank_line_after_comment(self):
- round_trip(
- """
- # Comment with spaces after it.
-
-
- a: 1
- """
- )
-
- def test_blank_line_between_seq_items(self):
- round_trip(
- """
- # Seq with empty lines in between items.
- b:
- - bar
-
-
- - baz
- """
- )
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_blank_line_after_literal_chip(self):
- s = """
- c:
- - |
- This item
- has a blank line
- following it.
-
- - |
- To visually separate it from this item.
-
- This item contains a blank line.
-
-
- """
- d = round_trip_load(dedent(s))
- print(d)
- round_trip(s)
- assert d["c"][0].split("it.")[1] == "\n"
- assert d["c"][1].split("line.")[1] == "\n"
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_blank_line_after_literal_keep(self):
- """ have to insert an eof marker in YAML to test this"""
- s = """
- c:
- - |+
- This item
- has a blank line
- following it.
-
- - |+
- To visually separate it from this item.
-
- This item contains a blank line.
-
-
- ...
- """
- d = round_trip_load(dedent(s))
- print(d)
- round_trip(s)
- assert d["c"][0].split("it.")[1] == "\n\n"
- assert d["c"][1].split("line.")[1] == "\n\n\n"
-
- @pytest.mark.skipif(
- platform.python_implementation() == "Jython",
- reason="Jython throws RepresenterError",
- )
- def test_blank_line_after_literal_strip(self):
- s = """
- c:
- - |-
- This item
- has a blank line
- following it.
-
- - |-
- To visually separate it from this item.
-
- This item contains a blank line.
-
-
- """
- d = round_trip_load(dedent(s))
- print(d)
- round_trip(s)
- assert d["c"][0].split("it.")[1] == ""
- assert d["c"][1].split("line.")[1] == ""
-
- def test_load_all_perserve_quotes(self):
- import srsly.ruamel_yaml # NOQA
-
- s = dedent(
- """\
- a: 'hello'
- ---
- b: "goodbye"
- """
- )
- data = []
- for x in srsly.ruamel_yaml.round_trip_load_all(s, preserve_quotes=True):
- data.append(x)
- out = srsly.ruamel_yaml.dump_all(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
- print(type(data[0]["a"]), data[0]["a"])
- # out = srsly.ruamel_yaml.round_trip_dump_all(data)
- print(out)
- assert out == s
diff --git a/srsly/tests/ruamel_yaml/test_yamlobject.py b/srsly/tests/ruamel_yaml/test_yamlobject.py
deleted file mode 100755
index b0d64b9..0000000
--- a/srsly/tests/ruamel_yaml/test_yamlobject.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function
-
-import sys
-import pytest # NOQA
-
-from .roundtrip import save_and_run # NOQA
-
-
-def test_monster(tmpdir):
- program_src = u'''\
- import srsly.ruamel_yaml
- from textwrap import dedent
-
- class Monster(srsly.ruamel_yaml.YAMLObject):
- yaml_tag = u'!Monster'
-
- def __init__(self, name, hp, ac, attacks):
- self.name = name
- self.hp = hp
- self.ac = ac
- self.attacks = attacks
-
- def __repr__(self):
- return "%s(name=%r, hp=%r, ac=%r, attacks=%r)" % (
- self.__class__.__name__, self.name, self.hp, self.ac, self.attacks)
-
- data = srsly.ruamel_yaml.load(dedent("""\\
- --- !Monster
- name: Cave spider
- hp: [2,6] # 2d6
- ac: 16
- attacks: [BITE, HURT]
- """), Loader=srsly.ruamel_yaml.Loader)
- # normal dump, keys will be sorted
- assert srsly.ruamel_yaml.dump(data) == dedent("""\\
- !Monster
- ac: 16
- attacks: [BITE, HURT]
- hp: [2, 6]
- name: Cave spider
- """)
- '''
- assert save_and_run(program_src, tmpdir) == 1
-
-
-@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__")
-def test_qualified_name00(tmpdir):
- """issue 214"""
- program_src = u"""\
- from srsly.ruamel_yaml import YAML
- from srsly.ruamel_yaml.compat import StringIO
-
- class A:
- def f(self):
- pass
-
- yaml = YAML(typ='unsafe', pure=True)
- yaml.explicit_end = True
- buf = StringIO()
- yaml.dump(A.f, buf)
- res = buf.getvalue()
- print('res', repr(res))
- assert res == "!!python/name:__main__.A.f ''\\n...\\n"
- x = yaml.load(res)
- assert x == A.f
- """
- assert save_and_run(program_src, tmpdir) == 1
-
-
-@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__")
-def test_qualified_name01(tmpdir):
- """issue 214"""
- from srsly.ruamel_yaml import YAML
- import srsly.ruamel_yaml.comments
- from srsly.ruamel_yaml.compat import StringIO
-
- with pytest.raises(ValueError):
- yaml = YAML(typ="unsafe", pure=True)
- yaml.explicit_end = True
- buf = StringIO()
- yaml.dump(srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor, buf)
- res = buf.getvalue()
- assert (
- res
- == "!!python/name:srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor ''\n...\n"
- )
- x = yaml.load(res)
- assert x == srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor
diff --git a/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py b/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py
deleted file mode 100755
index 3a3c835..0000000
--- a/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# coding: utf-8
-
-import sys
-import pytest # NOQA
-
-from .roundtrip import round_trip_load, round_trip_dump, dedent
-
-
-class TestLeftOverDebug:
- # idea here is to capture round_trip_output via pytest stdout capture
- # if there is are any leftover debug statements they should show up
- def test_00(self, capsys):
- s = dedent(
- """
- a: 1
- b: []
- c: [a, 1]
- d: {f: 3.14, g: 42}
- """
- )
- d = round_trip_load(s)
- round_trip_dump(d, sys.stdout)
- out, err = capsys.readouterr()
- assert out == s
-
- def test_01(self, capsys):
- s = dedent(
- """
- - 1
- - []
- - [a, 1]
- - {f: 3.14, g: 42}
- - - 123
- """
- )
- d = round_trip_load(s)
- round_trip_dump(d, sys.stdout)
- out, err = capsys.readouterr()
- assert out == s
diff --git a/srsly/tests/ruamel_yaml/test_z_data.py b/srsly/tests/ruamel_yaml/test_z_data.py
deleted file mode 100755
index dbd13a1..0000000
--- a/srsly/tests/ruamel_yaml/test_z_data.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# coding: utf-8
-
-from __future__ import print_function, unicode_literals
-
-import sys
-import pytest # NOQA
-import warnings # NOQA
-
-from pathlib import Path
-
-base_path = Path("data") # that is srsly.ruamel_yaml.data
-PY2 = sys.version_info[0] == 2
-
-
-class YAMLData(object):
- yaml_tag = "!YAML"
-
- def __init__(self, s):
- self._s = s
-
- # Conversion tables for input. E.g. "" is replaced by "\t"
- # fmt: off
- special = {
- 'SPC': ' ',
- 'TAB': '\t',
- '---': '---',
- '...': '...',
- }
- # fmt: on
-
- @property
- def value(self):
- if hasattr(self, "_p"):
- return self._p
- assert " \n" not in self._s
- assert "\t\n" not in self._s
- self._p = self._s
- for k, v in YAMLData.special.items():
- k = "<" + k + ">"
- self._p = self._p.replace(k, v)
- return self._p
-
- def test_rewrite(self, s):
- assert " \n" not in s
- assert "\t\n" not in s
- for k, v in YAMLData.special.items():
- k = "<" + k + ">"
- s = s.replace(k, v)
- return s
-
- @classmethod
- def from_yaml(cls, constructor, node):
- from srsly.ruamel_yaml.nodes import MappingNode
-
- if isinstance(node, MappingNode):
- return cls(constructor.construct_mapping(node))
- return cls(node.value)
-
-
-class Python(YAMLData):
- yaml_tag = "!Python"
-
-
-class Output(YAMLData):
- yaml_tag = "!Output"
-
-
-class Assert(YAMLData):
- yaml_tag = "!Assert"
-
- @property
- def value(self):
- from srsly.ruamel_yaml.compat import Mapping
-
- if hasattr(self, "_pa"):
- return self._pa
- if isinstance(self._s, Mapping):
- self._s["lines"] = self.test_rewrite(self._s["lines"])
- self._pa = self._s
- return self._pa
-
-
-def pytest_generate_tests(metafunc):
- test_yaml = []
- paths = sorted(base_path.glob("**/*.yaml"))
- idlist = []
- for path in paths:
- stem = path.stem
- if stem.startswith(".#"): # skip emacs temporary file
- continue
- idlist.append(stem)
- test_yaml.append([path])
- metafunc.parametrize(["yaml"], test_yaml, ids=idlist, scope="class")
-
-
-class TestYAMLData(object):
- def yaml(self, yaml_version=None):
- from srsly.ruamel_yaml import YAML
-
- y = YAML()
- y.preserve_quotes = True
- if yaml_version:
- y.version = yaml_version
- return y
-
- def docs(self, path):
- from srsly.ruamel_yaml import YAML
-
- tyaml = YAML(typ="safe", pure=True)
- tyaml.register_class(YAMLData)
- tyaml.register_class(Python)
- tyaml.register_class(Output)
- tyaml.register_class(Assert)
- return list(tyaml.load_all(path))
-
- def yaml_load(self, value, yaml_version=None):
- yaml = self.yaml(yaml_version=yaml_version)
- data = yaml.load(value)
- return yaml, data
-
- def round_trip(self, input, output=None, yaml_version=None):
- from srsly.ruamel_yaml.compat import StringIO
-
- yaml, data = self.yaml_load(input.value, yaml_version=yaml_version)
- buf = StringIO()
- yaml.dump(data, buf)
- expected = input.value if output is None else output.value
- value = buf.getvalue()
- if PY2:
- value = value.decode("utf-8")
- print("value", value)
- # print('expected', expected)
- assert value == expected
-
- def load_assert(self, input, confirm, yaml_version=None):
- from srsly.ruamel_yaml.compat import Mapping
-
- d = self.yaml_load(input.value, yaml_version=yaml_version)[1] # NOQA
- print("confirm.value", confirm.value, type(confirm.value))
- if isinstance(confirm.value, Mapping):
- r = range(confirm.value["range"])
- lines = confirm.value["lines"].splitlines()
- for idx in r: # NOQA
- for line in lines:
- line = "assert " + line
- print(line)
- exec(line)
- else:
- for line in confirm.value.splitlines():
- line = "assert " + line
- print(line)
- exec(line)
-
- def run_python(self, python, data, tmpdir):
- from .roundtrip import save_and_run
-
- assert save_and_run(python.value, base_dir=tmpdir, output=data.value) == 0
-
- # this is executed by pytest the methods with names not starting with test_
- # are helpers
- def test_yaml_data(self, yaml, tmpdir):
- from srsly.ruamel_yaml.compat import Mapping
-
- idx = 0
- typ = None
- yaml_version = None
-
- docs = self.docs(yaml)
- if isinstance(docs[0], Mapping):
- d = docs[0]
- typ = d.get("type")
- yaml_version = d.get("yaml_version")
- if "python" in d:
- if not check_python_version(d["python"]):
- pytest.skip("unsupported version")
- idx += 1
- data = output = confirm = python = None
- for doc in docs[idx:]:
- if isinstance(doc, Output):
- output = doc
- elif isinstance(doc, Assert):
- confirm = doc
- elif isinstance(doc, Python):
- python = doc
- if typ is None:
- typ = "python_run"
- elif isinstance(doc, YAMLData):
- data = doc
- else:
- print("no handler for type:", type(doc), repr(doc))
- raise AssertionError()
- if typ is None:
- if data is not None and output is not None:
- typ = "rt"
- elif data is not None and confirm is not None:
- typ = "load_assert"
- else:
- assert data is not None
- typ = "rt"
- print("type:", typ)
- if data is not None:
- print("data:", data.value, end="")
- print("output:", output.value if output is not None else output)
- if typ == "rt":
- self.round_trip(data, output, yaml_version=yaml_version)
- elif typ == "python_run":
- self.run_python(python, output if output is not None else data, tmpdir)
- elif typ == "load_assert":
- self.load_assert(data, confirm, yaml_version=yaml_version)
- else:
- print("\nrun type unknown:", typ)
- raise AssertionError()
-
-
-def check_python_version(match, current=None):
- """
- version indication, return True if version matches.
- match should be something like 3.6+, or [2.7, 3.3] etc. Floats
- are converted to strings. Single values are made into lists.
- """
- if current is None:
- current = list(sys.version_info[:3])
- if not isinstance(match, list):
- match = [match]
- for m in match:
- minimal = False
- if isinstance(m, float):
- m = str(m)
- if m.endswith("+"):
- minimal = True
- m = m[:-1]
- # assert m[0].isdigit()
- # assert m[-1].isdigit()
- m = [int(x) for x in m.split(".")]
- current_len = current[: len(m)]
- # print(m, current, current_len)
- if minimal:
- if current_len >= m:
- return True
- else:
- if current_len == m:
- return True
- return False
diff --git a/srsly/tests/test_json_api.py b/srsly/tests/test_json_api.py
index 89ce400..85f0164 100644
--- a/srsly/tests/test_json_api.py
+++ b/srsly/tests/test_json_api.py
@@ -2,7 +2,6 @@
from io import StringIO
from pathlib import Path
import gzip
-import numpy
from .._json_api import (
read_json,
@@ -55,8 +54,8 @@ def test_write_json_file():
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
- '{\n "hello":"world",\n "test":123\n}',
- '{\n "test":123,\n "hello":"world"\n}',
+ '{\n "hello": "world",\n "test": 123\n}',
+ '{\n "test": 123,\n "hello": "world"\n}',
]
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
@@ -69,8 +68,8 @@ def test_write_json_file_gzip():
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
- '{\n "hello":"world",\n "test":123\n}',
- '{\n "test":123,\n "hello":"world"\n}',
+ '{\n "hello": "world",\n "test": 123\n}',
+ '{\n "test": 123,\n "hello": "world"\n}',
]
with make_tempdir() as temp_dir:
file_path = force_string(temp_dir / "tmp.json")
@@ -83,8 +82,8 @@ def test_write_json_stdout(capsys):
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
- '{\n "hello":"world",\n "test":123\n}\n',
- '{\n "test":123,\n "hello":"world"\n}\n',
+ '{\n "hello": "world",\n "test": 123\n}\n',
+ '{\n "test": 123,\n "hello": "world"\n}\n',
]
write_json("-", data)
captured = capsys.readouterr()
@@ -208,8 +207,14 @@ def test_json_loads_raises(obj):
def test_unsupported_type_error():
+ with pytest.raises(TypeError, match="is not JSON serializable"):
+ s = json_dumps({1, 2})
+
+
+def test_unsupported_type_error_numpy():
+ numpy = pytest.importorskip("numpy")
f = numpy.float32()
- with pytest.raises(TypeError):
+ with pytest.raises(TypeError, match="is not JSON serializable"):
s = json_dumps(f)
diff --git a/srsly/tests/test_msgpack_api.py b/srsly/tests/test_msgpack_api.py
index b516038..13d0f66 100644
--- a/srsly/tests/test_msgpack_api.py
+++ b/srsly/tests/test_msgpack_api.py
@@ -1,8 +1,8 @@
-import pytest
-from pathlib import Path
import datetime
-from mock import patch
-import numpy
+from collections import namedtuple
+from pathlib import Path
+
+import pytest
from .._msgpack_api import read_msgpack, write_msgpack
from .._msgpack_api import msgpack_loads, msgpack_dumps
@@ -54,14 +54,63 @@ def test_write_msgpack_file():
assert f.read() in expected
-@patch("srsly.msgpack._msgpack_numpy.np", None)
-@patch("srsly.msgpack._msgpack_numpy.has_numpy", False)
+def test_msgpack_complex():
+ inp = {"a": 1 + 2j}
+ out = msgpack_loads(msgpack_dumps(inp))
+ assert out == inp
+ # Test that we didn't accidentally convert to np.complex128,
+ # which is a subclass of complex
+ assert type(out["a"]) is complex
+
+
def test_msgpack_without_numpy():
- """Test that msgpack works without numpy and raises correct errors (e.g.
+ """Test that msgpack works with and without numpy and raises correct errors (e.g.
when serializing datetime objects, the error should be msgpack's TypeError,
not a "'np' is not defined error")."""
- with pytest.raises(TypeError):
- msgpack_loads(msgpack_dumps(datetime.datetime.now()))
+ with pytest.raises(TypeError, match="datetime.datetime"):
+ msgpack_dumps(datetime.datetime.now())
+
+
+@pytest.mark.parametrize(
+ "base_cls,data",
+ [
+ (int, 1),
+ (float, 1),
+ (list, [1, 2]),
+ (dict, {"x": 1}),
+ (str, "foo"),
+ (bytes, b"foo"),
+ ],
+)
+def test_msgpack_subtypes(base_cls, data):
+ """Subtypes of base types are cast to their parents."""
+
+ class SubType(base_cls):
+ pass
+
+ inp = SubType(data)
+ out = msgpack_loads(msgpack_dumps(inp))
+ assert type(out) is base_cls
+ assert out == base_cls(data)
+
+
+def test_msgpack_tuple():
+ # There is no difference between list and tuple in msgpack
+ # Outcome is controlled by the use_list=True parameter
+ class MyTuple(tuple):
+ pass
+
+ Named = namedtuple("Named", ["x", "y", "z"])
+
+ b1 = msgpack_dumps((1, 2, 3))
+ b2 = msgpack_dumps([1, 2, 3])
+ b3 = msgpack_dumps(MyTuple((1, 2, 3)))
+ b4 = msgpack_dumps(Named(x=1, y=2, z=3))
+ assert b2 == b1
+ assert b3 == b1
+ assert b4 == b1
+ assert msgpack_loads(b1) == [1, 2, 3]
+ assert msgpack_loads(b1, use_list=False) == (1, 2, 3)
def test_msgpack_custom_encoder_decoder():
@@ -69,15 +118,15 @@ class CustomObject:
def __init__(self, value):
self.value = value
- def serialize_obj(obj, chain=None):
+ def serialize_obj(obj):
if isinstance(obj, CustomObject):
return {"__custom__": obj.value}
- return obj if chain is None else chain(obj)
+ return obj
- def deserialize_obj(obj, chain=None):
+ def deserialize_obj(obj):
if "__custom__" in obj:
return CustomObject(obj["__custom__"])
- return obj if chain is None else chain(obj)
+ return obj
data = {"a": 123, "b": CustomObject({"foo": "bar"})}
with pytest.raises(TypeError):
@@ -91,10 +140,70 @@ def deserialize_obj(obj, chain=None):
assert new_data["a"] == 123
assert isinstance(new_data["b"], CustomObject)
assert new_data["b"].value == {"foo": "bar"}
- # Test that it also works with combinations of encoders/decoders (e.g. numpy)
- data = {"a": numpy.zeros((1, 2, 3)), "b": CustomObject({"foo": "bar"})}
+
+ # Test that it also works with combinations of encoders/decoders (e.g. complex)
+ data = {"a": 1 + 2j, "b": CustomObject({"foo": "bar"})}
bytes_data = msgpack_dumps(data)
new_data = msgpack_loads(bytes_data)
- assert isinstance(new_data["a"], numpy.ndarray)
+ assert isinstance(new_data["a"], complex)
assert isinstance(new_data["b"], CustomObject)
assert new_data["b"].value == {"foo": "bar"}
+
+ # Clean up
+ msgpack_encoders.deregister("custom_object")
+ msgpack_decoders.deregister("custom_object")
+
+
+def test_msgpack_custom_subtype_handler():
+ """By default, subtypes of base types are cast to their parents.
+ Test that the user can define a custom encoder/decoder to preserve
+ the subtype.
+ """
+
+ class MyInt(int):
+ pass
+
+ def encode_myint(obj):
+ if isinstance(obj, MyInt):
+ return {"MyInt": int(obj)}
+ return obj
+
+ def decode_myint(obj):
+ if "MyInt" in obj:
+ return MyInt(obj["MyInt"])
+ return obj
+
+ inp = MyInt(5)
+ out = msgpack_loads(msgpack_dumps(inp))
+ assert out == 5
+ assert type(out) is int
+
+ msgpack_encoders.register("myint", func=encode_myint)
+ msgpack_decoders.register("myint", func=decode_myint)
+ out = msgpack_loads(msgpack_dumps(inp))
+ assert out == MyInt(5)
+ assert type(out) is MyInt
+
+ # Cleanup
+ msgpack_encoders.deregister("myint")
+ msgpack_decoders.deregister("myint")
+
+
+def test_msgpack_numpy_not_installed():
+ """Test that we get a clean ModuleNotFoundError when
+ trying to decode numpy data when numpy is not installed.
+ """
+ # Output of np.float64(1)
+ bin = (
+ b"\x83\xc4\x02nd\xc2\xc4\x04type\xa3", encode_html_chars=True),
- )
-
- def test_doubleLongIssue(self):
- sut = {"a": -4342969734183514}
- encoded = json.dumps(sut)
- decoded = json.loads(encoded)
- self.assertEqual(sut, decoded)
- encoded = ujson.encode(sut)
- decoded = ujson.decode(encoded)
- self.assertEqual(sut, decoded)
-
- def test_doubleLongDecimalIssue(self):
- sut = {"a": -12345678901234.56789012}
- encoded = json.dumps(sut)
- decoded = json.loads(encoded)
- self.assertEqual(sut, decoded)
- encoded = ujson.encode(sut)
- decoded = ujson.decode(encoded)
- self.assertEqual(sut, decoded)
-
- def test_encodeDecodeLongDecimal(self):
- sut = {"a": -528656961.4399388}
- encoded = ujson.dumps(sut)
- ujson.decode(encoded)
-
- def test_decimalDecodeTest(self):
- sut = {"a": 4.56}
- encoded = ujson.encode(sut)
- decoded = ujson.decode(encoded)
- self.assertAlmostEqual(sut[u"a"], decoded[u"a"])
-
- def test_encodeDictWithUnicodeKeys(self):
- input = {
- "key1": "value1",
- "key1": "value1",
- "key1": "value1",
- "key1": "value1",
- "key1": "value1",
- "key1": "value1",
- }
- ujson.encode(input)
-
- input = {
- "بن": "value1",
- "بن": "value1",
- "بن": "value1",
- "بن": "value1",
- "بن": "value1",
- "بن": "value1",
- "بن": "value1",
- }
- ujson.encode(input)
-
- def test_encodeDoubleConversion(self):
- input = math.pi
- output = ujson.encode(input)
- self.assertEqual(round(input, 5), round(json.loads(output), 5))
- self.assertEqual(round(input, 5), round(ujson.decode(output), 5))
-
- def test_encodeWithDecimal(self):
- input = 1.0
- output = ujson.encode(input)
- self.assertEqual(output, "1.0")
-
- def test_encodeDoubleNegConversion(self):
- input = -math.pi
- output = ujson.encode(input)
-
- self.assertEqual(round(input, 5), round(json.loads(output), 5))
- self.assertEqual(round(input, 5), round(ujson.decode(output), 5))
-
- def test_encodeArrayOfNestedArrays(self):
- input = [[[[]]]] * 20
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- # self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeArrayOfDoubles(self):
- input = [31337.31337, 31337.31337, 31337.31337, 31337.31337] * 10
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- # self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeStringConversion2(self):
- input = "A string \\ / \b \f \n \r \t"
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, '"A string \\\\ \\/ \\b \\f \\n \\r \\t"')
- self.assertEqual(input, ujson.decode(output))
-
- def test_decodeUnicodeConversion(self):
- pass
-
- def test_encodeUnicodeConversion1(self):
- input = "Räksmörgås اسامة بن محمد بن عوض بن لادن"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
- self.assertEqual(enc, json_unicode(input))
- self.assertEqual(dec, json.loads(enc))
-
- def test_encodeControlEscaping(self):
- input = "\x19"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
- self.assertEqual(input, dec)
- self.assertEqual(enc, json_unicode(input))
-
- def test_encodeUnicodeConversion2(self):
- input = "\xe6\x97\xa5\xd1\x88"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
- self.assertEqual(enc, json_unicode(input))
- self.assertEqual(dec, json.loads(enc))
-
- def test_encodeUnicodeSurrogatePair(self):
- input = "\xf0\x90\x8d\x86"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
-
- self.assertEqual(enc, json_unicode(input))
- self.assertEqual(dec, json.loads(enc))
-
- def test_encodeUnicode4BytesUTF8(self):
- input = "\xf0\x91\x80\xb0TRAILINGNORMAL"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
-
- self.assertEqual(enc, json_unicode(input))
- self.assertEqual(dec, json.loads(enc))
-
- def test_encodeUnicode4BytesUTF8Highest(self):
- input = "\xf3\xbf\xbf\xbfTRAILINGNORMAL"
- enc = ujson.encode(input)
- dec = ujson.decode(enc)
-
- self.assertEqual(enc, json_unicode(input))
- self.assertEqual(dec, json.loads(enc))
-
- # Characters outside of Basic Multilingual Plane(larger than
- # 16 bits) are represented as \UXXXXXXXX in python but should be encoded
- # as \uXXXX\uXXXX in json.
- def testEncodeUnicodeBMP(self):
- s = "\U0001f42e\U0001f42e\U0001F42D\U0001F42D" # 🐮🐮🐭🐭
- encoded = ujson.dumps(s)
- encoded_json = json.dumps(s)
-
- if len(s) == 4:
- self.assertEqual(len(encoded), len(s) * 12 + 2)
- else:
- self.assertEqual(len(encoded), len(s) * 6 + 2)
-
- self.assertEqual(encoded, encoded_json)
- decoded = ujson.loads(encoded)
- self.assertEqual(s, decoded)
-
- # ujson outputs an UTF-8 encoded str object
- encoded = ujson.dumps(s, ensure_ascii=False)
- # json outputs an unicode object
- encoded_json = json.dumps(s, ensure_ascii=False)
- self.assertEqual(len(encoded), len(s) + 2) # original length + quotes
- self.assertEqual(encoded, encoded_json)
- decoded = ujson.loads(encoded)
- self.assertEqual(s, decoded)
-
- def testEncodeSymbols(self):
- s = "\u273f\u2661\u273f" # ✿♡✿
- encoded = ujson.dumps(s)
- encoded_json = json.dumps(s)
- self.assertEqual(len(encoded), len(s) * 6 + 2) # 6 characters + quotes
- self.assertEqual(encoded, encoded_json)
- decoded = ujson.loads(encoded)
- self.assertEqual(s, decoded)
-
- # ujson outputs an UTF-8 encoded str object
- encoded = ujson.dumps(s, ensure_ascii=False)
- # json outputs an unicode object
- encoded_json = json.dumps(s, ensure_ascii=False)
- self.assertEqual(len(encoded), len(s) + 2) # original length + quotes
- self.assertEqual(encoded, encoded_json)
- decoded = ujson.loads(encoded)
- self.assertEqual(s, decoded)
-
- def test_encodeArrayInArray(self):
- input = [[[[]]]]
- output = ujson.encode(input)
-
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeIntConversion(self):
- input = 31337
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeIntNegConversion(self):
- input = -31337
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeLongNegConversion(self):
- input = -9223372036854775808
- output = ujson.encode(input)
-
- json.loads(output)
- ujson.decode(output)
-
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeListConversion(self):
- input = [1, 2, 3, 4]
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeDictConversion(self):
- input = {"k1": 1, "k2": 2, "k3": 3, "k4": 4}
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(input, ujson.decode(output))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeNoneConversion(self):
- input = None
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeTrueConversion(self):
- input = True
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeFalseConversion(self):
- input = False
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeToUTF8(self):
- input = b"\xe6\x97\xa5\xd1\x88"
- input = input.decode("utf-8")
- enc = ujson.encode(input, ensure_ascii=False)
- dec = ujson.decode(enc)
- self.assertEqual(enc, json.dumps(input, ensure_ascii=False))
- self.assertEqual(dec, json.loads(enc))
-
- def test_decodeFromUnicode(self):
- input = '{"obj": 31337}'
- dec1 = ujson.decode(input)
- dec2 = ujson.decode(str(input))
- self.assertEqual(dec1, dec2)
-
- def test_encodeRecursionMax(self):
- # 8 is the max recursion depth
- class O2:
- member = 0
-
- def toDict(self):
- return {"member": self.member}
-
- class O1:
- member = 0
-
- def toDict(self):
- return {"member": self.member}
-
- input = O1()
- input.member = O2()
- input.member.member = input
- self.assertRaises(OverflowError, ujson.encode, input)
-
- def test_encodeDoubleNan(self):
- input = float("nan")
- self.assertRaises(OverflowError, ujson.encode, input)
-
- def test_encodeDoubleInf(self):
- input = float("inf")
- self.assertRaises(OverflowError, ujson.encode, input)
-
- def test_encodeDoubleNegInf(self):
- input = -float("inf")
- self.assertRaises(OverflowError, ujson.encode, input)
-
- def test_encodeOrderedDict(self):
- from collections import OrderedDict
-
- input = OrderedDict([(1, 1), (0, 0), (8, 8), (2, 2)])
- self.assertEqual('{"1":1,"0":0,"8":8,"2":2}', ujson.encode(input))
-
- def test_decodeJibberish(self):
- input = "fdsa sda v9sa fdsa"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenArrayStart(self):
- input = "["
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenObjectStart(self):
- input = "{"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenArrayEnd(self):
- input = "]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayDepthTooBig(self):
- input = "[" * (1024 * 1024)
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenObjectEnd(self):
- input = "}"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeObjectTrailingCommaFail(self):
- input = '{"one":1,}'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeObjectDepthTooBig(self):
- input = "{" * (1024 * 1024)
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeStringUnterminated(self):
- input = '"TESTING'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeStringUntermEscapeSequence(self):
- input = '"TESTING\\"'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeStringBadEscape(self):
- input = '"TESTING\\"'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeTrueBroken(self):
- input = "tru"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeFalseBroken(self):
- input = "fa"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeNullBroken(self):
- input = "n"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenDictKeyTypeLeakTest(self):
- input = '{{1337:""}}'
- for x in range(1000):
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenDictLeakTest(self):
- input = '{{"key":"}'
- for x in range(1000):
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeBrokenListLeakTest(self):
- input = "[[[true"
- for x in range(1000):
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeDictWithNoKey(self):
- input = "{{{{31337}}}}"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeDictWithNoColonOrValue(self):
- input = '{{{{"key"}}}}'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeDictWithNoValue(self):
- input = '{{{{"key":}}}}'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeNumericIntPos(self):
- input = "31337"
- self.assertEqual(31337, ujson.decode(input))
-
- def test_decodeNumericIntNeg(self):
- input = "-31337"
- self.assertEqual(-31337, ujson.decode(input))
-
- def test_encodeUnicode4BytesUTF8Fail(self):
- input = b"\xfd\xbf\xbf\xbf\xbf\xbf"
- self.assertRaises(OverflowError, ujson.encode, input)
-
- def test_encodeNullCharacter(self):
- input = "31337 \x00 1337"
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- input = "\x00"
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- self.assertEqual('" \\u0000\\r\\n "', ujson.dumps(" \u0000\r\n "))
-
- def test_decodeNullCharacter(self):
- input = '"31337 \\u0000 31337"'
- self.assertEqual(ujson.decode(input), json.loads(input))
-
- def test_encodeListLongConversion(self):
- input = [
- 9223372036854775807,
- 9223372036854775807,
- 9223372036854775807,
- 9223372036854775807,
- 9223372036854775807,
- 9223372036854775807,
- ]
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeListLongUnsignedConversion(self):
- input = [18446744073709551615, 18446744073709551615, 18446744073709551615]
- output = ujson.encode(input)
-
- self.assertEqual(input, json.loads(output))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeLongConversion(self):
- input = 9223372036854775807
- output = ujson.encode(input)
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_encodeLongUnsignedConversion(self):
- input = 18446744073709551615
- output = ujson.encode(input)
-
- self.assertEqual(input, json.loads(output))
- self.assertEqual(output, json.dumps(input))
- self.assertEqual(input, ujson.decode(output))
-
- def test_numericIntExp(self):
- input = "1337E40"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_numericIntFrcExp(self):
- input = "1.337E40"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpEPLUS(self):
- input = "1337E+9"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpePLUS(self):
- input = "1.337e+40"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpE(self):
- input = "1337E40"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpe(self):
- input = "1337e40"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpEMinus(self):
- input = "1.337E-4"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_decodeNumericIntExpeMinus(self):
- input = "1.337e-4"
- output = ujson.decode(input)
- self.assertEqual(output, json.loads(input))
-
- def test_dumpToFile(self):
- f = StringIO()
- ujson.dump([1, 2, 3], f)
- self.assertEqual("[1,2,3]", f.getvalue())
-
- def test_dumpToFileLikeObject(self):
- class filelike:
- def __init__(self):
- self.bytes = ""
-
- def write(self, bytes):
- self.bytes += bytes
-
- f = filelike()
- ujson.dump([1, 2, 3], f)
- self.assertEqual("[1,2,3]", f.bytes)
-
- def test_dumpFileArgsError(self):
- self.assertRaises(TypeError, ujson.dump, [], "")
-
- def test_loadFile(self):
- f = StringIO("[1,2,3,4]")
- self.assertEqual([1, 2, 3, 4], ujson.load(f))
-
- def test_loadFileLikeObject(self):
- class filelike:
- def read(self):
- try:
- self.end
- except AttributeError:
- self.end = True
- return "[1,2,3,4]"
-
- f = filelike()
- self.assertEqual([1, 2, 3, 4], ujson.load(f))
-
- def test_loadFileArgsError(self):
- self.assertRaises(TypeError, ujson.load, "[]")
-
- def test_encodeNumericOverflow(self):
- self.assertRaises(OverflowError, ujson.encode, 12839128391289382193812939)
-
- def test_decodeNumberWith32bitSignBit(self):
- # Test that numbers that fit within 32 bits but would have the
- # sign bit set (2**31 <= x < 2**32) are decoded properly.
- docs = (
- '{"id": 3590016419}',
- '{"id": %s}' % 2 ** 31,
- '{"id": %s}' % 2 ** 32,
- '{"id": %s}' % ((2 ** 32) - 1),
- )
- results = (3590016419, 2 ** 31, 2 ** 32, 2 ** 32 - 1)
- for doc, result in zip(docs, results):
- self.assertEqual(ujson.decode(doc)["id"], result)
-
- def test_encodeBigEscape(self):
- for x in range(10):
- base = "\u00e5".encode("utf-8")
- input = base * 1024 * 1024 * 2
- ujson.encode(input)
-
- def test_decodeBigEscape(self):
- for x in range(10):
- base = "\u00e5".encode("utf-8")
- quote = '"'.encode()
- input = quote + (base * 1024 * 1024 * 2) + quote
- ujson.decode(input)
-
- def test_toDict(self):
- d = {"key": 31337}
-
- class DictTest:
- def toDict(self):
- return d
-
- def __json__(self):
- return '"json defined"' # Fallback and shouldn't be called.
-
- o = DictTest()
- output = ujson.encode(o)
- dec = ujson.decode(output)
- self.assertEqual(dec, d)
-
- def test_object_with_json(self):
- # If __json__ returns a string, then that string
- # will be used as a raw JSON snippet in the object.
- output_text = "this is the correct output"
-
- class JSONTest:
- def __json__(self):
- return '"' + output_text + '"'
-
- d = {u"key": JSONTest()}
- output = ujson.encode(d)
- dec = ujson.decode(output)
- self.assertEqual(dec, {u"key": output_text})
-
- def test_object_with_json_unicode(self):
- # If __json__ returns a string, then that string
- # will be used as a raw JSON snippet in the object.
- output_text = u"this is the correct output"
-
- class JSONTest:
- def __json__(self):
- return u'"' + output_text + u'"'
-
- d = {u"key": JSONTest()}
- output = ujson.encode(d)
- dec = ujson.decode(output)
- self.assertEqual(dec, {u"key": output_text})
-
- def test_object_with_complex_json(self):
- # If __json__ returns a string, then that string
- # will be used as a raw JSON snippet in the object.
- obj = {u"foo": [u"bar", u"baz"]}
-
- class JSONTest:
- def __json__(self):
- return ujson.encode(obj)
-
- d = {u"key": JSONTest()}
- output = ujson.encode(d)
- dec = ujson.decode(output)
- self.assertEqual(dec, {u"key": obj})
-
- def test_object_with_json_type_error(self):
- # __json__ must return a string, otherwise it should raise an error.
- for return_value in (None, 1234, 12.34, True, {}):
-
- class JSONTest:
- def __json__(self):
- return return_value
-
- d = {u"key": JSONTest()}
- self.assertRaises(TypeError, ujson.encode, d)
-
- def test_object_with_json_attribute_error(self):
- # If __json__ raises an error, make sure python actually raises it.
- class JSONTest:
- def __json__(self):
- raise AttributeError
-
- d = {u"key": JSONTest()}
- self.assertRaises(AttributeError, ujson.encode, d)
-
- def test_decodeArrayTrailingCommaFail(self):
- input = "[31337,]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayLeadingCommaFail(self):
- input = "[,31337]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayOnlyCommaFail(self):
- input = "[,]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayUnmatchedBracketFail(self):
- input = "[]]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayEmpty(self):
- input = "[]"
- obj = ujson.decode(input)
- self.assertEqual([], obj)
-
- def test_decodeArrayOneItem(self):
- input = "[31337]"
- ujson.decode(input)
-
- def test_decodeLongUnsignedValue(self):
- input = "18446744073709551615"
- ujson.decode(input)
-
- def test_decodeBigValue(self):
- input = "9223372036854775807"
- ujson.decode(input)
-
- def test_decodeSmallValue(self):
- input = "-9223372036854775808"
- ujson.decode(input)
-
- def test_decodeTooBigValue(self):
- input = "18446744073709551616"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeTooSmallValue(self):
- input = "-90223372036854775809"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeVeryTooBigValue(self):
- input = "18446744073709551616"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeVeryTooSmallValue(self):
- input = "-90223372036854775809"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeWithTrailingWhitespaces(self):
- input = "{}\n\t "
- ujson.decode(input)
-
- def test_decodeWithTrailingNonWhitespaces(self):
- input = "{}\n\t a"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeArrayWithBigInt(self):
- input = "[18446744073709551616]"
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_decodeFloatingPointAdditionalTests(self):
- self.assertAlmostEqual(-1.1234567893, ujson.loads("-1.1234567893"))
- self.assertAlmostEqual(-1.234567893, ujson.loads("-1.234567893"))
- self.assertAlmostEqual(-1.34567893, ujson.loads("-1.34567893"))
- self.assertAlmostEqual(-1.4567893, ujson.loads("-1.4567893"))
- self.assertAlmostEqual(-1.567893, ujson.loads("-1.567893"))
- self.assertAlmostEqual(-1.67893, ujson.loads("-1.67893"))
- self.assertAlmostEqual(-1.7894, ujson.loads("-1.7894"))
- self.assertAlmostEqual(-1.893, ujson.loads("-1.893"))
- self.assertAlmostEqual(-1.3, ujson.loads("-1.3"))
-
- self.assertAlmostEqual(1.1234567893, ujson.loads("1.1234567893"))
- self.assertAlmostEqual(1.234567893, ujson.loads("1.234567893"))
- self.assertAlmostEqual(1.34567893, ujson.loads("1.34567893"))
- self.assertAlmostEqual(1.4567893, ujson.loads("1.4567893"))
- self.assertAlmostEqual(1.567893, ujson.loads("1.567893"))
- self.assertAlmostEqual(1.67893, ujson.loads("1.67893"))
- self.assertAlmostEqual(1.7894, ujson.loads("1.7894"))
- self.assertAlmostEqual(1.893, ujson.loads("1.893"))
- self.assertAlmostEqual(1.3, ujson.loads("1.3"))
-
- def test_ReadBadObjectSyntax(self):
- input = '{"age", 44}'
- self.assertRaises(ValueError, ujson.decode, input)
-
- def test_ReadTrue(self):
- self.assertEqual(True, ujson.loads("true"))
-
- def test_ReadFalse(self):
- self.assertEqual(False, ujson.loads("false"))
-
- def test_ReadNull(self):
- self.assertEqual(None, ujson.loads("null"))
-
- def test_WriteTrue(self):
- self.assertEqual("true", ujson.dumps(True))
-
- def test_WriteFalse(self):
- self.assertEqual("false", ujson.dumps(False))
-
- def test_WriteNull(self):
- self.assertEqual("null", ujson.dumps(None))
-
- def test_ReadArrayOfSymbols(self):
- self.assertEqual([True, False, None], ujson.loads(" [ true, false,null] "))
-
- def test_WriteArrayOfSymbolsFromList(self):
- self.assertEqual("[true,false,null]", ujson.dumps([True, False, None]))
-
- def test_WriteArrayOfSymbolsFromTuple(self):
- self.assertEqual("[true,false,null]", ujson.dumps((True, False, None)))
-
- def test_encodingInvalidUnicodeCharacter(self):
- s = "\udc7f"
- self.assertRaises(UnicodeEncodeError, ujson.dumps, s)
-
- def test_sortKeys(self):
- data = {"a": 1, "c": 1, "b": 1, "e": 1, "f": 1, "d": 1}
- sortedKeys = ujson.dumps(data, sort_keys=True)
- self.assertEqual(sortedKeys, '{"a":1,"b":1,"c":1,"d":1,"e":1,"f":1}')
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_values(self):
- import gc
-
- gc.collect()
- value = ["abc"]
- data = {"1": value}
- ref_count = sys.getrefcount(value)
- ujson.dumps(data)
- self.assertEqual(ref_count, sys.getrefcount(value))
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_keys(self):
- import gc
-
- gc.collect()
- key1 = "1"
- key2 = "1"
- value1 = ["abc"]
- value2 = [1, 2, 3]
- data = {key1: value1, key2: value2}
- ref_count1 = sys.getrefcount(key1)
- ref_count2 = sys.getrefcount(key2)
- ujson.dumps(data)
- self.assertEqual(ref_count1, sys.getrefcount(key1))
- self.assertEqual(ref_count2, sys.getrefcount(key2))
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_string_key(self):
- import gc
-
- gc.collect()
- key1 = "1"
- value1 = 1
- data = {key1: value1}
- ref_count1 = sys.getrefcount(key1)
- ujson.dumps(data)
- self.assertEqual(ref_count1, sys.getrefcount(key1))
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_tuple_key(self):
- import gc
-
- gc.collect()
- key1 = ("a",)
- value1 = 1
- data = {key1: value1}
- ref_count1 = sys.getrefcount(key1)
- ujson.dumps(data)
- self.assertEqual(ref_count1, sys.getrefcount(key1))
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_bytes_key(self):
- import gc
-
- gc.collect()
- key1 = b"1"
- value1 = 1
- data = {key1: value1}
- ref_count1 = sys.getrefcount(key1)
- ujson.dumps(data)
- self.assertEqual(ref_count1, sys.getrefcount(key1))
-
- @unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
- def test_does_not_leak_dictionary_None_key(self):
- import gc
-
- gc.collect()
- key1 = None
- value1 = 1
- data = {key1: value1}
- ref_count1 = sys.getrefcount(key1)
- ujson.dumps(data)
- self.assertEqual(ref_count1, sys.getrefcount(key1))
-
-
-"""
-def test_decodeNumericIntFrcOverflow(self):
-input = "X.Y"
-raise NotImplementedError("Implement this test!")
-
-
-def test_decodeStringUnicodeEscape(self):
-input = "\u3131"
-raise NotImplementedError("Implement this test!")
-
-def test_decodeStringUnicodeBrokenEscape(self):
-input = "\u3131"
-raise NotImplementedError("Implement this test!")
-
-def test_decodeStringUnicodeInvalidEscape(self):
-input = "\u3131"
-raise NotImplementedError("Implement this test!")
-
-def test_decodeStringUTF8(self):
-input = "someutfcharacters"
-raise NotImplementedError("Implement this test!")
-
-"""
-
-if __name__ == "__main__":
- unittest.main()
-
-"""
-# Use this to look for memory leaks
-if __name__ == '__main__':
- from guppy import hpy
- hp = hpy()
- hp.setrelheap()
- while True:
- try:
- unittest.main()
- except SystemExit:
- pass
- heap = hp.heapu()
- print(heap)
-"""
-
-
-@pytest.mark.parametrize("indent", list(range(65537, 65542)))
-def test_dump_huge_indent(indent):
- ujson.encode({"a": True}, indent=indent)
-
-
-@pytest.mark.parametrize("first_length", list(range(2, 7)))
-@pytest.mark.parametrize("second_length", list(range(10919, 10924)))
-def test_dump_long_string(first_length, second_length):
- ujson.dumps(["a" * first_length, "\x00" * second_length])
-
-
-def test_dump_indented_nested_list():
- a = _a = []
- for i in range(20):
- _a.append(list(range(i)))
- _a = _a[-1]
- ujson.dumps(a, indent=i)
-
-
-@pytest.mark.parametrize("indent", [0, 1, 2, 4, 5, 8, 49])
-def test_issue_334(indent):
- path = Path(__file__).with_name("334-reproducer.json")
- a = ujson.loads(path.read_bytes())
- ujson.dumps(a, indent=indent)
-
-
-@pytest.mark.parametrize(
- "test_input, expected",
- [
- # Normal cases
- (r'"\uD83D\uDCA9"', "\U0001F4A9"),
- (r'"a\uD83D\uDCA9b"', "a\U0001F4A9b"),
- # Unpaired surrogates
- (r'"\uD800"', "\uD800"),
- (r'"a\uD800b"', "a\uD800b"),
- (r'"\uDEAD"', "\uDEAD"),
- (r'"a\uDEADb"', "a\uDEADb"),
- (r'"\uD83D\uD83D\uDCA9"', "\uD83D\U0001F4A9"),
- (r'"\uDCA9\uD83D\uDCA9"', "\uDCA9\U0001F4A9"),
- (r'"\uD83D\uDCA9\uD83D"', "\U0001F4A9\uD83D"),
- (r'"\uD83D\uDCA9\uDCA9"', "\U0001F4A9\uDCA9"),
- (r'"\uD83D \uDCA9"', "\uD83D \uDCA9"),
- # No decoding of actual surrogate characters (rather than escaped ones)
- ('"\uD800"', "\uD800"),
- ('"\uDEAD"', "\uDEAD"),
- ('"\uD800a\uDEAD"', "\uD800a\uDEAD"),
- ('"\uD83D\uDCA9"', "\uD83D\uDCA9"),
- ],
-)
-def test_decode_surrogate_characters(test_input, expected):
- assert ujson.loads(test_input) == expected
- assert ujson.loads(test_input.encode("utf-8", "surrogatepass")) == expected
-
- # Ensure that this matches stdlib's behaviour
- assert json.loads(test_input) == expected
diff --git a/srsly/ujson/JSONtoObj.c b/srsly/ujson/JSONtoObj.c
deleted file mode 100644
index 8563970..0000000
--- a/srsly/ujson/JSONtoObj.c
+++ /dev/null
@@ -1,261 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
- * Copyright (c) 1988-1993 The Regents of the University of California.
- * Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include "py_defines.h"
-#include
-
-
-//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__)
-#define PRINTMARK()
-
-void Object_objectAddKey(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value)
-{
- PyDict_SetItem (obj, name, value);
- Py_DECREF( (PyObject *) name);
- Py_DECREF( (PyObject *) value);
- return;
-}
-
-void Object_arrayAddItem(void *prv, JSOBJ obj, JSOBJ value)
-{
- PyList_Append(obj, value);
- Py_DECREF( (PyObject *) value);
- return;
-}
-
-/*
-Check that Py_UCS4 is the same as JSUINT32, else Object_newString will fail.
-Based on Linux's check in vbox_vmmdev_types.h.
-This should be replaced with
- _Static_assert(sizeof(Py_UCS4) == sizeof(JSUINT32));
-when C11 is made mandatory (CPython 3.11+, PyPy ?).
-*/
-typedef char assert_py_ucs4_is_jsuint32[1 - 2*!(sizeof(Py_UCS4) == sizeof(JSUINT32))];
-
-static JSOBJ Object_newString(void *prv, JSUINT32 *start, JSUINT32 *end)
-{
- return PyUnicode_FromKindAndData (PyUnicode_4BYTE_KIND, (Py_UCS4 *) start, (end - start));
-}
-
-JSOBJ Object_newTrue(void *prv)
-{
- Py_RETURN_TRUE;
-}
-
-JSOBJ Object_newFalse(void *prv)
-{
- Py_RETURN_FALSE;
-}
-
-JSOBJ Object_newNull(void *prv)
-{
- Py_RETURN_NONE;
-}
-
-JSOBJ Object_newObject(void *prv)
-{
- return PyDict_New();
-}
-
-JSOBJ Object_newArray(void *prv)
-{
- return PyList_New(0);
-}
-
-JSOBJ Object_newInteger(void *prv, JSINT32 value)
-{
- return PyInt_FromLong( (long) value);
-}
-
-JSOBJ Object_newLong(void *prv, JSINT64 value)
-{
- return PyLong_FromLongLong (value);
-}
-
-JSOBJ Object_newUnsignedLong(void *prv, JSUINT64 value)
-{
- return PyLong_FromUnsignedLongLong (value);
-}
-
-JSOBJ Object_newDouble(void *prv, double value)
-{
- return PyFloat_FromDouble(value);
-}
-
-static void Object_releaseObject(void *prv, JSOBJ obj)
-{
- Py_DECREF( ((PyObject *)obj));
-}
-
-static char *g_kwlist[] = {"obj", "precise_float", NULL};
-
-PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs)
-{
- PyObject *ret;
- PyObject *sarg;
- PyObject *arg;
- PyObject *opreciseFloat = NULL;
- JSONObjectDecoder decoder =
- {
- Object_newString,
- Object_objectAddKey,
- Object_arrayAddItem,
- Object_newTrue,
- Object_newFalse,
- Object_newNull,
- Object_newObject,
- Object_newArray,
- Object_newInteger,
- Object_newLong,
- Object_newUnsignedLong,
- Object_newDouble,
- Object_releaseObject,
- PyObject_Malloc,
- PyObject_Free,
- PyObject_Realloc
- };
-
- decoder.preciseFloat = 0;
- decoder.prv = NULL;
-
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", g_kwlist, &arg, &opreciseFloat))
- {
- return NULL;
- }
-
- if (opreciseFloat && PyObject_IsTrue(opreciseFloat))
- {
- decoder.preciseFloat = 1;
- }
-
- if (PyString_Check(arg))
- {
- sarg = arg;
- }
- else
- if (PyUnicode_Check(arg))
- {
- sarg = PyUnicode_AsEncodedString(arg, NULL, "surrogatepass");
- if (sarg == NULL)
- {
- //Exception raised above us by codec according to docs
- return NULL;
- }
- }
- else
- {
- PyErr_Format(PyExc_TypeError, "Expected String or Unicode");
- return NULL;
- }
-
- decoder.errorStr = NULL;
- decoder.errorOffset = NULL;
-
- ret = JSON_DecodeObject(&decoder, PyString_AS_STRING(sarg), PyString_GET_SIZE(sarg));
-
- if (sarg != arg)
- {
- Py_DECREF(sarg);
- }
-
- if (decoder.errorStr)
- {
- /*
- FIXME: It's possible to give a much nicer error message here with actual failing element in input etc*/
-
- PyErr_Format (PyExc_ValueError, "%s", decoder.errorStr);
-
- if (ret)
- {
- Py_DECREF( (PyObject *) ret);
- }
-
- return NULL;
- }
-
- return ret;
-}
-
-PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs)
-{
- PyObject *read;
- PyObject *string;
- PyObject *result;
- PyObject *file = NULL;
- PyObject *argtuple;
-
- if (!PyArg_ParseTuple (args, "O", &file))
- {
- return NULL;
- }
-
- if (!PyObject_HasAttrString (file, "read"))
- {
- PyErr_Format (PyExc_TypeError, "expected file");
- return NULL;
- }
-
- read = PyObject_GetAttrString (file, "read");
-
- if (!PyCallable_Check (read)) {
- Py_XDECREF(read);
- PyErr_Format (PyExc_TypeError, "expected file");
- return NULL;
- }
-
- string = PyObject_CallObject (read, NULL);
- Py_XDECREF(read);
-
- if (string == NULL)
- {
- return NULL;
- }
-
- argtuple = PyTuple_Pack(1, string);
-
- result = JSONToObj (self, argtuple, kwargs);
-
- Py_XDECREF(argtuple);
- Py_XDECREF(string);
-
- if (result == NULL) {
- return NULL;
- }
-
- return result;
-}
diff --git a/srsly/ujson/__init__.py b/srsly/ujson/__init__.py
deleted file mode 100644
index 744ff70..0000000
--- a/srsly/ujson/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .ujson import decode, encode, dump, dumps, load, loads # noqa: F401
diff --git a/srsly/ujson/lib/dconv_wrapper.cc b/srsly/ujson/lib/dconv_wrapper.cc
deleted file mode 100644
index 27e1b9b..0000000
--- a/srsly/ujson/lib/dconv_wrapper.cc
+++ /dev/null
@@ -1,58 +0,0 @@
-#include "double-conversion.h"
-
-namespace double_conversion
-{
- static StringToDoubleConverter* s2d_instance = NULL;
- static DoubleToStringConverter* d2s_instance = NULL;
-
- extern "C"
- {
- void dconv_d2s_init(int flags,
- const char* infinity_symbol,
- const char* nan_symbol,
- char exponent_character,
- int decimal_in_shortest_low,
- int decimal_in_shortest_high,
- int max_leading_padding_zeroes_in_precision_mode,
- int max_trailing_padding_zeroes_in_precision_mode)
- {
- d2s_instance = new DoubleToStringConverter(flags, infinity_symbol, nan_symbol,
- exponent_character, decimal_in_shortest_low,
- decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode,
- max_trailing_padding_zeroes_in_precision_mode);
- }
-
- int dconv_d2s(double value, char* buf, int buflen, int* strlength)
- {
- StringBuilder sb(buf, buflen);
- int success = static_cast(d2s_instance->ToShortest(value, &sb));
- *strlength = success ? sb.position() : -1;
- return success;
- }
-
- void dconv_d2s_free()
- {
- delete d2s_instance;
- d2s_instance = NULL;
- }
-
- void dconv_s2d_init(int flags, double empty_string_value,
- double junk_string_value, const char* infinity_symbol,
- const char* nan_symbol)
- {
- s2d_instance = new StringToDoubleConverter(flags, empty_string_value,
- junk_string_value, infinity_symbol, nan_symbol);
- }
-
- double dconv_s2d(const char* buffer, int length, int* processed_characters_count)
- {
- return s2d_instance->StringToDouble(buffer, length, processed_characters_count);
- }
-
- void dconv_s2d_free()
- {
- delete s2d_instance;
- s2d_instance = NULL;
- }
- }
-}
diff --git a/srsly/ujson/lib/ultrajson.h b/srsly/ujson/lib/ultrajson.h
deleted file mode 100644
index a117901..0000000
--- a/srsly/ujson/lib/ultrajson.h
+++ /dev/null
@@ -1,324 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
- * Copyright (c) 1988-1993 The Regents of the University of California.
- * Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-/*
-Ultra fast JSON encoder and decoder
-Developed by Jonas Tarnstrom (jonas@esn.me).
-
-Encoder notes:
-------------------
-
-:: Cyclic references ::
-Cyclic referenced objects are not detected.
-Set JSONObjectEncoder.recursionMax to suitable value or make sure input object
-tree doesn't have cyclic references.
-
-*/
-
-#ifndef __ULTRAJSON_H__
-#define __ULTRAJSON_H__
-
-#include
-
-// Max decimals to encode double floating point numbers with
-#ifndef JSON_DOUBLE_MAX_DECIMALS
-#define JSON_DOUBLE_MAX_DECIMALS 15
-#endif
-
-// Max recursion depth, default for encoder
-#ifndef JSON_MAX_RECURSION_DEPTH
-#define JSON_MAX_RECURSION_DEPTH 1024
-#endif
-
-// Max recursion depth, default for decoder
-#ifndef JSON_MAX_OBJECT_DEPTH
-#define JSON_MAX_OBJECT_DEPTH 1024
-#endif
-
-/*
-Dictates and limits how much stack space for buffers UltraJSON will use before resorting to provided heap functions */
-#ifndef JSON_MAX_STACK_BUFFER_SIZE
-#define JSON_MAX_STACK_BUFFER_SIZE 131072
-#endif
-
-#ifdef _WIN32
-
-typedef __int64 JSINT64;
-typedef unsigned __int64 JSUINT64;
-
-typedef __int32 JSINT32;
-typedef unsigned __int32 JSUINT32;
-typedef unsigned __int8 JSUINT8;
-typedef unsigned __int16 JSUTF16;
-typedef unsigned __int32 JSUTF32;
-typedef __int64 JSLONG;
-
-#define EXPORTFUNCTION __declspec(dllexport)
-
-#define FASTCALL_MSVC __fastcall
-#define FASTCALL_ATTR
-#define INLINE_PREFIX __inline
-
-#else
-
-#include
-typedef int64_t JSINT64;
-typedef uint64_t JSUINT64;
-
-typedef int32_t JSINT32;
-typedef uint32_t JSUINT32;
-
-#define FASTCALL_MSVC
-
-#if !defined __x86_64__
-#define FASTCALL_ATTR __attribute__((fastcall))
-#else
-#define FASTCALL_ATTR
-#endif
-
-#define INLINE_PREFIX inline
-
-typedef uint8_t JSUINT8;
-typedef uint16_t JSUTF16;
-typedef uint32_t JSUTF32;
-
-typedef int64_t JSLONG;
-
-#define EXPORTFUNCTION
-#endif
-
-#if !(defined(__LITTLE_ENDIAN__) || defined(__BIG_ENDIAN__))
-
-#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
-#define __LITTLE_ENDIAN__
-#else
-
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-#define __BIG_ENDIAN__
-#endif
-
-#endif
-
-#endif
-
-#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)
-#error "Endianess not supported"
-#endif
-
-enum JSTYPES
-{
- JT_NULL, // NULL
- JT_TRUE, // boolean true
- JT_FALSE, // boolean false
- JT_INT, // (JSINT32 (signed 32-bit))
- JT_LONG, // (JSINT64 (signed 64-bit))
- JT_ULONG, // (JSUINT64 (unsigned 64-bit))
- JT_DOUBLE, // (double)
- JT_UTF8, // (char 8-bit)
- JT_RAW, // (raw char 8-bit)
- JT_ARRAY, // Array structure
- JT_OBJECT, // Key/Value structure
- JT_INVALID, // Internal, do not return nor expect
-};
-
-typedef void * JSOBJ;
-typedef void * JSITER;
-
-typedef struct __JSONTypeContext
-{
- int type;
- void *prv;
- void *encoder_prv;
-} JSONTypeContext;
-
-/*
-Function pointer declarations, suitable for implementing UltraJSON */
-typedef int (*JSPFN_ITERNEXT)(JSOBJ obj, JSONTypeContext *tc);
-typedef void (*JSPFN_ITEREND)(JSOBJ obj, JSONTypeContext *tc);
-typedef JSOBJ (*JSPFN_ITERGETVALUE)(JSOBJ obj, JSONTypeContext *tc);
-typedef char *(*JSPFN_ITERGETNAME)(JSOBJ obj, JSONTypeContext *tc, size_t *outLen);
-typedef void *(*JSPFN_MALLOC)(size_t size);
-typedef void (*JSPFN_FREE)(void *pptr);
-typedef void *(*JSPFN_REALLOC)(void *base, size_t size);
-
-
-struct __JSONObjectEncoder;
-
-typedef struct __JSONObjectEncoder
-{
- void (*beginTypeContext)(JSOBJ obj, JSONTypeContext *tc, struct __JSONObjectEncoder *enc);
- void (*endTypeContext)(JSOBJ obj, JSONTypeContext *tc);
- const char *(*getStringValue)(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen);
- JSINT64 (*getLongValue)(JSOBJ obj, JSONTypeContext *tc);
- JSUINT64 (*getUnsignedLongValue)(JSOBJ obj, JSONTypeContext *tc);
- JSINT32 (*getIntValue)(JSOBJ obj, JSONTypeContext *tc);
- double (*getDoubleValue)(JSOBJ obj, JSONTypeContext *tc);
-
- /*
- Retrieve next object in an iteration. Should return 0 to indicate iteration has reached end or 1 if there are more items.
- Implementor is responsible for keeping state of the iteration. Use ti->prv fields for this
- */
- JSPFN_ITERNEXT iterNext;
-
- /*
- Ends the iteration of an iteratable object.
- Any iteration state stored in ti->prv can be freed here
- */
- JSPFN_ITEREND iterEnd;
-
- /*
- Returns a reference to the value object of an iterator
- The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object
- */
- JSPFN_ITERGETVALUE iterGetValue;
-
- /*
- Return name of iterator.
- The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object
- */
- JSPFN_ITERGETNAME iterGetName;
-
- /*
- Release a value as indicated by setting ti->release = 1 in the previous getValue call.
- The ti->prv array should contain the necessary context to release the value
- */
- void (*releaseObject)(JSOBJ obj);
-
- /* Library functions
- Set to NULL to use STDLIB malloc,realloc,free */
- JSPFN_MALLOC malloc;
- JSPFN_REALLOC realloc;
- JSPFN_FREE free;
-
- /*
- Configuration for max recursion, set to 0 to use default (see JSON_MAX_RECURSION_DEPTH)*/
- int recursionMax;
-
- /*
- Configuration for max decimals of double floating point numbers to encode (0-9) */
- int doublePrecision;
-
- /*
- If true output will be ASCII with all characters above 127 encoded as \uXXXX. If false output will be UTF-8 or what ever charset strings are brought as */
- int forceASCII;
-
- /*
- If true, '<', '>', and '&' characters will be encoded as \u003c, \u003e, and \u0026, respectively. If false, no special encoding will be used. */
- int encodeHTMLChars;
-
- /*
- If true, '/' will be encoded as \/. If false, no escaping. */
- int escapeForwardSlashes;
-
- /*
- If true, dictionaries are iterated through in sorted key order. */
- int sortKeys;
-
- /*
- Configuration for spaces of indent */
- int indent;
-
- /*
- Private pointer to be used by the caller. Passed as encoder_prv in JSONTypeContext */
- void *prv;
-
- /*
- Set to an error message if error occured */
- const char *errorMsg;
- JSOBJ errorObj;
-
- /* Buffer stuff */
- char *start;
- char *offset;
- char *end;
- int heap;
- int level;
-
-} JSONObjectEncoder;
-
-
-/*
-Encode an object structure into JSON.
-
-Arguments:
-obj - An anonymous type representing the object
-enc - Function definitions for querying JSOBJ type
-buffer - Preallocated buffer to store result in. If NULL function allocates own buffer
-cbBuffer - Length of buffer (ignored if buffer is NULL)
-
-Returns:
-Encoded JSON object as a null terminated char string.
-
-NOTE:
-If the supplied buffer wasn't enough to hold the result the function will allocate a new buffer.
-Life cycle of the provided buffer must still be handled by caller.
-
-If the return value doesn't equal the specified buffer caller must release the memory using
-JSONObjectEncoder.free or free() as specified when calling this function.
-*/
-EXPORTFUNCTION char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *buffer, size_t cbBuffer);
-
-
-
-typedef struct __JSONObjectDecoder
-{
- JSOBJ (*newString)(void *prv, JSUINT32 *start, JSUINT32 *end);
- void (*objectAddKey)(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value);
- void (*arrayAddItem)(void *prv, JSOBJ obj, JSOBJ value);
- JSOBJ (*newTrue)(void *prv);
- JSOBJ (*newFalse)(void *prv);
- JSOBJ (*newNull)(void *prv);
- JSOBJ (*newObject)(void *prv);
- JSOBJ (*newArray)(void *prv);
- JSOBJ (*newInt)(void *prv, JSINT32 value);
- JSOBJ (*newLong)(void *prv, JSINT64 value);
- JSOBJ (*newUnsignedLong)(void *prv, JSUINT64 value);
- JSOBJ (*newDouble)(void *prv, double value);
- void (*releaseObject)(void *prv, JSOBJ obj);
- JSPFN_MALLOC malloc;
- JSPFN_FREE free;
- JSPFN_REALLOC realloc;
- char *errorStr;
- char *errorOffset;
- int preciseFloat;
- void *prv;
-} JSONObjectDecoder;
-
-EXPORTFUNCTION JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer);
-
-#endif
diff --git a/srsly/ujson/lib/ultrajsondec.c b/srsly/ujson/lib/ultrajsondec.c
deleted file mode 100644
index a88a008..0000000
--- a/srsly/ujson/lib/ultrajsondec.c
+++ /dev/null
@@ -1,891 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
-* Copyright (c) 1988-1993 The Regents of the University of California.
-* Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include "ultrajson.h"
-#include
-#include
-#include
-#include
-#include
-#include
-
-#ifndef TRUE
-#define TRUE 1
-#define FALSE 0
-#endif
-#ifndef NULL
-#define NULL 0
-#endif
-
-struct DecoderState
-{
- char *start;
- char *end;
- JSUINT32 *escStart;
- JSUINT32 *escEnd;
- int escHeap;
- int lastType;
- JSUINT32 objDepth;
- void *prv;
- JSONObjectDecoder *dec;
-};
-
-JSOBJ FASTCALL_MSVC decode_any( struct DecoderState *ds) FASTCALL_ATTR;
-typedef JSOBJ (*PFN_DECODER)( struct DecoderState *ds);
-
-static JSOBJ SetError( struct DecoderState *ds, int offset, const char *message)
-{
- ds->dec->errorOffset = ds->start + offset;
- ds->dec->errorStr = (char *) message;
- return NULL;
-}
-
-double createDouble(double intNeg, double intValue, double frcValue, int frcDecimalCount)
-{
- static const double g_pow10[] = {1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001,0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001, 0.0000000000001, 0.00000000000001, 0.000000000000001};
- return (intValue + (frcValue * g_pow10[frcDecimalCount])) * intNeg;
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decodePreciseFloat(struct DecoderState *ds)
-{
- char *end;
- double value;
- errno = 0;
-
- value = strtod(ds->start, &end);
-
- if (errno == ERANGE)
- {
- return SetError(ds, -1, "Range error when decoding numeric as double");
- }
-
- ds->start = end;
- return ds->dec->newDouble(ds->prv, value);
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_numeric (struct DecoderState *ds)
-{
- int intNeg = 1;
- int mantSize = 0;
- JSUINT64 intValue;
- JSUINT64 prevIntValue;
- int chr;
- int decimalCount = 0;
- double frcValue = 0.0;
- double expNeg;
- double expValue;
- char *offset = ds->start;
-
- JSUINT64 overflowLimit = LLONG_MAX;
-
- if (*(offset) == '-')
- {
- offset ++;
- intNeg = -1;
- overflowLimit = LLONG_MIN;
- }
-
- // Scan integer part
- intValue = 0;
-
- while (1)
- {
- chr = (int) (unsigned char) *(offset);
-
- switch (chr)
- {
- case '0':
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- {
- //PERF: Don't do 64-bit arithmetic here unless we know we have to
- prevIntValue = intValue;
- intValue = intValue * 10ULL + (JSLONG) (chr - 48);
-
- if (intNeg == 1 && prevIntValue > intValue)
- {
- return SetError(ds, -1, "Value is too big!");
- }
- else if (intNeg == -1 && intValue > overflowLimit)
- {
- return SetError(ds, -1, overflowLimit == LLONG_MAX ? "Value is too big!" : "Value is too small");
- }
-
- offset ++;
- mantSize ++;
- break;
- }
- case '.':
- {
- offset ++;
- goto DECODE_FRACTION;
- break;
- }
- case 'e':
- case 'E':
- {
- offset ++;
- goto DECODE_EXPONENT;
- break;
- }
-
- default:
- {
- goto BREAK_INT_LOOP;
- break;
- }
- }
- }
-
-BREAK_INT_LOOP:
-
- ds->lastType = JT_INT;
- ds->start = offset;
-
- if (intNeg == 1 && (intValue & 0x8000000000000000ULL) != 0)
- {
- return ds->dec->newUnsignedLong(ds->prv, intValue);
- }
- else if ((intValue >> 31))
- {
- return ds->dec->newLong(ds->prv, (JSINT64) (intValue * (JSINT64) intNeg));
- }
- else
- {
- return ds->dec->newInt(ds->prv, (JSINT32) (intValue * intNeg));
- }
-
-DECODE_FRACTION:
-
- if (ds->dec->preciseFloat)
- {
- return decodePreciseFloat(ds);
- }
-
- // Scan fraction part
- frcValue = 0.0;
- for (;;)
- {
- chr = (int) (unsigned char) *(offset);
-
- switch (chr)
- {
- case '0':
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- {
- if (decimalCount < JSON_DOUBLE_MAX_DECIMALS)
- {
- frcValue = frcValue * 10.0 + (double) (chr - 48);
- decimalCount ++;
- }
- offset ++;
- break;
- }
- case 'e':
- case 'E':
- {
- offset ++;
- goto DECODE_EXPONENT;
- break;
- }
- default:
- {
- goto BREAK_FRC_LOOP;
- }
- }
- }
-
-BREAK_FRC_LOOP:
- //FIXME: Check for arithemtic overflow here
- ds->lastType = JT_DOUBLE;
- ds->start = offset;
- return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue, frcValue, decimalCount));
-
-DECODE_EXPONENT:
- if (ds->dec->preciseFloat)
- {
- return decodePreciseFloat(ds);
- }
-
- expNeg = 1.0;
-
- if (*(offset) == '-')
- {
- expNeg = -1.0;
- offset ++;
- }
- else
- if (*(offset) == '+')
- {
- expNeg = +1.0;
- offset ++;
- }
-
- expValue = 0.0;
-
- for (;;)
- {
- chr = (int) (unsigned char) *(offset);
-
- switch (chr)
- {
- case '0':
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- {
- expValue = expValue * 10.0 + (double) (chr - 48);
- offset ++;
- break;
- }
- default:
- {
- goto BREAK_EXP_LOOP;
- }
- }
- }
-
-BREAK_EXP_LOOP:
- //FIXME: Check for arithemtic overflow here
- ds->lastType = JT_DOUBLE;
- ds->start = offset;
- return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue , frcValue, decimalCount) * pow(10.0, expValue * expNeg));
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_true ( struct DecoderState *ds)
-{
- char *offset = ds->start;
- offset ++;
-
- if (*(offset++) != 'r')
- goto SETERROR;
- if (*(offset++) != 'u')
- goto SETERROR;
- if (*(offset++) != 'e')
- goto SETERROR;
-
- ds->lastType = JT_TRUE;
- ds->start = offset;
- return ds->dec->newTrue(ds->prv);
-
-SETERROR:
- return SetError(ds, -1, "Unexpected character found when decoding 'true'");
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_false ( struct DecoderState *ds)
-{
- char *offset = ds->start;
- offset ++;
-
- if (*(offset++) != 'a')
- goto SETERROR;
- if (*(offset++) != 'l')
- goto SETERROR;
- if (*(offset++) != 's')
- goto SETERROR;
- if (*(offset++) != 'e')
- goto SETERROR;
-
- ds->lastType = JT_FALSE;
- ds->start = offset;
- return ds->dec->newFalse(ds->prv);
-
-SETERROR:
- return SetError(ds, -1, "Unexpected character found when decoding 'false'");
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_null ( struct DecoderState *ds)
-{
- char *offset = ds->start;
- offset ++;
-
- if (*(offset++) != 'u')
- goto SETERROR;
- if (*(offset++) != 'l')
- goto SETERROR;
- if (*(offset++) != 'l')
- goto SETERROR;
-
- ds->lastType = JT_NULL;
- ds->start = offset;
- return ds->dec->newNull(ds->prv);
-
-SETERROR:
- return SetError(ds, -1, "Unexpected character found when decoding 'null'");
-}
-
-FASTCALL_ATTR void FASTCALL_MSVC SkipWhitespace(struct DecoderState *ds)
-{
- char *offset = ds->start;
-
- for (;;)
- {
- switch (*offset)
- {
- case ' ':
- case '\t':
- case '\r':
- case '\n':
- offset ++;
- break;
-
- default:
- ds->start = offset;
- return;
- }
- }
-}
-
-enum DECODESTRINGSTATE
-{
- DS_ISNULL = 0x32,
- DS_ISQUOTE,
- DS_ISESCAPE,
- DS_UTFLENERROR,
-
-};
-
-static const JSUINT8 g_decoderLookup[256] =
-{
- /* 0x00 */ DS_ISNULL, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x10 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x20 */ 1, 1, DS_ISQUOTE, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, DS_ISESCAPE, 1, 1, 1,
- /* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- /* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- /* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- /* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
- /* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR,
-};
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_string ( struct DecoderState *ds)
-{
- int index;
- JSUINT32 *escOffset;
- JSUINT32 *escStart;
- size_t escLen = (ds->escEnd - ds->escStart);
- JSUINT8 *inputOffset;
- JSUTF16 ch = 0;
- JSUINT8 *lastHighSurrogate = NULL;
- JSUINT8 oct;
- JSUTF32 ucs;
- ds->lastType = JT_INVALID;
- ds->start ++;
-
- if ( (size_t) (ds->end - ds->start) > escLen)
- {
- size_t newSize = (ds->end - ds->start);
-
- if (ds->escHeap)
- {
- if (newSize > (SIZE_MAX / sizeof(JSUINT32)))
- {
- return SetError(ds, -1, "Could not reserve memory block");
- }
- escStart = (JSUINT32 *)ds->dec->realloc(ds->escStart, newSize * sizeof(JSUINT32));
- if (!escStart)
- {
- ds->dec->free(ds->escStart);
- return SetError(ds, -1, "Could not reserve memory block");
- }
- ds->escStart = escStart;
- }
- else
- {
- JSUINT32 *oldStart = ds->escStart;
- if (newSize > (SIZE_MAX / sizeof(JSUINT32)))
- {
- return SetError(ds, -1, "Could not reserve memory block");
- }
- ds->escStart = (JSUINT32 *) ds->dec->malloc(newSize * sizeof(JSUINT32));
- if (!ds->escStart)
- {
- return SetError(ds, -1, "Could not reserve memory block");
- }
- ds->escHeap = 1;
- memcpy(ds->escStart, oldStart, escLen * sizeof(JSUINT32));
- }
-
- ds->escEnd = ds->escStart + newSize;
- }
-
- escOffset = ds->escStart;
- inputOffset = (JSUINT8 *) ds->start;
-
- for (;;)
- {
- switch (g_decoderLookup[(JSUINT8)(*inputOffset)])
- {
- case DS_ISNULL:
- {
- return SetError(ds, -1, "Unmatched ''\"' when when decoding 'string'");
- }
- case DS_ISQUOTE:
- {
- ds->lastType = JT_UTF8;
- inputOffset ++;
- ds->start += ( (char *) inputOffset - (ds->start));
- return ds->dec->newString(ds->prv, ds->escStart, escOffset);
- }
- case DS_UTFLENERROR:
- {
- return SetError (ds, -1, "Invalid UTF-8 sequence length when decoding 'string'");
- }
- case DS_ISESCAPE:
- inputOffset ++;
- switch (*inputOffset)
- {
- case '\\': *(escOffset++) = '\\'; inputOffset++; continue;
- case '\"': *(escOffset++) = '\"'; inputOffset++; continue;
- case '/': *(escOffset++) = '/'; inputOffset++; continue;
- case 'b': *(escOffset++) = '\b'; inputOffset++; continue;
- case 'f': *(escOffset++) = '\f'; inputOffset++; continue;
- case 'n': *(escOffset++) = '\n'; inputOffset++; continue;
- case 'r': *(escOffset++) = '\r'; inputOffset++; continue;
- case 't': *(escOffset++) = '\t'; inputOffset++; continue;
-
- case 'u':
- {
- int index;
- inputOffset ++;
-
- for (index = 0; index < 4; index ++)
- {
- switch (*inputOffset)
- {
- case '\0': return SetError (ds, -1, "Unterminated unicode escape sequence when decoding 'string'");
- default: return SetError (ds, -1, "Unexpected character in unicode escape sequence when decoding 'string'");
-
- case '0':
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- ch = (ch << 4) + (JSUTF16) (*inputOffset - '0');
- break;
-
- case 'a':
- case 'b':
- case 'c':
- case 'd':
- case 'e':
- case 'f':
- ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'a');
- break;
-
- case 'A':
- case 'B':
- case 'C':
- case 'D':
- case 'E':
- case 'F':
- ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'A');
- break;
- }
-
- inputOffset ++;
- }
-
- if ((ch & 0xfc00) == 0xdc00 && lastHighSurrogate == inputOffset - 6 * sizeof(*inputOffset))
- {
- // Low surrogate immediately following a high surrogate
- // Overwrite existing high surrogate with combined character
- *(escOffset-1) = (((*(escOffset-1) - 0xd800) <<10) | (ch - 0xdc00)) + 0x10000;
- }
- else
- {
- *(escOffset++) = (JSUINT32) ch;
- }
- if ((ch & 0xfc00) == 0xd800)
- {
- lastHighSurrogate = inputOffset;
- }
- break;
- }
-
- case '\0': return SetError(ds, -1, "Unterminated escape sequence when decoding 'string'");
- default: return SetError(ds, -1, "Unrecognized escape sequence when decoding 'string'");
- }
- break;
-
- case 1:
- {
- *(escOffset++) = (JSUINT32) (*inputOffset++);
- break;
- }
-
- case 2:
- {
- ucs = (*inputOffset++) & 0x1f;
- ucs <<= 6;
- if (((*inputOffset) & 0x80) != 0x80)
- {
- return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
- }
- ucs |= (*inputOffset++) & 0x3f;
- if (ucs < 0x80) return SetError (ds, -1, "Overlong 2 byte UTF-8 sequence detected when decoding 'string'");
- *(escOffset++) = (JSUINT32) ucs;
- break;
- }
-
- case 3:
- {
- JSUTF32 ucs = 0;
- ucs |= (*inputOffset++) & 0x0f;
-
- for (index = 0; index < 2; index ++)
- {
- ucs <<= 6;
- oct = (*inputOffset++);
-
- if ((oct & 0x80) != 0x80)
- {
- return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
- }
-
- ucs |= oct & 0x3f;
- }
-
- if (ucs < 0x800) return SetError (ds, -1, "Overlong 3 byte UTF-8 sequence detected when encoding string");
- *(escOffset++) = (JSUINT32) ucs;
- break;
- }
-
- case 4:
- {
- JSUTF32 ucs = 0;
- ucs |= (*inputOffset++) & 0x07;
-
- for (index = 0; index < 3; index ++)
- {
- ucs <<= 6;
- oct = (*inputOffset++);
-
- if ((oct & 0x80) != 0x80)
- {
- return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
- }
-
- ucs |= oct & 0x3f;
- }
-
- if (ucs < 0x10000) return SetError (ds, -1, "Overlong 4 byte UTF-8 sequence detected when decoding 'string'");
-
- *(escOffset++) = (JSUINT32) ucs;
- break;
- }
- }
- }
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_array(struct DecoderState *ds)
-{
- JSOBJ itemValue;
- JSOBJ newObj;
- int len;
- ds->objDepth++;
- if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) {
- return SetError(ds, -1, "Reached object decoding depth limit");
- }
-
- newObj = ds->dec->newArray(ds->prv);
- len = 0;
-
- ds->lastType = JT_INVALID;
- ds->start ++;
-
- for (;;)
- {
- SkipWhitespace(ds);
-
- if ((*ds->start) == ']')
- {
- ds->objDepth--;
- if (len == 0)
- {
- ds->start ++;
- return newObj;
- }
-
- ds->dec->releaseObject(ds->prv, newObj);
- return SetError(ds, -1, "Unexpected character found when decoding array value (1)");
- }
-
- itemValue = decode_any(ds);
-
- if (itemValue == NULL)
- {
- ds->dec->releaseObject(ds->prv, newObj);
- return NULL;
- }
-
- ds->dec->arrayAddItem (ds->prv, newObj, itemValue);
-
- SkipWhitespace(ds);
-
- switch (*(ds->start++))
- {
- case ']':
- {
- ds->objDepth--;
- return newObj;
- }
- case ',':
- break;
-
- default:
- ds->dec->releaseObject(ds->prv, newObj);
- return SetError(ds, -1, "Unexpected character found when decoding array value (2)");
- }
-
- len ++;
- }
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_object( struct DecoderState *ds)
-{
- JSOBJ itemName;
- JSOBJ itemValue;
- JSOBJ newObj;
- int len;
-
- ds->objDepth++;
- if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) {
- return SetError(ds, -1, "Reached object decoding depth limit");
- }
-
- newObj = ds->dec->newObject(ds->prv);
- len = 0;
-
- ds->start ++;
-
- for (;;)
- {
- SkipWhitespace(ds);
-
- if ((*ds->start) == '}')
- {
- ds->objDepth--;
- if (len == 0)
- {
- ds->start ++;
- return newObj;
- }
-
- ds->dec->releaseObject(ds->prv, newObj);
- return SetError(ds, -1, "Unexpected character in found when decoding object value");
- }
-
- ds->lastType = JT_INVALID;
- itemName = decode_any(ds);
-
- if (itemName == NULL)
- {
- ds->dec->releaseObject(ds->prv, newObj);
- return NULL;
- }
-
- if (ds->lastType != JT_UTF8)
- {
- ds->dec->releaseObject(ds->prv, newObj);
- ds->dec->releaseObject(ds->prv, itemName);
- return SetError(ds, -1, "Key name of object must be 'string' when decoding 'object'");
- }
-
- SkipWhitespace(ds);
-
- if (*(ds->start++) != ':')
- {
- ds->dec->releaseObject(ds->prv, newObj);
- ds->dec->releaseObject(ds->prv, itemName);
- return SetError(ds, -1, "No ':' found when decoding object value");
- }
-
- SkipWhitespace(ds);
-
- itemValue = decode_any(ds);
-
- if (itemValue == NULL)
- {
- ds->dec->releaseObject(ds->prv, newObj);
- ds->dec->releaseObject(ds->prv, itemName);
- return NULL;
- }
-
- ds->dec->objectAddKey (ds->prv, newObj, itemName, itemValue);
-
- SkipWhitespace(ds);
-
- switch (*(ds->start++))
- {
- case '}':
- {
- ds->objDepth--;
- return newObj;
- }
- case ',':
- break;
-
- default:
- ds->dec->releaseObject(ds->prv, newObj);
- return SetError(ds, -1, "Unexpected character in found when decoding object value");
- }
-
- len++;
- }
-}
-
-FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_any(struct DecoderState *ds)
-{
- for (;;)
- {
- switch (*ds->start)
- {
- case '\"':
- return decode_string (ds);
- case '0':
- case '1':
- case '2':
- case '3':
- case '4':
- case '5':
- case '6':
- case '7':
- case '8':
- case '9':
- case '-':
- return decode_numeric (ds);
-
- case '[': return decode_array (ds);
- case '{': return decode_object (ds);
- case 't': return decode_true (ds);
- case 'f': return decode_false (ds);
- case 'n': return decode_null (ds);
-
- case ' ':
- case '\t':
- case '\r':
- case '\n':
- // White space
- ds->start ++;
- break;
-
- default:
- return SetError(ds, -1, "Expected object or value");
- }
- }
-}
-
-JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer)
-{
- /*
- FIXME: Base the size of escBuffer of that of cbBuffer so that the unicode escaping doesn't run into the wall each time */
- struct DecoderState ds;
- JSUINT32 escBuffer[(JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32))];
- JSOBJ ret;
-
- ds.start = (char *) buffer;
- ds.end = ds.start + cbBuffer;
-
- ds.escStart = escBuffer;
- ds.escEnd = ds.escStart + (JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32));
- ds.escHeap = 0;
- ds.prv = dec->prv;
- ds.dec = dec;
- ds.dec->errorStr = NULL;
- ds.dec->errorOffset = NULL;
- ds.objDepth = 0;
-
- ds.dec = dec;
-
- ret = decode_any (&ds);
-
- if (ds.escHeap)
- {
- dec->free(ds.escStart);
- }
-
- if (!(dec->errorStr))
- {
- if ((ds.end - ds.start) > 0)
- {
- SkipWhitespace(&ds);
- }
-
- if (ds.start != ds.end && ret)
- {
- dec->releaseObject(ds.prv, ret);
- return SetError(&ds, -1, "Trailing data");
- }
- }
-
- return ret;
-}
diff --git a/srsly/ujson/lib/ultrajsonenc.c b/srsly/ujson/lib/ultrajsonenc.c
deleted file mode 100644
index ea6372b..0000000
--- a/srsly/ujson/lib/ultrajsonenc.c
+++ /dev/null
@@ -1,1067 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
- * Copyright (c) 1988-1993 The Regents of the University of California.
- * Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include "ultrajson.h"
-#include
-#include
-#include
-#include
-#include
-#include
-
-#include
-
-#ifndef TRUE
-#define TRUE 1
-#endif
-#ifndef FALSE
-#define FALSE 0
-#endif
-
-#if ( (defined(_WIN32) || defined(WIN32) ) && ( defined(_MSC_VER) ) )
-#define snprintf sprintf_s
-#endif
-
-/*
-Worst cases being:
-
-Control characters (ASCII < 32)
-0x00 (1 byte) input => \u0000 output (6 bytes)
-1 * 6 => 6 (6 bytes required)
-
-or UTF-16 surrogate pairs
-4 bytes input in UTF-8 => \uXXXX\uYYYY (12 bytes).
-
-4 * 6 => 24 bytes (12 bytes required)
-
-The extra 2 bytes are for the quotes around the string
-
-*/
-#define RESERVE_STRING(_len) (2 + ((_len) * 6))
-
-static const double g_pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000, 10000000000, 100000000000, 1000000000000, 10000000000000, 100000000000000, 1000000000000000};
-static const char g_hexChars[] = "0123456789abcdef";
-static const char g_escapeChars[] = "0123456789\\b\\t\\n\\f\\r\\\"\\\\\\/";
-
-/*
-FIXME: While this is fine dandy and working it's a magic value mess which probably only the author understands.
-Needs a cleanup and more documentation */
-
-/*
-Table for pure ascii output escaping all characters above 127 to \uXXXX */
-static const JSUINT8 g_asciiOutputTable[256] =
-{
-/* 0x00 */ 0, 30, 30, 30, 30, 30, 30, 30, 10, 12, 14, 30, 16, 18, 30, 30,
-/* 0x10 */ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
-/* 0x20 */ 1, 1, 20, 1, 1, 1, 29, 1, 1, 1, 1, 1, 1, 1, 1, 24,
-/* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 29, 1, 29, 1,
-/* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 22, 1, 1, 1,
-/* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
-/* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
-/* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
-/* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
-/* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 1, 1
-};
-
-static void SetError (JSOBJ obj, JSONObjectEncoder *enc, const char *message)
-{
- enc->errorMsg = message;
- enc->errorObj = obj;
-}
-
-/*
-FIXME: Keep track of how big these get across several encoder calls and try to make an estimate
-That way we won't run our head into the wall each call */
-void Buffer_Realloc (JSONObjectEncoder *enc, size_t cbNeeded)
-{
- size_t free_space = enc->end - enc->offset;
- if (free_space >= cbNeeded)
- {
- return;
- }
- size_t curSize = enc->end - enc->start;
- size_t newSize = curSize;
- size_t offset = enc->offset - enc->start;
-
-#ifdef DEBUG
- // In debug mode, allocate only what is requested so that any miscalculation
- // shows up plainly as a crash.
- newSize = (enc->offset - enc->start) + cbNeeded;
-#else
- while (newSize < curSize + cbNeeded)
- {
- newSize *= 2;
- }
-#endif
-
- if (enc->heap)
- {
- enc->start = (char *) enc->realloc (enc->start, newSize);
- if (!enc->start)
- {
- SetError (NULL, enc, "Could not reserve memory block");
- return;
- }
- }
- else
- {
- char *oldStart = enc->start;
- enc->heap = 1;
- enc->start = (char *) enc->malloc (newSize);
- if (!enc->start)
- {
- SetError (NULL, enc, "Could not reserve memory block");
- return;
- }
- memcpy (enc->start, oldStart, offset);
- }
- enc->offset = enc->start + offset;
- enc->end = enc->start + newSize;
-}
-
-#define Buffer_Reserve(__enc, __len) \
- if ( (size_t) ((__enc)->end - (__enc)->offset) < (size_t) (__len)) \
- { \
- Buffer_Realloc((__enc), (__len));\
- } \
-
-FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendShortHexUnchecked (char *outputOffset, unsigned short value)
-{
- *(outputOffset++) = g_hexChars[(value & 0xf000) >> 12];
- *(outputOffset++) = g_hexChars[(value & 0x0f00) >> 8];
- *(outputOffset++) = g_hexChars[(value & 0x00f0) >> 4];
- *(outputOffset++) = g_hexChars[(value & 0x000f) >> 0];
-}
-
-int Buffer_EscapeStringUnvalidated (JSONObjectEncoder *enc, const char *io, const char *end)
-{
- char *of = (char *) enc->offset;
-
- for (;;)
- {
- switch (*io)
- {
- case 0x00:
- {
- if (io < end)
- {
- *(of++) = '\\';
- *(of++) = 'u';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = '0';
- break;
- }
- else
- {
- enc->offset += (of - enc->offset);
- return TRUE;
- }
- }
- case '\"': (*of++) = '\\'; (*of++) = '\"'; break;
- case '\\': (*of++) = '\\'; (*of++) = '\\'; break;
- case '\b': (*of++) = '\\'; (*of++) = 'b'; break;
- case '\f': (*of++) = '\\'; (*of++) = 'f'; break;
- case '\n': (*of++) = '\\'; (*of++) = 'n'; break;
- case '\r': (*of++) = '\\'; (*of++) = 'r'; break;
- case '\t': (*of++) = '\\'; (*of++) = 't'; break;
-
- case '/':
- {
- if (enc->escapeForwardSlashes)
- {
- (*of++) = '\\';
- (*of++) = '/';
- }
- else
- {
- // Same as default case below.
- (*of++) = (*io);
- }
- break;
- }
- case 0x26: // '&'
- case 0x3c: // '<'
- case 0x3e: // '>'
- {
- if (enc->encodeHTMLChars)
- {
- // Fall through to \u00XX case below.
- }
- else
- {
- // Same as default case below.
- (*of++) = (*io);
- break;
- }
- }
- case 0x01:
- case 0x02:
- case 0x03:
- case 0x04:
- case 0x05:
- case 0x06:
- case 0x07:
- case 0x0b:
- case 0x0e:
- case 0x0f:
- case 0x10:
- case 0x11:
- case 0x12:
- case 0x13:
- case 0x14:
- case 0x15:
- case 0x16:
- case 0x17:
- case 0x18:
- case 0x19:
- case 0x1a:
- case 0x1b:
- case 0x1c:
- case 0x1d:
- case 0x1e:
- case 0x1f:
- {
- *(of++) = '\\';
- *(of++) = 'u';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)];
- *(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)];
- break;
- }
- default: (*of++) = (*io); break;
- }
- io++;
- }
-}
-
-int Buffer_EscapeStringValidated (JSOBJ obj, JSONObjectEncoder *enc, const char *io, const char *end)
-{
- JSUTF32 ucs;
- char *of = (char *) enc->offset;
-
- for (;;)
- {
-#ifdef DEBUG
- // 6 is the maximum length of a single character (cf. RESERVE_STRING).
- if ((io < end) && (enc->end - of < 6)) {
- fprintf(stderr, "Ran out of buffer space during Buffer_EscapeStringValidated()\n");
- abort();
- }
-#endif
- JSUINT8 utflen = g_asciiOutputTable[(unsigned char) *io];
-
- switch (utflen)
- {
- case 0:
- {
- if (io < end)
- {
- *(of++) = '\\';
- *(of++) = 'u';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = '0';
- io ++;
- continue;
- }
- else
- {
- enc->offset += (of - enc->offset);
- return TRUE;
- }
- }
-
- case 1:
- {
- *(of++)= (*io++);
- continue;
- }
-
- case 2:
- {
- JSUTF32 in;
- JSUTF16 in16;
-
- if (end - io < 1)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
- return FALSE;
- }
-
- memcpy(&in16, io, sizeof(JSUTF16));
- in = (JSUTF32) in16;
-
-#ifdef __LITTLE_ENDIAN__
- ucs = ((in & 0x1f) << 6) | ((in >> 8) & 0x3f);
-#else
- ucs = ((in & 0x1f00) >> 2) | (in & 0x3f);
-#endif
-
- if (ucs < 0x80)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Overlong 2 byte UTF-8 sequence detected when encoding string");
- return FALSE;
- }
-
- io += 2;
- break;
- }
-
- case 3:
- {
- JSUTF32 in;
- JSUTF16 in16;
- JSUINT8 in8;
-
- if (end - io < 2)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
- return FALSE;
- }
-
- memcpy(&in16, io, sizeof(JSUTF16));
- memcpy(&in8, io + 2, sizeof(JSUINT8));
-#ifdef __LITTLE_ENDIAN__
- in = (JSUTF32) in16;
- in |= in8 << 16;
- ucs = ((in & 0x0f) << 12) | ((in & 0x3f00) >> 2) | ((in & 0x3f0000) >> 16);
-#else
- in = in16 << 8;
- in |= in8;
- ucs = ((in & 0x0f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f);
-#endif
-
- if (ucs < 0x800)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Overlong 3 byte UTF-8 sequence detected when encoding string");
- return FALSE;
- }
-
- io += 3;
- break;
- }
- case 4:
- {
- JSUTF32 in;
-
- if (end - io < 3)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
- return FALSE;
- }
-
- memcpy(&in, io, sizeof(JSUTF32));
-#ifdef __LITTLE_ENDIAN__
- ucs = ((in & 0x07) << 18) | ((in & 0x3f00) << 4) | ((in & 0x3f0000) >> 10) | ((in & 0x3f000000) >> 24);
-#else
- ucs = ((in & 0x07000000) >> 6) | ((in & 0x3f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f);
-#endif
- if (ucs < 0x10000)
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Overlong 4 byte UTF-8 sequence detected when encoding string");
- return FALSE;
- }
-
- io += 4;
- break;
- }
-
-
- case 5:
- case 6:
- {
- enc->offset += (of - enc->offset);
- SetError (obj, enc, "Unsupported UTF-8 sequence length when encoding string");
- return FALSE;
- }
-
- case 29:
- {
- if (enc->encodeHTMLChars)
- {
- // Fall through to \u00XX case 30 below.
- }
- else
- {
- // Same as case 1 above.
- *(of++) = (*io++);
- continue;
- }
- }
-
- case 30:
- {
- // \uXXXX encode
- *(of++) = '\\';
- *(of++) = 'u';
- *(of++) = '0';
- *(of++) = '0';
- *(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)];
- *(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)];
- io ++;
- continue;
- }
- case 10:
- case 12:
- case 14:
- case 16:
- case 18:
- case 20:
- case 22:
- {
- *(of++) = *( (char *) (g_escapeChars + utflen + 0));
- *(of++) = *( (char *) (g_escapeChars + utflen + 1));
- io ++;
- continue;
- }
- case 24:
- {
- if (enc->escapeForwardSlashes)
- {
- *(of++) = *( (char *) (g_escapeChars + utflen + 0));
- *(of++) = *( (char *) (g_escapeChars + utflen + 1));
- io ++;
- }
- else
- {
- // Same as case 1 above.
- *(of++) = (*io++);
- }
- continue;
- }
- // This can never happen, it's here to make L4 VC++ happy
- default:
- {
- ucs = 0;
- break;
- }
- }
-
- /*
- If the character is a UTF8 sequence of length > 1 we end up here */
- if (ucs >= 0x10000)
- {
- ucs -= 0x10000;
- *(of++) = '\\';
- *(of++) = 'u';
- Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs >> 10) + 0xd800);
- of += 4;
-
- *(of++) = '\\';
- *(of++) = 'u';
- Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs & 0x3ff) + 0xdc00);
- of += 4;
- }
- else
- {
- *(of++) = '\\';
- *(of++) = 'u';
- Buffer_AppendShortHexUnchecked(of, (unsigned short) ucs);
- of += 4;
- }
- }
-}
-
-static FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendCharUnchecked(JSONObjectEncoder *enc, char chr)
-{
-#ifdef DEBUG
- if (enc->end <= enc->offset)
- {
- fprintf(stderr, "Overflow writing byte %d '%c'. The last few characters were:\n'''", chr, chr);
- char * recent = enc->offset - 1000;
- if (enc->start > recent)
- {
- recent = enc->start;
- }
- for (; recent < enc->offset; recent++)
- {
- fprintf(stderr, "%c", *recent);
- }
- fprintf(stderr, "'''\n");
- abort();
- }
-#endif
- *(enc->offset++) = chr;
-}
-
-FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC strreverse(char* begin, char* end)
-{
- char aux;
- while (end > begin)
- aux = *end, *end-- = *begin, *begin++ = aux;
-}
-
-void Buffer_AppendIndentNewlineUnchecked(JSONObjectEncoder *enc)
-{
- if (enc->indent > 0) Buffer_AppendCharUnchecked(enc, '\n');
-}
-
-void Buffer_AppendIndentUnchecked(JSONObjectEncoder *enc, JSINT32 value)
-{
- int i;
- if (enc->indent > 0)
- while (value-- > 0)
- for (i = 0; i < enc->indent; i++)
- Buffer_AppendCharUnchecked(enc, ' ');
-}
-
-void Buffer_AppendIntUnchecked(JSONObjectEncoder *enc, JSINT32 value)
-{
- char* wstr;
- JSUINT32 uvalue = (value < 0) ? -value : value;
-
- wstr = enc->offset;
- // Conversion. Number is reversed.
-
- do *wstr++ = (char)(48 + (uvalue % 10)); while(uvalue /= 10);
- if (value < 0) *wstr++ = '-';
-
- // Reverse string
- strreverse(enc->offset,wstr - 1);
- enc->offset += (wstr - (enc->offset));
-}
-
-void Buffer_AppendLongUnchecked(JSONObjectEncoder *enc, JSINT64 value)
-{
- char* wstr;
- JSUINT64 uvalue = (value < 0) ? -value : value;
-
- wstr = enc->offset;
- // Conversion. Number is reversed.
-
- do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL);
- if (value < 0) *wstr++ = '-';
-
- // Reverse string
- strreverse(enc->offset,wstr - 1);
- enc->offset += (wstr - (enc->offset));
-}
-
-void Buffer_AppendUnsignedLongUnchecked(JSONObjectEncoder *enc, JSUINT64 value)
-{
- char* wstr;
- JSUINT64 uvalue = value;
-
- wstr = enc->offset;
- // Conversion. Number is reversed.
-
- do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL);
-
- // Reverse string
- strreverse(enc->offset,wstr - 1);
- enc->offset += (wstr - (enc->offset));
-}
-
-int Buffer_AppendDoubleUnchecked(JSOBJ obj, JSONObjectEncoder *enc, double value)
-{
- /* if input is larger than thres_max, revert to exponential */
- const double thres_max = (double) 1e16 - 1;
- int count;
- double diff = 0.0;
- char* str = enc->offset;
- char* wstr = str;
- unsigned long long whole;
- double tmp;
- unsigned long long frac;
- int neg;
- double pow10;
-
- if (value == HUGE_VAL || value == -HUGE_VAL)
- {
- SetError (obj, enc, "Invalid Inf value when encoding double");
- return FALSE;
- }
-
- if (!(value == value))
- {
- SetError (obj, enc, "Invalid Nan value when encoding double");
- return FALSE;
- }
-
- /* we'll work in positive values and deal with the
- negative sign issue later */
- neg = 0;
- if (value < 0)
- {
- neg = 1;
- value = -value;
- }
-
- pow10 = g_pow10[enc->doublePrecision];
-
- whole = (unsigned long long) value;
- tmp = (value - whole) * pow10;
- frac = (unsigned long long)(tmp);
- diff = tmp - frac;
-
- if (diff > 0.5)
- {
- ++frac;
- /* handle rollover, e.g. case 0.99 with prec 1 is 1.0 */
- if (frac >= pow10)
- {
- frac = 0;
- ++whole;
- }
- }
- else
- if (diff == 0.5 && ((frac == 0) || (frac & 1)))
- {
- /* if halfway, round up if odd, OR
- if last digit is 0. That last part is strange */
- ++frac;
- }
-
- /* for very large numbers switch back to native sprintf for exponentials.
- anyone want to write code to replace this? */
- /*
- normal printf behavior is to print EVERY whole number digit
- which can be 100s of characters overflowing your buffers == bad
- */
- if (value > thres_max)
- {
- enc->offset += snprintf(str, enc->end - enc->offset, "%.15e", neg ? -value : value);
- return TRUE;
- }
-
- if (enc->doublePrecision == 0)
- {
- diff = value - whole;
-
- if (diff > 0.5)
- {
- /* greater than 0.5, round up, e.g. 1.6 -> 2 */
- ++whole;
- }
- else
- if (diff == 0.5 && (whole & 1))
- {
- /* exactly 0.5 and ODD, then round up */
- /* 1.5 -> 2, but 2.5 -> 2 */
- ++whole;
- }
-
- //vvvvvvvvvvvvvvvvvvv Diff from modp_dto2
- }
- else
- if (frac)
- {
- count = enc->doublePrecision;
- // now do fractional part, as an unsigned number
- // we know it is not 0 but we can have leading zeros, these
- // should be removed
- while (!(frac % 10))
- {
- --count;
- frac /= 10;
- }
- //^^^^^^^^^^^^^^^^^^^ Diff from modp_dto2
-
- // now do fractional part, as an unsigned number
- do
- {
- --count;
- *wstr++ = (char)(48 + (frac % 10));
- } while (frac /= 10);
- // add extra 0s
- while (count-- > 0)
- {
- *wstr++ = '0';
- }
- // add decimal
- *wstr++ = '.';
- }
- else
- {
- *wstr++ = '0';
- *wstr++ = '.';
- }
-
- // do whole part
- // Take care of sign
- // Conversion. Number is reversed.
- do *wstr++ = (char)(48 + (whole % 10)); while (whole /= 10);
-
- if (neg)
- {
- *wstr++ = '-';
- }
- strreverse(str, wstr-1);
- enc->offset += (wstr - (enc->offset));
-
- return TRUE;
-}
-
-/*
-FIXME:
-Handle integration functions returning NULL here */
-
-/*
-FIXME:
-Perhaps implement recursion detection */
-
-void encode(JSOBJ obj, JSONObjectEncoder *enc, const char *name, size_t cbName)
-{
- const char *value;
- char *objName;
- int count;
- JSOBJ iterObj;
- size_t szlen;
- JSONTypeContext tc;
-
- if (enc->level > enc->recursionMax)
- {
- SetError (obj, enc, "Maximum recursion level reached");
- return;
- }
-
- if (enc->errorMsg)
- {
- return;
- }
-
- if (name)
- {
- // 2 extra for the colon and optional space after it
- Buffer_Reserve(enc, RESERVE_STRING(cbName) + 2);
- Buffer_AppendCharUnchecked(enc, '\"');
-
- if (enc->forceASCII)
- {
- if (!Buffer_EscapeStringValidated(obj, enc, name, name + cbName))
- {
- return;
- }
- }
- else
- {
- if (!Buffer_EscapeStringUnvalidated(enc, name, name + cbName))
- {
- return;
- }
- }
-
- Buffer_AppendCharUnchecked(enc, '\"');
-
- Buffer_AppendCharUnchecked (enc, ':');
- }
-
- tc.encoder_prv = enc->prv;
- enc->beginTypeContext(obj, &tc, enc);
-
- /*
- This reservation covers any additions on non-variable parts below, specifically:
- - Opening brackets for JT_ARRAY and JT_OBJECT
- - Number representation for JT_LONG, JT_ULONG, JT_INT, and JT_DOUBLE
- - Constant value for JT_TRUE, JT_FALSE, JT_NULL
- The length of 128 is the worst case length of the Buffer_AppendDoubleDconv addition.
- The other types above all have smaller representations.
- */
- Buffer_Reserve (enc, 128);
-
- switch (tc.type)
- {
- case JT_INVALID:
- {
- return;
- }
-
- case JT_ARRAY:
- {
- count = 0;
-
- Buffer_AppendCharUnchecked (enc, '[');
- Buffer_AppendIndentNewlineUnchecked (enc);
-
- while (enc->iterNext(obj, &tc))
- {
- // The extra 2 bytes cover the comma and (optional) newline.
- Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2);
-
- if (count > 0)
- {
- Buffer_AppendCharUnchecked (enc, ',');
- Buffer_AppendIndentNewlineUnchecked (enc);
- }
-
- iterObj = enc->iterGetValue(obj, &tc);
-
- enc->level ++;
- Buffer_AppendIndentUnchecked (enc, enc->level);
- encode (iterObj, enc, NULL, 0);
- count ++;
- }
-
- enc->iterEnd(obj, &tc);
- // Reserve space for the indentation plus the newline.
- Buffer_Reserve (enc, enc->indent * enc->level + 1);
- Buffer_AppendIndentNewlineUnchecked (enc);
- Buffer_AppendIndentUnchecked (enc, enc->level);
- Buffer_Reserve (enc, 1);
- Buffer_AppendCharUnchecked (enc, ']');
- break;
- }
-
- case JT_OBJECT:
- {
- count = 0;
-
- Buffer_AppendCharUnchecked (enc, '{');
- Buffer_AppendIndentNewlineUnchecked (enc);
-
- while (enc->iterNext(obj, &tc))
- {
- // The extra 2 bytes cover the comma and optional newline.
- Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2);
-
- if (count > 0)
- {
- Buffer_AppendCharUnchecked (enc, ',');
- Buffer_AppendIndentNewlineUnchecked (enc);
- }
-
- iterObj = enc->iterGetValue(obj, &tc);
- objName = enc->iterGetName(obj, &tc, &szlen);
-
- enc->level ++;
- Buffer_AppendIndentUnchecked (enc, enc->level);
- encode (iterObj, enc, objName, szlen);
- count ++;
- }
-
- enc->iterEnd(obj, &tc);
- Buffer_Reserve (enc, enc->indent * enc->level + 1);
- Buffer_AppendIndentNewlineUnchecked (enc);
- Buffer_AppendIndentUnchecked (enc, enc->level);
- Buffer_Reserve (enc, 1);
- Buffer_AppendCharUnchecked (enc, '}');
- break;
- }
-
- case JT_LONG:
- {
- Buffer_AppendLongUnchecked (enc, enc->getLongValue(obj, &tc));
- break;
- }
-
- case JT_ULONG:
- {
- Buffer_AppendUnsignedLongUnchecked (enc, enc->getUnsignedLongValue(obj, &tc));
- break;
- }
-
- case JT_INT:
- {
- Buffer_AppendIntUnchecked (enc, enc->getIntValue(obj, &tc));
- break;
- }
-
- case JT_TRUE:
- {
- Buffer_AppendCharUnchecked (enc, 't');
- Buffer_AppendCharUnchecked (enc, 'r');
- Buffer_AppendCharUnchecked (enc, 'u');
- Buffer_AppendCharUnchecked (enc, 'e');
- break;
- }
-
- case JT_FALSE:
- {
- Buffer_AppendCharUnchecked (enc, 'f');
- Buffer_AppendCharUnchecked (enc, 'a');
- Buffer_AppendCharUnchecked (enc, 'l');
- Buffer_AppendCharUnchecked (enc, 's');
- Buffer_AppendCharUnchecked (enc, 'e');
- break;
- }
-
-
- case JT_NULL:
- {
- Buffer_AppendCharUnchecked (enc, 'n');
- Buffer_AppendCharUnchecked (enc, 'u');
- Buffer_AppendCharUnchecked (enc, 'l');
- Buffer_AppendCharUnchecked (enc, 'l');
- break;
- }
-
- case JT_DOUBLE:
- {
- if (!Buffer_AppendDoubleUnchecked (obj, enc, enc->getDoubleValue(obj, &tc)))
- {
- enc->endTypeContext(obj, &tc);
- enc->level --;
- return;
- }
- break;
- }
-
- case JT_UTF8:
- {
- value = enc->getStringValue(obj, &tc, &szlen);
- if(!value)
- {
- SetError(obj, enc, "utf-8 encoding error");
- return;
- }
-
- Buffer_Reserve(enc, RESERVE_STRING(szlen));
- if (enc->errorMsg)
- {
- enc->endTypeContext(obj, &tc);
- return;
- }
- Buffer_AppendCharUnchecked (enc, '\"');
-
- if (enc->forceASCII)
- {
- if (!Buffer_EscapeStringValidated(obj, enc, value, value + szlen))
- {
- enc->endTypeContext(obj, &tc);
- enc->level --;
- return;
- }
- }
- else
- {
- if (!Buffer_EscapeStringUnvalidated(enc, value, value + szlen))
- {
- enc->endTypeContext(obj, &tc);
- enc->level --;
- return;
- }
- }
-
- Buffer_AppendCharUnchecked (enc, '\"');
- break;
- }
-
- case JT_RAW:
- {
- value = enc->getStringValue(obj, &tc, &szlen);
- if(!value)
- {
- SetError(obj, enc, "utf-8 encoding error");
- return;
- }
-
- Buffer_Reserve(enc, szlen);
- if (enc->errorMsg)
- {
- enc->endTypeContext(obj, &tc);
- return;
- }
-
- memcpy(enc->offset, value, szlen);
- enc->offset += szlen;
-
- break;
- }
- }
-
- enc->endTypeContext(obj, &tc);
- enc->level --;
-}
-
-char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *_buffer, size_t _cbBuffer)
-{
- enc->malloc = enc->malloc ? enc->malloc : malloc;
- enc->free = enc->free ? enc->free : free;
- enc->realloc = enc->realloc ? enc->realloc : realloc;
- enc->errorMsg = NULL;
- enc->errorObj = NULL;
- enc->level = 0;
-
- if (enc->recursionMax < 1)
- {
- enc->recursionMax = JSON_MAX_RECURSION_DEPTH;
- }
-
- if (enc->doublePrecision < 0 ||
- enc->doublePrecision > JSON_DOUBLE_MAX_DECIMALS)
- {
- enc->doublePrecision = JSON_DOUBLE_MAX_DECIMALS;
- }
-
- if (_buffer == NULL)
- {
- _cbBuffer = 32768;
- enc->start = (char *) enc->malloc (_cbBuffer);
- if (!enc->start)
- {
- SetError(obj, enc, "Could not reserve memory block");
- return NULL;
- }
- enc->heap = 1;
- }
- else
- {
- enc->start = _buffer;
- enc->heap = 0;
- }
-
- enc->end = enc->start + _cbBuffer;
- enc->offset = enc->start;
-
- encode (obj, enc, NULL, 0);
-
- Buffer_Reserve(enc, 1);
- if (enc->errorMsg)
- {
- return NULL;
- }
- Buffer_AppendCharUnchecked(enc, '\0');
-
- return enc->start;
-}
diff --git a/srsly/ujson/objToJSON.c b/srsly/ujson/objToJSON.c
deleted file mode 100644
index c22f861..0000000
--- a/srsly/ujson/objToJSON.c
+++ /dev/null
@@ -1,984 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
-* Copyright (c) 1988-1993 The Regents of the University of California.
-* Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include "py_defines.h"
-#include
-#include
-#include
-
-#define EPOCH_ORD 719163
-static PyObject* type_decimal = NULL;
-
-typedef void *(*PFN_PyTypeToJSON)(JSOBJ obj, JSONTypeContext *ti, void *outValue, size_t *_outLen);
-
-#if (PY_VERSION_HEX < 0x02050000)
-typedef ssize_t Py_ssize_t;
-#endif
-
-typedef struct __TypeContext
-{
- JSPFN_ITEREND iterEnd;
- JSPFN_ITERNEXT iterNext;
- JSPFN_ITERGETNAME iterGetName;
- JSPFN_ITERGETVALUE iterGetValue;
- PFN_PyTypeToJSON PyTypeToJSON;
- PyObject *newObj;
- PyObject *dictObj;
- Py_ssize_t index;
- Py_ssize_t size;
- PyObject *itemValue;
- PyObject *itemName;
- PyObject *attrList;
- PyObject *iterator;
-
- union
- {
- PyObject *rawJSONValue;
- JSINT64 longValue;
- JSUINT64 unsignedLongValue;
- };
-} TypeContext;
-
-#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))
-
-struct PyDictIterState
-{
- PyObject *keys;
- size_t i;
- size_t sz;
-};
-
-//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__)
-#define PRINTMARK()
-
-void initObjToJSON(void)
-{
- PyObject* mod_decimal = PyImport_ImportModule("decimal");
- if (mod_decimal)
- {
- type_decimal = PyObject_GetAttrString(mod_decimal, "Decimal");
- Py_INCREF(type_decimal);
- Py_DECREF(mod_decimal);
- }
- else
- PyErr_Clear();
-
- PyDateTime_IMPORT;
-}
-
-#ifdef _LP64
-static void *PyIntToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- *((JSINT64 *) outValue) = PyInt_AS_LONG (obj);
- return NULL;
-}
-#else
-static void *PyIntToINT32(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- *((JSINT32 *) outValue) = PyInt_AS_LONG (obj);
- return NULL;
-}
-#endif
-
-static void *PyLongToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- *((JSINT64 *) outValue) = GET_TC(tc)->longValue;
- return NULL;
-}
-
-static void *PyLongToUINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- *((JSUINT64 *) outValue) = GET_TC(tc)->unsignedLongValue;
- return NULL;
-}
-
-static void *PyFloatToDOUBLE(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- *((double *) outValue) = PyFloat_AsDouble (obj);
- return NULL;
-}
-
-static void *PyStringToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- *_outLen = PyString_GET_SIZE(obj);
- return PyString_AS_STRING(obj);
-}
-
-static void *PyUnicodeToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- PyObject *newObj;
-#if (PY_VERSION_HEX >= 0x03030000)
- if(PyUnicode_IS_COMPACT_ASCII(obj))
- {
- Py_ssize_t len;
- char *data = PyUnicode_AsUTF8AndSize(obj, &len);
- *_outLen = len;
- return data;
- }
-#endif
- newObj = PyUnicode_AsUTF8String(obj);
- if(!newObj)
- {
- return NULL;
- }
-
- GET_TC(tc)->newObj = newObj;
-
- *_outLen = PyString_GET_SIZE(newObj);
- return PyString_AS_STRING(newObj);
-}
-
-static void *PyRawJSONToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = GET_TC(tc)->rawJSONValue;
- if (PyUnicode_Check(obj)) {
- return PyUnicodeToUTF8(obj, tc, outValue, _outLen);
- }
- else {
- return PyStringToUTF8(obj, tc, outValue, _outLen);
- }
-}
-
-static void *PyDateTimeToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- PyObject *date, *ord, *utcoffset;
- int y, m, d, h, mn, s, days;
-
- utcoffset = PyObject_CallMethod(obj, "utcoffset", NULL);
- if(utcoffset != Py_None){
- obj = PyNumber_Subtract(obj, utcoffset);
- }
-
- y = PyDateTime_GET_YEAR(obj);
- m = PyDateTime_GET_MONTH(obj);
- d = PyDateTime_GET_DAY(obj);
- h = PyDateTime_DATE_GET_HOUR(obj);
- mn = PyDateTime_DATE_GET_MINUTE(obj);
- s = PyDateTime_DATE_GET_SECOND(obj);
-
- date = PyDate_FromDate(y, m, 1);
- ord = PyObject_CallMethod(date, "toordinal", NULL);
- days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1;
- Py_DECREF(date);
- Py_DECREF(ord);
- *( (JSINT64 *) outValue) = (((JSINT64) ((days * 24 + h) * 60 + mn)) * 60 + s);
- return NULL;
-}
-
-static void *PyDateToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
-{
- PyObject *obj = (PyObject *) _obj;
- PyObject *date, *ord;
- int y, m, d, days;
-
- y = PyDateTime_GET_YEAR(obj);
- m = PyDateTime_GET_MONTH(obj);
- d = PyDateTime_GET_DAY(obj);
-
- date = PyDate_FromDate(y, m, 1);
- ord = PyObject_CallMethod(date, "toordinal", NULL);
- days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1;
- Py_DECREF(date);
- Py_DECREF(ord);
- *( (JSINT64 *) outValue) = ((JSINT64) days * 86400);
-
- return NULL;
-}
-
-int Tuple_iterNext(JSOBJ obj, JSONTypeContext *tc)
-{
- PyObject *item;
-
- if (GET_TC(tc)->index >= GET_TC(tc)->size)
- {
- return 0;
- }
-
- item = PyTuple_GET_ITEM (obj, GET_TC(tc)->index);
-
- GET_TC(tc)->itemValue = item;
- GET_TC(tc)->index ++;
- return 1;
-}
-
-void Tuple_iterEnd(JSOBJ obj, JSONTypeContext *tc)
-{
-}
-
-JSOBJ Tuple_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->itemValue;
-}
-
-char *Tuple_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
-{
- return NULL;
-}
-
-int List_iterNext(JSOBJ obj, JSONTypeContext *tc)
-{
- if (GET_TC(tc)->index >= GET_TC(tc)->size)
- {
- PRINTMARK();
- return 0;
- }
-
- GET_TC(tc)->itemValue = PyList_GET_ITEM (obj, GET_TC(tc)->index);
- GET_TC(tc)->index ++;
- return 1;
-}
-
-void List_iterEnd(JSOBJ obj, JSONTypeContext *tc)
-{
-}
-
-JSOBJ List_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->itemValue;
-}
-
-char *List_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
-{
- return NULL;
-}
-
-//=============================================================================
-// Dict iteration functions
-// itemName might converted to string (Python_Str). Do refCounting
-// itemValue is borrowed from object (which is dict). No refCounting
-//=============================================================================
-
-int Dict_iterNext(JSOBJ obj, JSONTypeContext *tc)
-{
- PyObject* itemNameTmp;
-
- if (GET_TC(tc)->itemName)
- {
- Py_DECREF(GET_TC(tc)->itemName);
- GET_TC(tc)->itemName = NULL;
- }
-
- if (!(GET_TC(tc)->itemName = PyIter_Next(GET_TC(tc)->iterator)))
- {
- PRINTMARK();
- return 0;
- }
-
- if (GET_TC(tc)->itemValue) {
- Py_DECREF(GET_TC(tc)->itemValue);
- GET_TC(tc)->itemValue = NULL;
- }
-
- if (!(GET_TC(tc)->itemValue = PyObject_GetItem(GET_TC(tc)->dictObj, GET_TC(tc)->itemName))) {
- PRINTMARK();
- return 0;
- }
-
- if (PyUnicode_Check(GET_TC(tc)->itemName))
- {
- itemNameTmp = GET_TC(tc)->itemName;
- GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
- Py_DECREF(itemNameTmp);
- }
- else
- if (!PyString_Check(GET_TC(tc)->itemName))
- {
- itemNameTmp = GET_TC(tc)->itemName;
- GET_TC(tc)->itemName = PyObject_Str(itemNameTmp);
- Py_DECREF(itemNameTmp);
-#if PY_MAJOR_VERSION >= 3
- itemNameTmp = GET_TC(tc)->itemName;
- GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
- Py_DECREF(itemNameTmp);
-#endif
- }
- PRINTMARK();
- return 1;
-}
-
-void Dict_iterEnd(JSOBJ obj, JSONTypeContext *tc)
-{
- if (GET_TC(tc)->itemName) {
- Py_DECREF(GET_TC(tc)->itemName);
- GET_TC(tc)->itemName = NULL;
- }
- if (GET_TC(tc)->itemValue) {
- Py_DECREF(GET_TC(tc)->itemValue);
- GET_TC(tc)->itemValue = NULL;
- }
- Py_CLEAR(GET_TC(tc)->iterator);
- Py_DECREF(GET_TC(tc)->dictObj);
- PRINTMARK();
-}
-
-JSOBJ Dict_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->itemValue;
-}
-
-char *Dict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
-{
- *outLen = PyString_GET_SIZE(GET_TC(tc)->itemName);
- return PyString_AS_STRING(GET_TC(tc)->itemName);
-}
-
-int SortedDict_iterNext(JSOBJ obj, JSONTypeContext *tc)
-{
- PyObject *items = NULL, *item = NULL, *key = NULL, *value = NULL;
- Py_ssize_t i, nitems;
-#if PY_MAJOR_VERSION >= 3
- PyObject* keyTmp;
-#endif
-
- // Upon first call, obtain a list of the keys and sort them. This follows the same logic as the
- // stanard library's _json.c sort_keys handler.
- if (GET_TC(tc)->newObj == NULL)
- {
- // Obtain the list of keys from the dictionary.
- items = PyMapping_Keys(GET_TC(tc)->dictObj);
- if (items == NULL)
- {
- goto error;
- }
- else if (!PyList_Check(items))
- {
- PyErr_SetString(PyExc_ValueError, "keys must return list");
- goto error;
- }
-
- // Sort the list.
- if (PyList_Sort(items) < 0)
- {
- goto error;
- }
-
- // Obtain the value for each key, and pack a list of (key, value) 2-tuples.
- nitems = PyList_GET_SIZE(items);
- for (i = 0; i < nitems; i++)
- {
- key = PyList_GET_ITEM(items, i);
- value = PyDict_GetItem(GET_TC(tc)->dictObj, key);
-
- // Subject the key to the same type restrictions and conversions as in Dict_iterGetValue.
- if (PyUnicode_Check(key))
- {
- key = PyUnicode_AsUTF8String(key);
- }
- else if (!PyString_Check(key))
- {
- key = PyObject_Str(key);
-#if PY_MAJOR_VERSION >= 3
- keyTmp = key;
- key = PyUnicode_AsUTF8String(key);
- Py_DECREF(keyTmp);
-#endif
- }
- else
- {
- Py_INCREF(key);
- }
-
- item = PyTuple_Pack(2, key, value);
- if (item == NULL)
- {
- goto error;
- }
- if (PyList_SetItem(items, i, item))
- {
- goto error;
- }
- Py_DECREF(key);
- }
-
- // Store the sorted list of tuples in the newObj slot.
- GET_TC(tc)->newObj = items;
- GET_TC(tc)->size = nitems;
- }
-
- if (GET_TC(tc)->index >= GET_TC(tc)->size)
- {
- PRINTMARK();
- return 0;
- }
-
- item = PyList_GET_ITEM(GET_TC(tc)->newObj, GET_TC(tc)->index);
- GET_TC(tc)->itemName = PyTuple_GET_ITEM(item, 0);
- GET_TC(tc)->itemValue = PyTuple_GET_ITEM(item, 1);
- GET_TC(tc)->index++;
- return 1;
-
-error:
- Py_XDECREF(item);
- Py_XDECREF(key);
- Py_XDECREF(value);
- Py_XDECREF(items);
- return -1;
-}
-
-void SortedDict_iterEnd(JSOBJ obj, JSONTypeContext *tc)
-{
- GET_TC(tc)->itemName = NULL;
- GET_TC(tc)->itemValue = NULL;
- Py_DECREF(GET_TC(tc)->newObj);
- Py_DECREF(GET_TC(tc)->dictObj);
- PRINTMARK();
-}
-
-JSOBJ SortedDict_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->itemValue;
-}
-
-char *SortedDict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
-{
- *outLen = PyString_GET_SIZE(GET_TC(tc)->itemName);
- return PyString_AS_STRING(GET_TC(tc)->itemName);
-}
-
-
-void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder *enc)
-{
- pc->dictObj = dictObj;
- if (enc->sortKeys)
- {
- pc->iterEnd = SortedDict_iterEnd;
- pc->iterNext = SortedDict_iterNext;
- pc->iterGetValue = SortedDict_iterGetValue;
- pc->iterGetName = SortedDict_iterGetName;
- pc->index = 0;
- }
- else
- {
- pc->iterEnd = Dict_iterEnd;
- pc->iterNext = Dict_iterNext;
- pc->iterGetValue = Dict_iterGetValue;
- pc->iterGetName = Dict_iterGetName;
- pc->iterator = PyObject_GetIter(dictObj);
- }
-}
-
-void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
-{
- PyObject *obj, *objRepr, *exc;
- TypeContext *pc;
- PRINTMARK();
- if (!_obj)
- {
- tc->type = JT_INVALID;
- return;
- }
-
- obj = (PyObject*) _obj;
-
- tc->prv = PyObject_Malloc(sizeof(TypeContext));
- pc = (TypeContext *) tc->prv;
- if (!pc)
- {
- tc->type = JT_INVALID;
- PyErr_NoMemory();
- return;
- }
- pc->newObj = NULL;
- pc->dictObj = NULL;
- pc->itemValue = NULL;
- pc->itemName = NULL;
- pc->iterator = NULL;
- pc->attrList = NULL;
- pc->index = 0;
- pc->size = 0;
- pc->longValue = 0;
- pc->rawJSONValue = NULL;
-
- if (PyIter_Check(obj))
- {
- PRINTMARK();
- goto ISITERABLE;
- }
-
- if (PyBool_Check(obj))
- {
- PRINTMARK();
- tc->type = (obj == Py_True) ? JT_TRUE : JT_FALSE;
- return;
- }
- else
- if (PyLong_Check(obj))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyLongToINT64;
- tc->type = JT_LONG;
- GET_TC(tc)->longValue = PyLong_AsLongLong(obj);
-
- exc = PyErr_Occurred();
- if (!exc)
- {
- return;
- }
-
- if (exc && PyErr_ExceptionMatches(PyExc_OverflowError))
- {
- PyErr_Clear();
- pc->PyTypeToJSON = PyLongToUINT64;
- tc->type = JT_ULONG;
- GET_TC(tc)->unsignedLongValue = PyLong_AsUnsignedLongLong(obj);
-
- exc = PyErr_Occurred();
- if (exc && PyErr_ExceptionMatches(PyExc_OverflowError))
- {
- PRINTMARK();
- goto INVALID;
- }
- }
-
- return;
- }
- else
- if (PyInt_Check(obj))
- {
- PRINTMARK();
-#ifdef _LP64
- pc->PyTypeToJSON = PyIntToINT64; tc->type = JT_LONG;
-#else
- pc->PyTypeToJSON = PyIntToINT32; tc->type = JT_INT;
-#endif
- return;
- }
- else
- if (PyString_Check(obj) && !PyObject_HasAttrString(obj, "__json__"))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyStringToUTF8; tc->type = JT_UTF8;
- return;
- }
- else
- if (PyUnicode_Check(obj))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyUnicodeToUTF8; tc->type = JT_UTF8;
- return;
- }
- else
- if (PyFloat_Check(obj) || (type_decimal && PyObject_IsInstance(obj, type_decimal)))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyFloatToDOUBLE; tc->type = JT_DOUBLE;
- return;
- }
- else
- if (PyDateTime_Check(obj))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyDateTimeToINT64; tc->type = JT_LONG;
- return;
- }
- else
- if (PyDate_Check(obj))
- {
- PRINTMARK();
- pc->PyTypeToJSON = PyDateToINT64; tc->type = JT_LONG;
- return;
- }
- else
- if (obj == Py_None)
- {
- PRINTMARK();
- tc->type = JT_NULL;
- return;
- }
-
-ISITERABLE:
- if (PyDict_Check(obj))
- {
- PRINTMARK();
- tc->type = JT_OBJECT;
- SetupDictIter(obj, pc, enc);
- Py_INCREF(obj);
- return;
- }
- else
- if (PyList_Check(obj))
- {
- PRINTMARK();
- tc->type = JT_ARRAY;
- pc->iterEnd = List_iterEnd;
- pc->iterNext = List_iterNext;
- pc->iterGetValue = List_iterGetValue;
- pc->iterGetName = List_iterGetName;
- GET_TC(tc)->index = 0;
- GET_TC(tc)->size = PyList_GET_SIZE( (PyObject *) obj);
- return;
- }
- else
- if (PyTuple_Check(obj))
- {
- PRINTMARK();
- tc->type = JT_ARRAY;
- pc->iterEnd = Tuple_iterEnd;
- pc->iterNext = Tuple_iterNext;
- pc->iterGetValue = Tuple_iterGetValue;
- pc->iterGetName = Tuple_iterGetName;
- GET_TC(tc)->index = 0;
- GET_TC(tc)->size = PyTuple_GET_SIZE( (PyObject *) obj);
- GET_TC(tc)->itemValue = NULL;
-
- return;
- }
-
- if (PyObject_HasAttrString(obj, "toDict"))
- {
- PyObject* toDictFunc = PyObject_GetAttrString(obj, "toDict");
- PyObject* tuple = PyTuple_New(0);
- PyObject* toDictResult = PyObject_Call(toDictFunc, tuple, NULL);
- Py_DECREF(tuple);
- Py_DECREF(toDictFunc);
-
- if (toDictResult == NULL)
- {
- goto INVALID;
- }
-
- if (!PyDict_Check(toDictResult))
- {
- Py_DECREF(toDictResult);
- tc->type = JT_NULL;
- return;
- }
-
- PRINTMARK();
- tc->type = JT_OBJECT;
- SetupDictIter(toDictResult, pc, enc);
- return;
- }
- else
- if (PyObject_HasAttrString(obj, "__json__"))
- {
- PyObject* toJSONFunc = PyObject_GetAttrString(obj, "__json__");
- PyObject* tuple = PyTuple_New(0);
- PyObject* toJSONResult = PyObject_Call(toJSONFunc, tuple, NULL);
- Py_DECREF(tuple);
- Py_DECREF(toJSONFunc);
-
- if (toJSONResult == NULL)
- {
- goto INVALID;
- }
-
- if (PyErr_Occurred())
- {
- Py_DECREF(toJSONResult);
- goto INVALID;
- }
-
- if (!PyString_Check(toJSONResult) && !PyUnicode_Check(toJSONResult))
- {
- Py_DECREF(toJSONResult);
- PyErr_Format (PyExc_TypeError, "expected string");
- goto INVALID;
- }
-
- PRINTMARK();
- pc->PyTypeToJSON = PyRawJSONToUTF8;
- tc->type = JT_RAW;
- GET_TC(tc)->rawJSONValue = toJSONResult;
- return;
- }
-
- PRINTMARK();
- PyErr_Clear();
-
- objRepr = PyObject_Repr(obj);
-#if PY_MAJOR_VERSION >= 3
- PyObject* str = PyUnicode_AsEncodedString(objRepr, "utf-8", "~E~");
- PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(str));
- Py_XDECREF(str);
-#else
- PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(objRepr));
-#endif
- Py_DECREF(objRepr);
-
-INVALID:
- PRINTMARK();
- tc->type = JT_INVALID;
- PyObject_Free(tc->prv);
- tc->prv = NULL;
- return;
-}
-
-void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)
-{
- Py_XDECREF(GET_TC(tc)->newObj);
-
- if (tc->type == JT_RAW)
- {
- Py_XDECREF(GET_TC(tc)->rawJSONValue);
- }
- PyObject_Free(tc->prv);
- tc->prv = NULL;
-}
-
-const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
-{
- return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
-}
-
-JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
-{
- JSINT64 ret;
- GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
- return ret;
-}
-
-JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
-{
- JSUINT64 ret;
- GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
- return ret;
-}
-
-JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
-{
- JSINT32 ret;
- GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
- return ret;
-}
-
-double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
-{
- double ret;
- GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
- return ret;
-}
-
-static void Object_releaseObject(JSOBJ _obj)
-{
- Py_DECREF( (PyObject *) _obj);
-}
-
-int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->iterNext(obj, tc);
-}
-
-void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
-{
- GET_TC(tc)->iterEnd(obj, tc);
-}
-
-JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
-{
- return GET_TC(tc)->iterGetValue(obj, tc);
-}
-
-char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
-{
- return GET_TC(tc)->iterGetName(obj, tc, outLen);
-}
-
-PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
-{
- static char *kwlist[] = { "obj", "ensure_ascii", "double_precision", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", NULL };
-
- char buffer[65536];
- char *ret;
- PyObject *newobj;
- PyObject *oinput = NULL;
- PyObject *oensureAscii = NULL;
- PyObject *oencodeHTMLChars = NULL;
- PyObject *oescapeForwardSlashes = NULL;
- PyObject *osortKeys = NULL;
-
- JSONObjectEncoder encoder =
- {
- Object_beginTypeContext,
- Object_endTypeContext,
- Object_getStringValue,
- Object_getLongValue,
- Object_getUnsignedLongValue,
- Object_getIntValue,
- Object_getDoubleValue,
- Object_iterNext,
- Object_iterEnd,
- Object_iterGetValue,
- Object_iterGetName,
- Object_releaseObject,
- PyObject_Malloc,
- PyObject_Realloc,
- PyObject_Free,
- -1, //recursionMax
- 10, // default double precision setting
- 1, //forceAscii
- 0, //encodeHTMLChars
- 1, //escapeForwardSlashes
- 0, //sortKeys
- 0, //indent
- NULL, //prv
- };
-
-
- PRINTMARK();
-
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OiOOOi", kwlist, &oinput, &oensureAscii, &encoder.doublePrecision, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent))
- {
- return NULL;
- }
-
- if (oensureAscii != NULL && !PyObject_IsTrue(oensureAscii))
- {
- encoder.forceASCII = 0;
- }
-
- if (oencodeHTMLChars != NULL && PyObject_IsTrue(oencodeHTMLChars))
- {
- encoder.encodeHTMLChars = 1;
- }
-
- if (oescapeForwardSlashes != NULL && !PyObject_IsTrue(oescapeForwardSlashes))
- {
- encoder.escapeForwardSlashes = 0;
- }
-
- if (osortKeys != NULL && PyObject_IsTrue(osortKeys))
- {
- encoder.sortKeys = 1;
- }
-
- PRINTMARK();
- ret = JSON_EncodeObject (oinput, &encoder, buffer, sizeof (buffer));
- PRINTMARK();
-
- if (PyErr_Occurred())
- {
- return NULL;
- }
-
- if (encoder.errorMsg)
- {
- if (ret != buffer)
- {
- encoder.free (ret);
- }
-
- PyErr_Format (PyExc_OverflowError, "%s", encoder.errorMsg);
- return NULL;
- }
-
- newobj = PyString_FromString (ret);
-
- if (ret != buffer)
- {
- encoder.free (ret);
- }
-
- PRINTMARK();
-
- return newobj;
-}
-
-PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs)
-{
- PyObject *data;
- PyObject *file;
- PyObject *string;
- PyObject *write;
- PyObject *argtuple;
- PyObject *write_result;
-
- PRINTMARK();
-
- if (!PyArg_ParseTuple (args, "OO", &data, &file))
- {
- return NULL;
- }
-
- if (!PyObject_HasAttrString (file, "write"))
- {
- PyErr_Format (PyExc_TypeError, "expected file");
- return NULL;
- }
-
- write = PyObject_GetAttrString (file, "write");
-
- if (!PyCallable_Check (write))
- {
- Py_XDECREF(write);
- PyErr_Format (PyExc_TypeError, "expected file");
- return NULL;
- }
-
- argtuple = PyTuple_Pack(1, data);
-
- string = objToJSON (self, argtuple, kwargs);
-
- if (string == NULL)
- {
- Py_XDECREF(write);
- Py_XDECREF(argtuple);
- return NULL;
- }
-
- Py_XDECREF(argtuple);
-
- argtuple = PyTuple_Pack (1, string);
- if (argtuple == NULL)
- {
- Py_XDECREF(write);
- return NULL;
- }
- write_result = PyObject_CallObject (write, argtuple);
- if (write_result == NULL)
- {
- Py_XDECREF(write);
- Py_XDECREF(argtuple);
- return NULL;
- }
-
- Py_DECREF(write_result);
- Py_XDECREF(write);
- Py_DECREF(argtuple);
- Py_XDECREF(string);
-
- PRINTMARK();
-
- Py_RETURN_NONE;
-}
diff --git a/srsly/ujson/py_defines.h b/srsly/ujson/py_defines.h
deleted file mode 100644
index 2b38b41..0000000
--- a/srsly/ujson/py_defines.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
- * Copyright (c) 1988-1993 The Regents of the University of California.
- * Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include
-
-#if PY_MAJOR_VERSION >= 3
-
-#define PyInt_Check PyLong_Check
-#define PyInt_AS_LONG PyLong_AsLong
-#define PyInt_FromLong PyLong_FromLong
-
-#define PyString_Check PyBytes_Check
-#define PyString_GET_SIZE PyBytes_GET_SIZE
-#define PyString_AS_STRING PyBytes_AS_STRING
-
-#define PyString_FromString PyUnicode_FromString
-
-#endif
diff --git a/srsly/ujson/ujson.c b/srsly/ujson/ujson.c
deleted file mode 100644
index d0b15c6..0000000
--- a/srsly/ujson/ujson.c
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
-* Copyright (c) 1988-1993 The Regents of the University of California.
-* Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#include "py_defines.h"
-#include "version.h"
-
-/* objToJSON */
-PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs);
-void initObjToJSON(void);
-
-/* JSONToObj */
-PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs);
-
-/* objToJSONFile */
-PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs);
-
-/* JSONFileToObj */
-PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs);
-
-
-#define ENCODER_HELP_TEXT "Use ensure_ascii=false to output UTF-8. Pass in double_precision to alter the maximum digit precision of doubles. Set encode_html_chars=True to encode < > & as unicode escape sequences. Set escape_forward_slashes=False to prevent escaping / characters."
-
-static PyMethodDef ujsonMethods[] = {
- {"encode", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT},
- {"decode", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."},
- {"dumps", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT},
- {"loads", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."},
- {"dump", (PyCFunction) objToJSONFile, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON file. " ENCODER_HELP_TEXT},
- {"load", (PyCFunction) JSONFileToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as file to dict object structure. Use precise_float=True to use high precision float decoder."},
- {NULL, NULL, 0, NULL} /* Sentinel */
-};
-
-#if PY_MAJOR_VERSION >= 3
-
-static struct PyModuleDef moduledef = {
- PyModuleDef_HEAD_INIT,
- "ujson",
- 0, /* m_doc */
- -1, /* m_size */
- ujsonMethods, /* m_methods */
- NULL, /* m_reload */
- NULL, /* m_traverse */
- NULL, /* m_clear */
- NULL /* m_free */
-};
-
-#define PYMODINITFUNC PyObject *PyInit_ujson(void)
-#define PYMODULE_CREATE() PyModule_Create(&moduledef)
-#define MODINITERROR return NULL
-
-#else
-
-#define PYMODINITFUNC PyMODINIT_FUNC initujson(void)
-#define PYMODULE_CREATE() Py_InitModule("ujson", ujsonMethods)
-#define MODINITERROR return
-
-#endif
-
-PYMODINITFUNC
-{
- PyObject *module;
- PyObject *version_string;
-
- initObjToJSON();
- module = PYMODULE_CREATE();
-
- if (module == NULL)
- {
- MODINITERROR;
- }
-
- version_string = PyString_FromString (UJSON_VERSION);
- PyModule_AddObject (module, "__version__", version_string);
-
-#if PY_MAJOR_VERSION >= 3
- return module;
-#endif
-}
diff --git a/srsly/ujson/version.h b/srsly/ujson/version.h
deleted file mode 100644
index f0ce6bb..0000000
--- a/srsly/ujson/version.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
-Developed by ESN, an Electronic Arts Inc. studio.
-Copyright (c) 2014, Electronic Arts Inc.
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-* Redistributions of source code must retain the above copyright
-notice, this list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright
-notice, this list of conditions and the following disclaimer in the
-documentation and/or other materials provided with the distribution.
-* Neither the name of ESN, Electronic Arts Inc. nor the
-names of its contributors may be used to endorse or promote products
-derived from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
-Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
-http://code.google.com/p/stringencoders/
-Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
-
-Numeric decoder derived from from TCL library
-http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
- * Copyright (c) 1988-1993 The Regents of the University of California.
- * Copyright (c) 1994 Sun Microsystems, Inc.
-*/
-
-#define UJSON_VERSION "1.35"