diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c019f73..c84ab7e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## vTBD +- [BP-1719](https://movai.atlassian.net/browse/BP-1719): Allow for both in-memory and file-system project sources for metadata import. + ## v3.25.5 - [BP-1714](https://movai.atlassian.net/browse/BP-1714): Validate file size and safeguard for noneviction raised errors - Add validation for file size is lower than available memory before writing to Redis. diff --git a/dal/tools/backup.py b/dal/tools/backup.py index 0d939550..6d8bca86 100644 --- a/dal/tools/backup.py +++ b/dal/tools/backup.py @@ -17,7 +17,8 @@ import sys from importlib import import_module import warnings -from pathlib import Path +from abc import ABC, abstractmethod +from typing import Iterator, List, Tuple from dal.movaidb import MovaiDB @@ -89,6 +90,196 @@ def get_class(scope): return Factory.CLASSES_CACHE[scope] +class ProjectSource(ABC): + """Project metadata source addressed by project-relative POSIX paths. + + Importer historically reads from a project folder on disk. This interface + keeps the importer talking in relative paths like ``Callback/foo.json`` so + another source, such as a gRPC payload, can provide the same files without + recreating that folder on the filesystem. + """ + + @staticmethod + def normalize_path(relative_path: str) -> str: + """Normalize and validate a source-relative path.""" + if relative_path is None: + raise ValueError("relative_path cannot be None") + + path = str(relative_path).replace("\\", "/") + if os.path.isabs(path): + raise ValueError(f"absolute paths are not allowed: {relative_path}") + + normalized = os.path.normpath(path).replace("\\", "/") + if normalized == ".": + return "" + if normalized == ".." or normalized.startswith("../"): + raise ValueError(f"path escapes project source: {relative_path}") + + return normalized + + @staticmethod + def join(*parts: str) -> str: + """Join path parts into a normalized source-relative path.""" + return ProjectSource.normalize_path( + "/".join(str(part).strip("/") for part in parts if part) + ) + + @abstractmethod + def list_dir(self, relative_dir: str) -> List[str]: + """Return direct child names for a source-relative directory.""" + + @abstractmethod + def is_file(self, relative_path: str) -> bool: + """Return whether the source-relative path is a regular file.""" + + @abstractmethod + def is_dir(self, relative_path: str) -> bool: + """Return whether the source-relative path is a directory.""" + + @abstractmethod + def read_text(self, relative_path: str) -> str: + """Read a source-relative file as UTF-8 text.""" + + @abstractmethod + def read_bytes(self, relative_path: str) -> bytes: + """Read a source-relative file as bytes.""" + + @abstractmethod + def walk(self, relative_dir: str) -> Iterator[Tuple[str, List[str], List[str]]]: + """Yield ``(relative_root, dirs, files)`` for a source-relative directory.""" + + @abstractmethod + def display_path(self, relative_path: str) -> str: + """Return a human-readable path for logs and imported InstallPath.""" + + +class FilesystemProjectSource(ProjectSource): + """Filesystem-backed project source.""" + + def __init__(self, root: str): + self.root = os.path.abspath(root) + + def _resolve(self, relative_path: str) -> str: + normalized = self.normalize_path(relative_path) + resolved = os.path.abspath(os.path.join(self.root, normalized)) + + try: + common_path = os.path.commonpath([self.root, resolved]) + except ValueError as exc: + raise ValueError(f"path escapes project source: {relative_path}") from exc + + if common_path != self.root: + raise ValueError(f"path escapes project source: {relative_path}") + + return resolved + + def list_dir(self, relative_dir: str) -> List[str]: + return os.listdir(self._resolve(relative_dir)) + + def is_file(self, relative_path: str) -> bool: + return os.path.isfile(self._resolve(relative_path)) + + def is_dir(self, relative_path: str) -> bool: + return os.path.isdir(self._resolve(relative_path)) + + def read_text(self, relative_path: str) -> str: + with open(self._resolve(relative_path), encoding="utf-8") as file: + return file.read() + + def read_bytes(self, relative_path: str) -> bytes: + with open(self._resolve(relative_path), "rb") as file: + return file.read() + + def walk(self, relative_dir: str) -> Iterator[Tuple[str, List[str], List[str]]]: + start = self._resolve(relative_dir) + for root, dirs, files in os.walk(start): + relative_root = os.path.relpath(root, self.root) + if relative_root == ".": + relative_root = "" + yield relative_root.replace("\\", "/"), dirs, files + + def display_path(self, relative_path: str) -> str: + return self._resolve(relative_path) + + +class InMemoryProjectSource(ProjectSource): + """In-memory project source for gRPC payloads.""" + + def __init__(self, files: dict): + self.files = {} + for path, content in files.items(): + normalized = self.normalize_path(path) + if isinstance(content, str): + content = content.encode("utf-8") + self.files[normalized] = content + + def list_dir(self, relative_dir: str) -> List[str]: + prefix = self.normalize_path(relative_dir) + if prefix and not prefix.endswith("/"): + prefix += "/" + + children = set() + for path in self.files: + if not path.startswith(prefix): + continue + children.add(path[len(prefix) :].split("/", 1)[0]) + + return sorted(children) + + def is_file(self, relative_path: str) -> bool: + normalized = self.normalize_path(relative_path) + return normalized in self.files + + def is_dir(self, relative_path: str) -> bool: + prefix = self.normalize_path(relative_path) + if prefix and not prefix.endswith("/"): + prefix += "/" + return any(path.startswith(prefix) for path in self.files) + + def read_text(self, relative_path: str) -> str: + normalized = self.normalize_path(relative_path) + if normalized not in self.files: + raise FileNotFoundError(f"File not found: {relative_path}") + return self.read_bytes(normalized).decode("utf-8") + + def read_bytes(self, relative_path: str) -> bytes: + normalized = self.normalize_path(relative_path) + if normalized not in self.files: + raise FileNotFoundError(f"File not found: {relative_path}") + return self.files[normalized] + + def walk(self, relative_dir: str): + """ + Walk the in-memory file structure, yielding (relative_root, dirs, files). + """ + root = self.normalize_path(relative_dir) + stack = [root] + + while stack: + current = stack.pop() + dirs = [] + files = [] + + for path in self.files: + if not path.startswith(current): + continue + relative_path = path[len(current) :].lstrip("/") + if "/" in relative_path: + dir_name = relative_path.split("/", 1)[0] + if dir_name not in dirs: + dirs.append(dir_name) + else: + files.append(relative_path) + + yield current, dirs, files + + for dir_name in dirs: + stack.append(self.join(current, dir_name)) + + def display_path(self, relative_path: str) -> str: + return self.normalize_path(relative_path) + + class Backup: """Base class for Importer and Exporter.""" @@ -115,7 +306,7 @@ class Backup: def __init__(self, project, debug: bool = False, recursive=True): self.recursive = recursive - self.project_path = os.path.abspath(project) + self.project_path = os.path.abspath(project) if project else "" if debug: self.log = print @@ -126,6 +317,47 @@ def __init__(self, project, debug: bool = False, recursive=True): # Connect to Redis MovaiDB(db="global") + @staticmethod + def parse_manifest(manifest_lines: List[str], all_default=[None]) -> dict: + """ + Parses manifest lines and returns the declared objects. + + Args: + manifest_lines (List[str]): Lines of the manifest file. + all_default: Default value for all objects, applied when '*' is found. + + Returns: + dict: Parsed objects from the manifest. + """ + objects = {} + + for line in manifest_lines: + line = line.split("#", 1)[0].strip() + if not line or ":" not in line: + continue + _type, _name = [part.strip() for part in line.split(":", 1)] + if not _type or not _name: + # what? + continue + if _type not in objects.keys(): + objects[_type] = [] + if _name != "*": + objects[_type] += [_name] + # force it + continue + # else, * + names = all_default + try: + all_default.__getattribute__("__call__") + names = all_default(_type) + except AttributeError: + pass + finally: + objects[_type] += names + # endfor line in manifest + # close manifest + return objects + @staticmethod def read_manifest(manifest: str, all_default=[None]) -> dict: """Reads a manifest file and returns the declared objects. @@ -135,35 +367,20 @@ def read_manifest(manifest: str, all_default=[None]) -> dict: all_default: Default value for all objects, applied when '*' is found. """ - objects = {} # let it blow with open(manifest) as manifest_file: - for line in manifest_file: - line = line.split("#", 1)[0].strip() - if not line or ":" not in line: - continue - _type, _name = [part.strip() for part in line.split(":", 1)] - if not _type or not _name: - # what? - continue - if _type not in objects.keys(): - objects[_type] = [] - if _name != "*": - objects[_type] += [_name] - # force it - continue - # else, * - names = all_default - try: - all_default.__getattribute__("__call__") - names = all_default(_type) - except AttributeError: - pass - finally: - objects[_type] += names - # endfor line in manifest - # close manifest - return objects + return Backup.parse_manifest(manifest_file.readlines(), all_default) + + @staticmethod + def read_manifest_content(manifest_content: str, all_default=[None]) -> dict: + """Reads a manifest content string and returns the declared objects. + + Args: + manifest_content (str): Content of the manifest file. + all_default: Default value for all objects, applied when '*' is found. + + """ + return Backup.parse_manifest(manifest_content.splitlines(), all_default) def run(self, objects: dict = {}): raise NotImplementedError @@ -179,6 +396,7 @@ class Importer(Backup): def __init__( self, project, + source: ProjectSource = None, force: bool = False, dry: bool = False, clean_old_data: bool = False, @@ -186,12 +404,15 @@ def __init__( ): super().__init__(project, **kwargs) + self.project_path = os.path.abspath(project) if project else "" + self.source = source or FilesystemProjectSource(self.project_path) + self.force = force self.dry_run = dry self.validate = not force self._delete = clean_old_data - if not os.path.isdir(self.project_path): + if not self.source.is_dir(""): raise ImportException("Project path does not exist") self._imported = {} @@ -199,14 +420,22 @@ def __init__( if self.dry_run: # override import_data to not import data self._import_data = lambda scope, name, _, __: self.set_imported(scope, name) - # remove project root dir from it, plus an extra '/' (+1) - self.dry_print = lambda *paths: [ - print(path[len(self.project_path) + 1 :]) for path in paths - ] + self.dry_print = lambda *paths: [print(path) for path in paths] else: self._db = MovaiDB() self.dry_print = lambda *paths: None + def _read_json(self, relative_path: str) -> dict: + """Read a JSON file from the configured project source.""" + return json.loads(self.source.read_text(relative_path)) + + def _read_json_or_pickle(self, relative_path: str) -> dict: + """Read a JSON file, falling back to pickle for legacy metadata.""" + try: + return self._read_json(relative_path) + except (json.JSONDecodeError, UnicodeError): + return pickle.loads(self.source.read_bytes(relative_path)) + def get_objs(self, scope): """Get all objects of a given scope. @@ -273,12 +502,10 @@ def default_extractor(file): def matcher(file): return True - scope_path = os.path.join(self.project_path, scope) - try: return [ - (extractor(file), os.path.join(scope_path, file)) - for file in os.listdir(scope_path) + (extractor(file), ProjectSource.join(scope, file)) + for file in self.source.list_dir(scope) if matcher(file) ] except FileNotFoundError: @@ -303,14 +530,9 @@ def _get_files(self, scope, names, build=None, match=None): def builder(name): return name - matcher = match - if matcher is None: - # maybe default to os.path.isfile(file) ? - def matcher(file): - return True + matcher = self.source.is_file if match is None else match - scope_path = os.path.join(self.project_path, scope) - if not os.path.isdir(scope_path): + if not self.source.is_dir(scope): if self.validate: raise ImportException(f"{scope} directory not found") else: @@ -318,7 +540,7 @@ def matcher(file): files = [] for name in names: - file_path = os.path.join(scope_path, builder(name)) + file_path = ProjectSource.join(scope, builder(name)) if not matcher(file_path): _msg = f"{scope}:{name} not found" self.log(_msg) @@ -338,7 +560,7 @@ def get_files( extract=lambda file: os.path.splitext(file)[0], build=lambda name: f"{name}.json", list_match=lambda file: file.endswith(".json"), - get_match=os.path.isfile, + get_match=None, ): if names is None: return self._list_files(scope, extract, list_match) @@ -354,7 +576,7 @@ def _import_data(self, scope, name, data, path): # pylint: disable=method-hidde path (str): Path of origin of the data. """ - data[scope][name]["InstallPath"] = path + data[scope][name]["InstallPath"] = self.source.display_path(path) try: ScopeClass = Factory.get_class(scope) @@ -412,12 +634,7 @@ def import_default(self, scope, names=None): self.dry_print(file_path) - try: - with open(file_path) as file: - data = json.load(file) - except json.JSONDecodeError: - with open(file_path, "rb") as file: - data = pickle.load(file) + data = self._read_json_or_pickle(file_path) self._import_data(scope, name, data, file_path) @@ -433,17 +650,11 @@ def import_configuration(self, names=None): self.dry_print(file_path) self.dry_print(yaml_path) - try: - with open(file_path) as file: - data = json.load(file) - except json.JSONDecodeError: - with open(file_path, "rb") as file: - data = pickle.load(file) + data = self._read_json_or_pickle(file_path) # add yaml code try: - with open(yaml_path) as file: - data["Configuration"][name]["Yaml"] = file.read() + data["Configuration"][name]["Yaml"] = self.source.read_text(yaml_path) except: # probably file not found pass @@ -465,15 +676,9 @@ def import_callback(self, names=None): # .json and .py self.dry_print(file_path, code_path) - try: - with open(file_path) as file: - data = json.load(file) - except json.JSONDecodeError: - with open(file_path, "rb") as file: - data = pickle.load(file) - if os.path.isfile(code_path): - with open(code_path) as code: - data["Callback"][name]["Code"] = code.read() + data = self._read_json_or_pickle(file_path) + if self.source.is_file(code_path): + data["Callback"][name]["Code"] = self.source.read_text(code_path) self._import_data("Callback", name, data, file_path) @@ -484,8 +689,8 @@ def import_package(self, names=None): names, extract=None, build=None, - list_match=None, - get_match=os.path.isdir, + list_match=lambda file: self.source.is_dir(ProjectSource.join("Package", file)), + get_match=self.source.is_dir, ) for name, dir_path in dirs: @@ -503,13 +708,16 @@ def import_package(self, names=None): } } - for root, _, files in os.walk(dir_path): + package_path = dir_path + imported_file_path = package_path + + for root, _, files in self.source.walk(dir_path): for file in files: - file_path = os.path.join(root, file) + file_path = ProjectSource.join(root, file) + imported_file_path = file_path self.dry_print(file_path) checksum = hashlib.md5() - with open(file_path, "rb") as fd: - contents = fd.read() + contents = self.source.read_bytes(file_path) checksum.update(contents) try: @@ -517,13 +725,13 @@ def import_package(self, names=None): except UnicodeError: pass - file_dict[file_path.replace(dir_path, "", 1)[1:]] = { + file_dict[file_path.replace(package_path, "", 1)[1:]] = { "Value": contents, "Checksum": checksum.hexdigest(), "FileLabel": file, } - self._import_data("Package", name, data, file_path) + self._import_data("Package", name, data, imported_file_path) def import_message(self, names=None): files = self.get_files("Message", names) @@ -537,21 +745,18 @@ def import_message(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - json_dict = json.load(file) + json_dict = self._read_json(file_path) messages = [(k, msg) for k in json_dict for msg in json_dict[k]] - base_path = os.path.join(self.project_path, "Message", name) + base_path = ProjectSource.join("Message", name) for pack in messages: - message_path = os.path.join(base_path, pack[1]) + message_path = ProjectSource.join(base_path, pack[1]) this_dict = {} cmp_path = f"{message_path}.compiled" src_path = f"{message_path}.source" self.dry_print(cmp_path, src_path) - with open(cmp_path) as fd: - this_dict["Compiled"] = fd.read() - with open(src_path) as fd: - this_dict["Source"] = fd.read() + this_dict["Compiled"] = self.source.read_text(cmp_path) + this_dict["Source"] = self.source.read_text(src_path) msg_dict[pack[0]][pack[1]] = this_dict # clean it @@ -579,8 +784,7 @@ def to_import(ports): files = self.get_files("Ports", packages) for package, pkg_path in files: - with open(pkg_path) as fd: - pkg_json = json.load(fd) + pkg_json = self._read_json(pkg_path) self.dry_print(pkg_path) @@ -604,8 +808,7 @@ def import_tasktemplate(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) # dependencies if self.recursive: @@ -623,8 +826,7 @@ def import_flow(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) # dependencies # nodes @@ -659,8 +861,7 @@ def import_node(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) # dependencies if self.recursive: @@ -677,8 +878,7 @@ def import_shareddataentry(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) # dependencies if self.recursive: @@ -700,8 +900,7 @@ def import_statemachine(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) # dependencies if self.recursive: @@ -726,12 +925,7 @@ def import_graphicscene(self, names=None): self.log(file_path) - try: - with open(file_path) as file: - data = json.load(file) - except (json.JSONDecodeError, UnicodeError): - with open(file_path, "rb") as file: - data = pickle.load(file) + data = self._read_json_or_pickle(file_path) # dependencies if self.recursive: @@ -758,25 +952,26 @@ def import_translation(self, names=None): self.dry_print(file_path) - with open(file_path) as file: - data = json.load(file) + data = self._read_json(file_path) data["Translation"][name]["Translations"] = {} - parent = Path(file_path).parent - lang_pattern = re.compile(f"^{name}\.([a-z]+)\.po$") + parent = os.path.dirname(file_path) + lang_pattern = re.compile(rf"^{re.escape(name)}\.([a-z]+)\.po$") # look for po files - for file in parent.iterdir(): - if not file.is_file(): + for file in self.source.list_dir(parent): + translation_path = ProjectSource.join(parent, file) + if not self.source.is_file(translation_path): continue - lang = lang_pattern.findall(file.name) + lang = lang_pattern.findall(file) if not lang: continue - with open(file) as data_file: - data["Translation"][name]["Translations"][lang[0]] = {"po": data_file.read()} + data["Translation"][name]["Translations"][lang[0]] = { + "po": self.source.read_text(translation_path) + } self._import_data("Translation", name, data, file_path) diff --git a/pyproject.toml b/pyproject.toml index 4edd1ede..00415780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dal = [ line-length = 100 [tool.bumpversion] -current_version = "3.25.5.1" +current_version = "3.25.6.0" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)?(\\.(?P\\d+))?" serialize = ["{major}.{minor}.{patch}.{build}"] diff --git a/tests/unit/with_db/test_tools_backup.py b/tests/unit/with_db/test_tools_backup.py index cd8bc332..32e26487 100644 --- a/tests/unit/with_db/test_tools_backup.py +++ b/tests/unit/with_db/test_tools_backup.py @@ -3,6 +3,8 @@ import json from filecmp import cmpfiles +import pytest + class TestToolsBackup: """Test suite for the backup tool's import/export functionality. @@ -29,11 +31,40 @@ class TestToolsBackup: - Invalid data handling """ - def test_import_manifest(self, global_db, metadata_folder, manifest_file): - from dal.tools.backup import Importer + @staticmethod + def get_source(manifest_file, metadata_folder, source_type): + """Get the appropriate source based on the source_type.""" + from dal.tools.backup import FilesystemProjectSource, InMemoryProjectSource + + if source_type == "FilesystemProjectSource": + return FilesystemProjectSource(metadata_folder) + else: + files = {manifest_file.name: manifest_file.read_bytes()} + + for file in metadata_folder.rglob("*"): + if file.is_file(): + path = InMemoryProjectSource.normalize_path(file.relative_to(metadata_folder)) + if path.startswith("metadata/"): + path = path[len("metadata/") :] + + files[path] = file.read_bytes() + + return InMemoryProjectSource(files) + + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_manifest(self, global_db, metadata_folder, manifest_file, source_type): + from dal.tools.backup import Importer, Backup tool = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -41,8 +72,11 @@ def test_import_manifest(self, global_db, metadata_folder, manifest_file): clean_old_data=True, ) - objects = tool.read_manifest(manifest_file) - + objects = ( + tool.read_manifest(manifest_file) + if source_type == "FilesystemProjectSource" + else Backup.read_manifest_content(manifest_file.read_text(), tool.get_objs) + ) tool.run(objects) def test_export_manifest(self, global_db, manifest_file, tmp_path): @@ -77,12 +111,23 @@ def test_relative_import(self, global_db, metadata_folder, manifest_file): tool.run(objects) - def test_import_export_alert(self, global_db, metadata_folder, manifest_file, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_alert( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test alert import and export.""" from dal.tools.backup import Importer, Exporter importer = Importer( metadata_folder, + self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -110,7 +155,17 @@ def test_import_export_alert(self, global_db, metadata_folder, manifest_file, tm assert imported_content == exported_content - def test_import_export_callback(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_callback( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test callback import with .py file and export.""" from dal.tools.backup import Importer, Exporter from dal.scopes.callback import Callback @@ -118,6 +173,7 @@ def test_import_export_callback(self, global_db, metadata_folder, tmp_path): # Import importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -159,7 +215,17 @@ def test_import_export_callback(self, global_db, metadata_folder, tmp_path): exported_content = json.load(callback_json) assert original_content == exported_content - def test_import_export_configuration(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_configuration( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test configuration import with .yaml file and export.""" from dal.tools.backup import Importer, Exporter from dal.scopes.configuration import Configuration @@ -167,6 +233,7 @@ def test_import_export_configuration(self, global_db, metadata_folder, tmp_path) # Import importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -208,7 +275,17 @@ def test_import_export_configuration(self, global_db, metadata_folder, tmp_path) exported_content = json.load(exported_json) assert original_content == exported_content - def test_import_export_flow(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_flow( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test flow import and export.""" from dal.tools.backup import Importer, Exporter from dal.scopes.flow import Flow @@ -216,6 +293,7 @@ def test_import_export_flow(self, global_db, metadata_folder, tmp_path): # Import importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -249,7 +327,17 @@ def test_import_export_flow(self, global_db, metadata_folder, tmp_path): exported_content = json.load(exported) assert imported_content == exported_content - def test_import_flow_with_dependencies(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_flow_with_dependencies( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test flow import with recursive dependencies (nodes and subflows).""" from dal.tools.backup import Importer from dal.scopes.flow import Flow @@ -258,6 +346,7 @@ def test_import_flow_with_dependencies(self, global_db, metadata_folder, tmp_pat # Import with recursive=True to import dependencies importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -286,7 +375,17 @@ def test_import_flow_with_dependencies(self, global_db, metadata_folder, tmp_pat subflow = Flow("flow_with_duplicated_subflow") assert subflow.Label == "flow_with_duplicated_subflow" - def test_import_export_node(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_node( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test node import and export.""" from dal.tools.backup import Importer, Exporter from dal.scopes.node import Node @@ -294,6 +393,7 @@ def test_import_export_node(self, global_db, metadata_folder, tmp_path): # Import importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -326,7 +426,17 @@ def test_import_export_node(self, global_db, metadata_folder, tmp_path): exported_content = json.load(exported) assert imported_content == exported_content - def test_import_node_multiple(self, global_db, metadata_folder, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_node_multiple( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test importing multiple nodes at once.""" from dal.tools.backup import Importer from dal.scopes.node import Node @@ -334,6 +444,7 @@ def test_import_node_multiple(self, global_db, metadata_folder, tmp_path): # Import importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -349,13 +460,28 @@ def test_import_node_multiple(self, global_db, metadata_folder, tmp_path): node = Node(node_name) assert node.Label == node_name - def test_import_package(self, global_db, metadata_folder, metadata2_folder): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_package( + self, global_db, metadata_folder, metadata2_folder, manifest_file, source_type + ): """Test that consecutive imports merge package contents correctly.""" from dal.tools.backup import Importer from dal.scopes.package import Package + # Clean any existing Package data from previous test runs + # (clean_old_data doesn't work for Package due to SKIP_SCOPE_DELETE) + global_db.delete_by_args("Package", Name="maps") + importer1 = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -365,6 +491,7 @@ def test_import_package(self, global_db, metadata_folder, metadata2_folder): importer2 = Importer( metadata2_folder, + source=self.get_source(manifest_file, metadata2_folder, source_type), force=True, dry=False, debug=False, @@ -387,12 +514,23 @@ def test_import_package(self, global_db, metadata_folder, metadata2_folder): "delete_me2.yaml", } - def test_import_export_translation(self, global_db, metadata_folder, manifest_file, tmp_path): + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) + def test_import_export_translation( + self, global_db, metadata_folder, manifest_file, tmp_path, source_type + ): """Test translation import and export.""" from dal.tools.backup import Importer, Exporter importer = Importer( metadata_folder, + source=self.get_source(manifest_file, metadata_folder, source_type), force=True, dry=False, debug=False, @@ -425,14 +563,30 @@ def test_import_export_translation(self, global_db, metadata_folder, manifest_fi assert not diff assert not err + @pytest.mark.parametrize( + "source_type", + [ + "FilesystemProjectSource", + "InMemoryProjectSource", + ], + ids=["FilesystemProjectSource", "InMemoryProjectSource"], + ) def test_import_invalid_data( - self, global_db, metadata_folder_invalid_data, manifest_file_invalid_data, capsys + self, + global_db, + metadata_folder_invalid_data, + manifest_file_invalid_data, + capsys, + source_type, ): """Test import validates and reports invalid data.""" from dal.tools.backup import Importer tool = Importer( metadata_folder_invalid_data, + source=self.get_source( + manifest_file_invalid_data, metadata_folder_invalid_data, source_type + ), force=True, dry=False, debug=False,