From ca614ac928062053b56193fa9f0e23fceb29c7b2 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Fri, 19 Dec 2025 11:19:12 -0700 Subject: [PATCH 1/5] `Context` rework using `ExternalStorage` (#690) * feat(warehouse): initial re-implementation of StorageManager * feat(warehouse): initial implementation with test * feat(warehouse): add support for new context object * feat(warehouse): experimental implementation of executing a DAG * feat!: Replace Context with new implementation BREAKING CHANGE: This completely breaks existing `Context` to leverage our new context management. * test(warehouse): add a context test * test: cleanup test_protocolunit.py * refactor(StorageManager)!: make _convert_to_namespace a static method * test(ProtocolDAG): fix ProtocolDAG tests * test(Protocol): cleanup testing for transformation and protocol * test(Protocol): remove reliance on mixin in two cases since StorageManger isn't tokenizable * chore: remove commented out code * Revert "chore: remove commented out code" This reverts commit 7e6bb03fe90563c6529a646284bfa8c94f85f339. * chore: fix comments and cleanup code * fix: satisfy typing * refactor(StorageManager)!: make _convert_to_namespace, convert_to_namespace * docs: storagemanager * docs: add docstrings to Context object * Update gufe/protocols/protocolunit.py Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * fix: indents from code review web ui * fix: remove TODO * chore: rename scratch_path to scratch_dir * chore: remove dead comment * chore: remove redudndant fixtures * refactor: rename convert_to_namespace to append_to_namespace * docs: added news * Apply suggestions from code review on news item Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * fix docs env build by mocking zstandard * test: add a context test for validating cleanup --------- Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> Co-authored-by: Alyssa Travitz --- gufe/protocols/protocoldag.py | 78 ++++--- gufe/protocols/protocolunit.py | 92 +++++++- gufe/storage/__init__.py | 2 + gufe/storage/storagemanager.py | 108 ++++++++++ gufe/tests/storage/test_storagemanager.py | 249 ++++++++++++++++++++++ gufe/tests/test_protocol.py | 216 ++++++++++--------- gufe/tests/test_protocoldag.py | 155 ++++++++------ gufe/tests/test_protocolunit.py | 233 ++++++++++++++------ gufe/tests/test_transformation.py | 171 +++++++-------- news/rework-context.rst | 25 +++ 10 files changed, 972 insertions(+), 357 deletions(-) create mode 100644 gufe/storage/storagemanager.py create mode 100644 gufe/tests/storage/test_storagemanager.py create mode 100644 news/rework-context.rst diff --git a/gufe/protocols/protocoldag.py b/gufe/protocols/protocoldag.py index 538b0abe2..fc12b9d09 100644 --- a/gufe/protocols/protocoldag.py +++ b/gufe/protocols/protocoldag.py @@ -13,6 +13,8 @@ import networkx as nx +from gufe.storage.externalresource.base import ExternalStorage + from ..tokenization import GufeKey, GufeTokenizable from .errors import MissingUnitResultError, ProtocolUnitFailureError from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult @@ -368,7 +370,8 @@ def _from_dict(cls, dct: dict): def execute_DAG( protocoldag: ProtocolDAG, *, - shared_basedir: Path, + shared_storage: ExternalStorage, + perm_storage: ExternalStorage, scratch_basedir: Path, stderr_basedir: Path | None = None, stdout_basedir: Path | None = None, @@ -383,35 +386,46 @@ def execute_DAG( Parameters ---------- protocoldag : ProtocolDAG - The :class:``ProtocolDAG`` to execute. - shared_basedir : Path - Filesystem path to use for shared space that persists across whole DAG - execution. Used by a `ProtocolUnit` to pass file contents to dependent - class:``ProtocolUnit`` instances. + The :class:`ProtocolDAG` to execute. + shared_storage : ExternalStorage + Storage for shared files that persist across the entire DAG execution. + Used by ProtocolUnits to pass file contents to dependent ProtocolUnits. + perm_storage : ExternalStorage + Permanent storage for files that should persist after DAG execution. scratch_basedir : Path - Filesystem path to use for `ProtocolUnit` `scratch` space. + Base directory for ProtocolUnit scratch space. Each ProtocolUnit gets + its own scratch directory under this base directory. stderr_basedir : Path | None - Filesystem path to use for `ProtocolUnit` `stderr` archiving. + Base directory for ProtocolUnit stderr archiving. If None, stderr + is not archived. stdout_basedir : Path | None - Filesystem path to use for `ProtocolUnit` `stdout` archiving. + Base directory for ProtocolUnit stdout archiving. If None, stdout + is not archived. keep_shared : bool - If True, don't remove shared directories for `ProtocolUnit`s after - the `ProtocolDAG` is executed. + If True, shared directories are not removed after DAG execution. + Default is False. keep_scratch : bool - If True, don't remove scratch directories for a `ProtocolUnit` after - it is executed. + If True, scratch directories are not removed after each ProtocolUnit + execution. Default is False. raise_error : bool - If True, raise an exception if a ProtocolUnit fails, default True - if False, any exceptions will be stored as `ProtocolUnitFailure` - objects inside the returned `ProtocolDAGResult` + If True, raises an exception when a ProtocolUnit fails. If False, + failures are stored as ProtocolUnitFailure objects in the returned + ProtocolDAGResult. Default is True. n_retries : int - the number of times to attempt, default 0, i.e. try once and only once + Number of times to retry failed ProtocolUnits. Default is 0 (no retries). Returns ------- ProtocolDAGResult - The result of executing the `ProtocolDAG`. - + Result object containing the execution results of all ProtocolUnits + in the DAG, including both successes and failures. + + Notes + ----- + The function executes ProtocolUnits in DAG-dependency order, ensuring that + each ProtocolUnit's dependencies are executed before the ProtocolUnit itself. + If a ProtocolUnit fails and raise_error is True, execution stops immediately. + Otherwise, execution continues with the next ProtocolUnit. """ if n_retries < 0: raise ValueError("Must give positive number of retries") @@ -426,11 +440,8 @@ def execute_DAG( inputs = _pu_to_pur(unit.inputs, results) attempt = 0 + result = None while attempt <= n_retries: - shared = shared_basedir / f"shared_{str(unit.key)}_attempt_{attempt}" - shared_paths.append(shared) - shared.mkdir() - scratch = scratch_basedir / f"scratch_{str(unit.key)}_attempt_{attempt}" scratch.mkdir() @@ -444,17 +455,16 @@ def execute_DAG( stdout = stdout_basedir / f"stdout_{str(unit.key)}_attempt_{attempt}" stdout.mkdir() - context = Context(shared=shared, scratch=scratch, stderr=stderr, stdout=stdout) - # execute - result = unit.execute(context=context, raise_error=raise_error, **inputs) - all_results.append(result) - - # clean up outputs - if stderr: - shutil.rmtree(stderr) - if stdout: - shutil.rmtree(stdout) + with Context( + dag_label=str(protocoldag.key), + unit_label=str(unit.key), + shared_storage=shared_storage, + permanent_storage=perm_storage, + scratch=scratch, + ) as ctx: + result = unit.execute(context=ctx, raise_error=raise_error, **inputs) + all_results.append(result) if not keep_scratch: shutil.rmtree(scratch) @@ -465,7 +475,7 @@ def execute_DAG( break attempt += 1 - if not result.ok(): + if result is not None and not result.ok(): break if not keep_shared: diff --git a/gufe/protocols/protocolunit.py b/gufe/protocols/protocolunit.py index 9f3b90c4f..4452273c4 100644 --- a/gufe/protocols/protocolunit.py +++ b/gufe/protocols/protocolunit.py @@ -10,6 +10,7 @@ import abc import datetime +import shutil import sys import tempfile import traceback @@ -21,21 +22,100 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union +from gufe.storage.externalresource.base import ExternalStorage + +from ..storage.storagemanager import StorageManager from ..tokenization import TOKENIZABLE_REGISTRY, GufeKey, GufeTokenizable from .errors import ExecutionInterrupt -@dataclass class Context: - """Data class for passing around execution context components to + """ + Class for passing around execution context components to `ProtocolUnit._execute`. + This class provides execution context information to ProtocolUnit subclasses + when executing their `_execute` method. + + Parameters + ---------- + dag_label : str + Label for the ProtocolDAG this unit belongs to. + unit_label : str + Label for this specific ProtocolUnit. + scratch_dir : Path + Path to the scratch directory for temporary files. + shared_storage : ExternalStorage + Storage manager for shared resources that can be accessed by other units. + permanent_storage : ExternalStorage + Storage manager for permanent resources that persist after execution. + stderr : Path, optional + Path to directory for capturing stderr output. + stdout : Path, optional + Path to directory for capturing stdout output. + + Attributes + ---------- + scratch : Path + Path to the scratch directory for temporary files. + shared : StorageManager + Storage manager for shared resources. + permanent : StorageManager + Storage manager for permanent resources. + stderr : Path or None + Path to directory for capturing stderr output. + stdout : Path or None + Path to directory for capturing stdout output. + + Notes + ----- + This class implements the context manager protocol, automatically transferring + shared and permanent storage resources when exiting the context. + + Examples + -------- + >>> with Context(dag_label="my_dag", unit_label="unit1", scratch="/tmp/scratch", + ... shared_storage=shared_store, permanent_storage=perm_store) as ctx: + ... # use ctx within the ProtocolUnit._execute method + ... pass """ - scratch: Path - shared: Path - stderr: Path | None = None - stdout: Path | None = None + def __init__( + self, + dag_label: str, + unit_label: str, + scratch: Path, + shared_storage: ExternalStorage, + permanent_storage: ExternalStorage, + stderr: Optional[Path] = None, + stdout: Optional[Path] = None, + ): + self.scratch = scratch + self.shared = StorageManager( + scratch_dir=scratch, + storage=shared_storage, + unit_label=unit_label, + dag_label=dag_label, + ) + self.permanent = StorageManager( + scratch_dir=scratch, + storage=permanent_storage, + unit_label=unit_label, + dag_label=dag_label, + ) + self.stderr = stderr + self.stdout = stdout + + def __enter__(self) -> Context: + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shared._transfer() + self.permanent._transfer() + if self.stderr: + shutil.rmtree(self.stderr) + if self.stdout: + shutil.rmtree(self.stdout) def _list_dependencies(inputs, cls): diff --git a/gufe/storage/__init__.py b/gufe/storage/__init__.py index 853c55a2a..a81434523 100644 --- a/gufe/storage/__init__.py +++ b/gufe/storage/__init__.py @@ -1 +1,3 @@ """How to store objects across simulation campaigns""" + +from .storagemanager import StorageManager diff --git a/gufe/storage/storagemanager.py b/gufe/storage/storagemanager.py new file mode 100644 index 000000000..5b0971b0c --- /dev/null +++ b/gufe/storage/storagemanager.py @@ -0,0 +1,108 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Literal + +from .externalresource import ExternalStorage + + +class StorageManager: + """Manage storage operations for files in a DAG. + + This class provides a context manager for working with storage systems, + allowing registration, loading, and transfer of files between scratch + directory and external storage. + + Parameters + ---------- + scratch_dir : Path + Path to the scratch directory where files are temporarily stored. + storage : ExternalStorage + External storage system for persistent file storage. + dag_label : str + Label for the directed acyclic graph (DAG) this storage manager belongs to. + unit_label : str + Label for the specific unit within the DAG. + + Attributes + ---------- + scratch_dir : Path + Path to the scratch directory. + storage : ExternalStorage + External storage system. + registry : set[str] + Set of registered filenames to be transferred. + namespace : str + Namespace combining dag_label and unit_label for file organization. + """ + + def __init__( + self, + scratch_dir: Path, + storage: ExternalStorage, + dag_label: str, + unit_label: str, + ): + self.scratch_dir = scratch_dir + self.storage = storage + self.registry: set[str] = set() + self.namespace = f"{dag_label}/{unit_label}" + + @staticmethod + def append_to_namespace(namespace: str, filename: str) -> str: + """Append a filenmae to a namespace, mainly used to + make testing easier. + + Parameters + ---------- + namespace : str + The namespace prefix for the file. + filename : str + The filename to be appended to the namespace. + + Returns + ------- + str + Combined namespace and filename as a storage path. + """ + # We opt _not_ to use Paths because these aren't actually path objects + return f"{namespace}/{filename}" + + def register(self, filename: str): + """Register a filename for later transfer to external storage. + + Parameters + ---------- + filename : str + The filename to register for transfer. + """ + self.registry.add(filename) + + def load(self, filename: str) -> bytes: + """Load an item from external storage. + + Parameters + ---------- + filename : str + The filename to load from external storage. + + Returns + ------- + bytes + The content of the loaded file. + """ + with self.storage.load_stream(filename) as f: + stored = f.read() + return stored + + def __contains__(self, filename: str) -> bool: + return filename in self.registry + + def _transfer(self): + """Transfer all the files from the files in the internal registry to its + corresponding :class:`gufe.externalresource.ExternalStorage`. + """ + for filename in self.registry: + path = self.scratch_dir / filename + with open(path, "rb") as f: + data = f.read() + self.storage.store_bytes(self.append_to_namespace(self.namespace, filename), data) diff --git a/gufe/tests/storage/test_storagemanager.py b/gufe/tests/storage/test_storagemanager.py new file mode 100644 index 000000000..8eca4a087 --- /dev/null +++ b/gufe/tests/storage/test_storagemanager.py @@ -0,0 +1,249 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/gufe + +import pathlib +import tempfile +from unittest.mock import MagicMock, patch + +import pytest +from pluggy import _manager + +from gufe.storage.externalresource import FileStorage, MemoryStorage +from gufe.storage.externalresource.base import ExternalStorage +from gufe.storage.storagemanager import StorageManager + + +@pytest.fixture +def tmp_scratch_dir(): + """Create a temporary scratch directory.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield pathlib.Path(tmp_dir) + + +@pytest.fixture +def memory_storage(): + """Create a MemoryStorage instance for testing.""" + return MemoryStorage() + + +@pytest.fixture +def file_storage(tmp_path): + """Create a FileStorage instance for testing.""" + return FileStorage(tmp_path) + + +@pytest.fixture +def storage_manager(tmp_scratch_dir, memory_storage): + """Create a StorageManager instance with MemoryStorage.""" + return StorageManager(tmp_scratch_dir, memory_storage, dag_label="MEM", unit_label="1") + + +@pytest.fixture +def storage_manager_file_storage(tmp_scratch_dir, file_storage): + """Create a StorageManager instance with FileStorage.""" + return StorageManager(tmp_scratch_dir, file_storage, dag_label="FILE", unit_label="1") + + +class TestStorageManager: + """Test the StorageManager class.""" + + def test_init(self, tmp_scratch_dir, memory_storage): + """Test StorageManager initialization.""" + manager = StorageManager(tmp_scratch_dir, memory_storage, dag_label="MEM", unit_label="1") + + assert manager.scratch_dir == tmp_scratch_dir + assert manager.storage == memory_storage + assert isinstance(manager.registry, set) + assert len(manager.registry) == 0 + assert manager.namespace == "MEM/1" + + def test_init_with_file_storage(self, tmp_scratch_dir, file_storage): + """Test StorageManager initialization with FileStorage.""" + manager = StorageManager(tmp_scratch_dir, file_storage, dag_label="FILE", unit_label="1") + + assert manager.scratch_dir == tmp_scratch_dir + assert manager.storage == file_storage + assert isinstance(manager.registry, set) + + def test_convert_to_namespace(self): + filename = "test_file.txt" + + out = StorageManager.append_to_namespace("MEM/1", filename) + assert out == "MEM/1/test_file.txt" + + def test_register(self, storage_manager): + """Test registering a filename.""" + filename = "test_file.txt" + + # Initially registry should be empty + assert filename not in storage_manager.registry + assert filename not in storage_manager + + # Register the file + storage_manager.register(filename) + + # Check it's now in the registry + assert filename in storage_manager.registry + assert filename in storage_manager + + def test_register_multiple_files(self, storage_manager): + """Test registering multiple filenames.""" + files = ["file1.txt", "file2.txt", "dir/file3.txt"] + + for filename in files: + storage_manager.register(filename) + + # Check all files are registered + for filename in files: + assert filename in storage_manager.registry + assert filename in storage_manager + + # Check registry size + assert len(storage_manager.registry) == 3 + + def test_register_duplicate_file(self, storage_manager): + """Test registering the same file multiple times.""" + filename = "duplicate.txt" + + # Register twice + storage_manager.register(filename) + storage_manager.register(filename) + + # Should only appear once (set behavior) + assert filename in storage_manager.registry + assert len(storage_manager.registry) == 1 + + def test_contains(self, storage_manager): + """Test the __contains__ method.""" + filename = "contains_test.txt" + + # Initially should not contain + assert filename not in storage_manager + + # Register and check again + storage_manager.register(filename) + assert filename in storage_manager + + # Check a non-existent file + assert "nonexistent.txt" not in storage_manager + + def test_transfer_empty_registry(self, storage_manager): + """Test _transfer with empty registry.""" + # Should not raise any errors + storage_manager._transfer() + + # Storage should remain empty + assert list(storage_manager.storage) == [] + + def test_transfer_with_files(self, storage_manager, tmp_scratch_dir): + """Test _transfer with actual files.""" + # Create test files in scratch directory + test_files = {"test1.txt": b"Hello World", "test2.txt": b"Another file", "subdir/test3.txt": b"Nested file"} + + # Create files and register them + for filename, content in test_files.items(): + file_path = tmp_scratch_dir / filename + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(content) + storage_manager.register(filename) + + # Transfer files + storage_manager._transfer() + + # Check files are now in storage + for filename in test_files: + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + assert storage_manager.storage.exists(namespaced_filename) + + # Verify content + with storage_manager.storage.load_stream(namespaced_filename) as f: + stored_content = f.read() + assert stored_content == test_files[filename] + + def test_transfer_with_file_storage(self, storage_manager_file_storage, tmp_scratch_dir): + """Test _transfer with FileStorage.""" + # Create test file + filename = "transfer_test.txt" + content = b"Test content for file storage" + + # Create file in scratch directory + file_path = tmp_scratch_dir / filename + file_path.write_bytes(content) + storage_manager_file_storage.register(filename) + + # Transfer file + storage_manager_file_storage._transfer() + + # Check file exists in file storage + namespaced_filename = StorageManager.append_to_namespace(storage_manager_file_storage.namespace, filename) + assert storage_manager_file_storage.storage.exists(namespaced_filename) + + # Verify content + expected_path = storage_manager_file_storage.storage.root_dir / namespaced_filename + assert expected_path.exists() + assert expected_path.read_bytes() == content + + def test_transfer_file_not_found(self, storage_manager): + """Test _transfer when registered file doesn't exist.""" + filename = "nonexistent.txt" + + # Register non-existent file + storage_manager.register(filename) + + # Should raise FileNotFoundError when trying to transfer + with pytest.raises(FileNotFoundError): + storage_manager._transfer() + + def test_registry_persistence(self, storage_manager, tmp_scratch_dir): + """Test that registry persists across operations.""" + filename = "persistent.txt" + + # Create the file first + file_path = tmp_scratch_dir / filename + file_path.write_bytes(b"test content") + + # Register file + storage_manager.register(filename) + assert filename in storage_manager.registry + + # Perform other operations + storage_manager._transfer() # Should not clear registry + + # Registry should still contain the file + assert filename in storage_manager.registry + + def test_load(self, storage_manager, tmp_scratch_dir): + filename = "item.txt" + file_path = tmp_scratch_dir / filename + content = b"test content" + file_path.write_bytes(content) + + # Register file + storage_manager.register(filename) + assert filename in storage_manager.registry + + # Perform other operations + storage_manager._transfer() # Should not clear registry + + # Registry should still contain the file + assert filename in storage_manager.registry + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + out = storage_manager.load(namespaced_filename) + assert out == content + + def test_prepopulated_storage_load(self, file_storage: ExternalStorage, tmp_scratch_dir): + # This test highlights a case where a unit wants to load something + # from shared storage. We basically load content into the medium that we want to fetch later. + + # Store content + contents = b"test content" + name = "test.txt" + file_storage.store_bytes(name, contents) + # Provide this prepopulated storage to a manager + manager = StorageManager(tmp_scratch_dir, storage=file_storage, dag_label="TEST", unit_label="1") + # Validate that content is loaded + out = manager.load(name) + assert out == contents + + # Validate that the registry has not loaded the item + assert len(manager.registry) == 0 diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index 4de6241c3..65d0dcd41 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -29,6 +29,7 @@ ) from gufe.protocols.errors import ProtocolValidationError from gufe.protocols.protocoldag import execute_DAG +from gufe.storage.externalresource import MemoryStorage from .test_tokenization import GufeTokenizableTestsMixin @@ -210,7 +211,7 @@ def instance(self): return DummyProtocol(settings=DummyProtocol.default_settings()) @pytest.fixture - def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir): + def protocol_dag(self, tmp_path, solvated_ligand, vacuum_ligand): protocol = DummyProtocol(settings=DummyProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -218,27 +219,34 @@ def protocol_dag(self, solvated_ligand, vacuum_ligand, tmpdir): name="a dummy run", mapping=None, ) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + # shared = tmp_path / "shared" + # shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - dagresult: ProtocolDAGResult = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout - ) + dagresult: ProtocolDAGResult = execute_DAG( + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + keep_scratch=True, + ) return protocol, dag, dagresult @pytest.fixture - def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmpdir): + def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmp_path): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -246,27 +254,27 @@ def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmpdir): name="a broken dummy run", mapping=None, ) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - dagfailure: ProtocolDAGResult = execute_DAG( - dag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - raise_error=False, - ) + dagfailure: ProtocolDAGResult = execute_DAG( + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + raise_error=False, + ) return protocol, dag, dagfailure @@ -381,7 +389,7 @@ def test_dag_execute_failure(self, protocol_dag_broken): assert len(succeeded_units) > 0 - def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir): + def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmp_path): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -389,20 +397,20 @@ def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, t name="a broken dummy run", mapping=None, ) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) - - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) - - with pytest.raises(ValueError, match="I have failed my mission"): - execute_DAG( - dag, - shared_basedir=shared, - scratch_basedir=scratch, - raise_error=True, - ) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) + + shared = MemoryStorage() + perm = MemoryStorage() + + with pytest.raises(ValueError, match="I have failed my mission"): + execute_DAG( + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + raise_error=True, + ) def test_create_execute_gather(self, protocol_dag): protocol, dag, dagresult = protocol_dag @@ -486,7 +494,7 @@ def instance(self, protocol_dag): protocol, dag, dagresult = protocol_dag return dag - class TestProtocolDAGResult(ProtocolDAGTestsMixin): + class TestProtocolDAGResult: cls = ProtocolDAGResult repr = None @@ -536,7 +544,7 @@ def test_protocol_unit_successes(self, instance: ProtocolDAGResult): assert len(instance.protocol_unit_successes) == 23 assert all(isinstance(i, ProtocolUnitResult) for i in instance.protocol_unit_successes) - class TestProtocolDAGResultFailure(ProtocolDAGTestsMixin): + class TestProtocolDAGResultFailure: cls = ProtocolDAGResult repr = None @@ -651,15 +659,14 @@ def dag(self, protocol): def test_create(self, dag): assert len(dag.protocol_units) == 3 - def test_gather(self, protocol, dag, tmpdir): - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + def test_gather(self, protocol, dag, tmp_path): + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + dag_result = execute_DAG(dag, shared_storage=shared, scratch_basedir=scratch, perm_storage=perm) assert dag_result.ok() @@ -668,16 +675,15 @@ def test_gather(self, protocol, dag, tmpdir): assert result.get_estimate() == 0 + 1 + 4 assert result.get_uncertainty() == 3 - def test_terminal_units(self, protocol, dag, tmpdir): - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + def test_terminal_units(self, protocol, dag, tmp_path): + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - # we have no dependencies, so this should be all three Unit results - dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + # we have no dependencies, so this should be all three Unit results + dag_result = execute_DAG(dag, shared_storage=shared, perm_storage=perm, scratch_basedir=scratch) terminal_results = dag_result.terminal_protocol_unit_results @@ -798,7 +804,7 @@ def test_foreign_objects(self, units, successes): dagresult.result_to_unit(successes[2]) -def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir): +def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmp_path): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -806,36 +812,37 @@ def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmpdir): mapping=None, ) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) - - r = execute_DAG( - dag, - shared_basedir=shared, - scratch_basedir=scratch, - keep_shared=True, - keep_scratch=True, - raise_error=False, - n_retries=3, - ) + shared = MemoryStorage() + perm = MemoryStorage() + + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) + + r = execute_DAG( + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + keep_shared=True, + keep_scratch=True, + raise_error=False, + n_retries=3, + ) - assert not r.ok() + assert not r.ok() - number_unit_failures = len(r.protocol_unit_failures) - number_unit_results = len(r.protocol_unit_results) - number_dirs = len(list(shared.iterdir())) + number_unit_failures = len(r.protocol_unit_failures) + number_unit_results = len(r.protocol_unit_results) + number_dirs = len(list(shared)) - # failed first attempt of BrokenSimulationUnit, failed 3 retries - assert number_unit_failures == 4 - # InitializeUnit and 21 SimulationUnits run before guaranteed - # final failure - assert number_unit_results == number_dirs == 26 + # failed first attempt of BrokenSimulationUnit, failed 3 retries + assert number_unit_failures == 4 + # InitializeUnit and 21 SimulationUnits run before guaranteed + # final failure + assert number_unit_results == 26 -def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir): +def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmp_path): protocol = BrokenProtocol(settings=BrokenProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -843,22 +850,23 @@ def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir): mapping=None, ) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - with pytest.raises(ValueError): - r = execute_DAG( - dag, - shared_basedir=shared, - scratch_basedir=scratch, - keep_shared=True, - keep_scratch=True, - raise_error=False, - n_retries=-1, - ) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) + + with pytest.raises(ValueError): + r = execute_DAG( + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + keep_shared=True, + keep_scratch=True, + raise_error=False, + n_retries=-1, + ) def test_settings_readonly(): diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index 5d37f120a..7375ebc5c 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -7,19 +7,28 @@ from openff.units import unit import gufe -from gufe.protocols import execute_DAG +from gufe.protocols import ProtocolDAG, execute_DAG +from gufe.protocols.protocolunit import Context +from gufe.storage.externalresource.filestorage import FileStorage +from gufe.storage.storagemanager import StorageManager class WriterUnit(gufe.ProtocolUnit): @staticmethod - def _execute(ctx, **inputs): + def _execute(ctx: Context, **inputs): my_id = inputs["identity"] - with open(os.path.join(ctx.shared, f"unit_{my_id}_shared.txt"), "w") as out: - out.write(f"unit {my_id} existed!\n") - with open(os.path.join(ctx.scratch, f"unit_{my_id}_scratch.txt"), "w") as out: - out.write(f"unit {my_id} was here\n") + unit_shared_name = f"unit_{my_id}_shared.txt" + ctx.shared.register(unit_shared_name) + unit_shared = ctx.scratch / unit_shared_name + + unit_scratch_name = f"unit_{my_id}_scratch.txt" + unit_scratch = ctx.scratch / unit_scratch_name + with open(unit_shared, "w") as out: + out.write(f"unit {my_id} existed\n") + with open(unit_scratch, "w") as out: + out.write(f"unit {my_id} was here\n") if ctx.stderr: with open(os.path.join(ctx.stderr, f"unit_{my_id}_stderr"), "w") as out: out.write(f"unit {my_id} wrote to stderr") @@ -78,70 +87,78 @@ def writefile_dag(): @pytest.mark.parametrize("keep_shared", [False, True]) @pytest.mark.parametrize("keep_scratch", [False, True]) @pytest.mark.parametrize("capture_stderr_stdout", [False, True]) -def test_execute_dag(tmpdir, keep_shared, keep_scratch, writefile_dag, capture_stderr_stdout): - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) - - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) - - stderr = None - stdout = None - if capture_stderr_stdout: - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) - - # run dag - execute_DAG( - writefile_dag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - keep_shared=keep_shared, - keep_scratch=keep_scratch, +def test_execute_dag(tmp_path, keep_shared, keep_scratch, writefile_dag: ProtocolDAG, capture_stderr_stdout): + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) + + shared = tmp_path / "shared" + shared.mkdir(parents=True) + + shared_storage = FileStorage(shared) + + perm = tmp_path / "perm" + perm.mkdir(parents=True) + perm_storage = FileStorage(perm) + + stderr = None + stdout = None + if capture_stderr_stdout: + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) + stdout = tmp_path / "stoud" + stdout.mkdir(parents=True) + + # run dag + execute_DAG( + writefile_dag, + shared_storage=shared_storage, + perm_storage=perm_storage, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + keep_shared=keep_shared, + keep_scratch=keep_scratch, + ) + + # check outputs are as expected + # will have produced 4 files in scratch and shared directory + dag_label = str(writefile_dag.key) + for pu in writefile_dag.protocol_units: + identity = pu.inputs["identity"] + # shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt") + shared_file = StorageManager.append_to_namespace(f"{dag_label}/{pu.key}", f"unit_{identity}_shared.txt") + scratch_file = os.path.join( + scratch, + f"scratch_{str(pu.key)}_attempt_0", + f"unit_{identity}_scratch.txt", ) - # check outputs are as expected - # will have produced 4 files in scratch and shared directory - for pu in writefile_dag.protocol_units: - identity = pu.inputs["identity"] - shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt") - scratch_file = os.path.join( - scratch, - f"scratch_{str(pu.key)}_attempt_0", - f"unit_{identity}_scratch.txt", + if capture_stderr_stdout: + stderr_file = os.path.join( + stderr, + f"stderr_{str(pu.key)}_attempt_0", + f"unit_{identity}_stderr", + ) + stdout_file = os.path.join( + stdout, + f"stdout_{str(pu.key)}_attempt_0", + f"unit_{identity}_stdout", ) - if capture_stderr_stdout: - stderr_file = os.path.join( - stderr, - f"stderr_{str(pu.key)}_attempt_0", - f"unit_{identity}_stderr", - ) - stdout_file = os.path.join( - stdout, - f"stdout_{str(pu.key)}_attempt_0", - f"unit_{identity}_stdout", - ) - - # stderr and stdout are always removed since their - # contents are included in the unit results - assert not os.path.exists(stderr_file) - assert not os.path.exists(stdout_file) - - if keep_shared: - assert os.path.exists(shared_file) - else: - assert not os.path.exists(shared_file) - if keep_scratch: - assert os.path.exists(scratch_file) - else: - assert not os.path.exists(scratch_file) - - # check that our shared and scratch basedirs are left behind - assert shared.exists() - assert scratch.exists() + # stderr and stdout are always removed since their + # contents are included in the unit results + assert not os.path.exists(stderr_file) + assert not os.path.exists(stdout_file) + + if keep_shared: + assert shared_storage.exists(shared_file) + else: + assert not os.path.exists(shared_file) + if keep_scratch: + assert os.path.exists(scratch_file) + else: + assert not os.path.exists(scratch_file) + + # check that our shared and scratch basedirs are left behind + assert shared.exists() + assert scratch.exists() diff --git a/gufe/tests/test_protocolunit.py b/gufe/tests/test_protocolunit.py index 1df3e7905..1ba94f56d 100644 --- a/gufe/tests/test_protocolunit.py +++ b/gufe/tests/test_protocolunit.py @@ -5,9 +5,29 @@ from gufe.protocols.errors import ExecutionInterrupt from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult +from gufe.storage.externalresource import MemoryStorage from gufe.tests.test_tokenization import GufeTokenizableTestsMixin +@pytest.fixture +def scratch_storage(tmpdir): + """Fixture to provide a scratch directory for ProtocolUnit tests.""" + with tmpdir.as_cwd(): + scratch = Path("scratch") + scratch.mkdir(parents=True) + yield scratch + + +@pytest.fixture +def shared_storage(): + yield MemoryStorage() + + +@pytest.fixture +def permanent_storage(): + yield MemoryStorage() + + class DummyUnit(ProtocolUnit): @staticmethod def _execute(ctx: Context, an_input=2, **inputs): @@ -62,84 +82,177 @@ def test_key_differs(self): assert u1.key != u2.key @pytest.mark.parametrize("capture_stderr_stdout", [False, True]) - def test_execute(self, tmpdir, capture_stderr_stdout): - with tmpdir.as_cwd(): - unit = DummyUnit() - - shared = Path("shared") / str(unit.key) - shared.mkdir(parents=True) - - scratch = Path("scratch") / str(unit.key) - scratch.mkdir(parents=True) - - if capture_stderr_stdout: - stderr = Path("stderr") / str(unit.key) - stderr.mkdir(parents=True) - - stdout = Path("stdout") / str(unit.key) - stdout.mkdir(parents=True) - - ctx = Context(shared=shared, scratch=scratch, stderr=stderr, stdout=stdout) - else: - ctx = Context(shared=shared, scratch=scratch) - - u: ProtocolUnitFailure = unit.execute(context=ctx, an_input=3) - assert u.exception[0] == "ValueError" - - for output_type in ("stderr", "stdout"): - data = getattr(u, output_type) - if not capture_stderr_stdout: - assert data == {} - continue - for process_number in range(1, 3): - entry = f"dummy_execute_{output_type}_process_{process_number}" - output = f"Sample {output_type} from process {process_number}".encode() - assert data[entry] == output + def test_execute(self, tmpdir, scratch_storage, shared_storage, permanent_storage, capture_stderr_stdout): + unit = DummyUnit() + + if capture_stderr_stdout: + stderr = Path("stderr") / str(unit.key) + stderr.mkdir(parents=True) + + stdout = Path("stdout") / str(unit.key) + stdout.mkdir(parents=True) + + ctx = Context( + scratch=scratch_storage, + dag_label="test", + unit_label=unit.key, + stderr=stderr, + stdout=stdout, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) + else: + ctx = Context( + scratch=scratch_storage, + dag_label="test", + unit_label=unit.key, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) + + u: ProtocolUnitFailure = unit.execute(context=ctx, an_input=3) + assert u.exception[0] == "ValueError" + + for output_type in ("stderr", "stdout"): + data = getattr(u, output_type) + if not capture_stderr_stdout: + assert data == {} + continue + for process_number in range(1, 3): + entry = f"dummy_execute_{output_type}_process_{process_number}" + output = f"Sample {output_type} from process {process_number}".encode() + assert data[entry] == output - # now try actually letting the error raise on execute - with pytest.raises(ValueError, match="should always be 2"): - unit.execute(context=ctx, raise_error=True, an_input=3) + # now try actually letting the error raise on execute + with pytest.raises(ValueError, match="should always be 2"): + unit.execute(context=ctx, raise_error=True, an_input=3) - def test_execute_ExecutionInterrupt(self, tmpdir): - with tmpdir.as_cwd(): - unit = DummyExecutionInterruptUnit() + def test_execute_ExecutionInterrupt(self, scratch_storage, shared_storage, permanent_storage): + unit = DummyExecutionInterruptUnit() - shared = Path("shared") / str(unit.key) - shared.mkdir(parents=True) + shared = Path("shared") / str(unit.key) + shared.mkdir(parents=True) - scratch = Path("scratch") / str(unit.key) - scratch.mkdir(parents=True) + scratch = Path("scratch") / str(unit.key) + scratch.mkdir(parents=True) - ctx = Context(shared=shared, scratch=scratch, stderr=None, stdout=None) + ctx = Context( + shared_storage=shared_storage, + permanent_storage=permanent_storage, + dag_label="test", + unit_label=unit.key, + scratch=scratch_storage, + ) - with pytest.raises(ExecutionInterrupt): - unit.execute(context=ctx, an_input=3) + with pytest.raises(ExecutionInterrupt): + unit.execute(context=ctx, an_input=3) - u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) + u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) - assert u.outputs == {"foo": "bar"} + assert u.outputs == {"foo": "bar"} - def test_execute_KeyboardInterrupt(self, tmpdir): - with tmpdir.as_cwd(): - unit = DummyKeyboardInterruptUnit() + def test_execute_KeyboardInterrupt(self, scratch_storage, permanent_storage, shared_storage): + unit = DummyKeyboardInterruptUnit() - shared = Path("shared") / str(unit.key) - shared.mkdir(parents=True) + shared = Path("shared") / str(unit.key) + shared.mkdir(parents=True) - scratch = Path("scratch") / str(unit.key) - scratch.mkdir(parents=True) + scratch = Path("scratch") / str(unit.key) + scratch.mkdir(parents=True) - ctx = Context(shared=shared, scratch=scratch, stderr=None, stdout=None) + ctx = Context( + shared_storage=shared_storage, + permanent_storage=permanent_storage, + dag_label="test", + unit_label=unit.key, + scratch=scratch_storage, + ) - with pytest.raises(KeyboardInterrupt): - unit.execute(context=ctx, an_input=3) + with pytest.raises(KeyboardInterrupt): + unit.execute(context=ctx, an_input=3) - u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) + u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) - assert u.outputs == {"foo": "bar"} + assert u.outputs == {"foo": "bar"} def test_normalize(self, instance): thingy = instance.key assert thingy.startswith("DummyUnit-") assert all(t in string.hexdigits for t in thingy.partition("-")[-1]) + + +class TestContext: + """Test the Context class context manager functionality.""" + + def test_context_manager_enter_exit( + self, scratch_storage, shared_storage: MemoryStorage, permanent_storage: MemoryStorage + ): + """Test that Context can be used as a context manager.""" + ctx = Context( + dag_label="test", + unit_label="test_unit", + scratch=scratch_storage, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) + file_text = b"Hello World!" + + # Test __enter__ + with ctx as context: + assert context is ctx + assert ctx.shared.scratch_dir == scratch_storage + assert ctx.permanent.scratch_dir == scratch_storage + filename = "test.txt" + test_file = context.scratch / filename + context.shared.register(filename) + context.permanent.register(filename) + with test_file.open("b+w") as f: + f.write(file_text) + + # Test __exit__ - should transfer the file to a namespaced location in shared_storage + assert shared_storage.exists("test/test_unit/test.txt") + with shared_storage.load_stream("test/test_unit/test.txt") as item: + out = item.read() + assert out == file_text + + assert permanent_storage.exists("test/test_unit/test.txt") + with permanent_storage.load_stream("test/test_unit/test.txt") as item: + out = item.read() + assert out == file_text + + def test_context_manager_cleanup_stdout_stderr( + self, scratch_storage, shared_storage: MemoryStorage, permanent_storage: MemoryStorage, tmp_path + ): + stdout: Path = tmp_path / "stdout" + stdout.mkdir() + # We write something into stdout + stderr: Path = tmp_path / "stderr" + stderr.mkdir() + + ctx = Context( + dag_label="test", + unit_label="test_unit", + scratch=scratch_storage, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + stdout=stdout, + stderr=stderr, + ) + # Validate the directory is empty + assert any(Path(stdout).iterdir()) == False + assert any(Path(stderr).iterdir()) == False + file_text = b"Hello world" + with ctx as context: + filename = "test.txt" + stdout_file = context.stdout / filename + stderr_file = context.stderr / filename + with stdout_file.open("b+w") as f: + f.write(file_text) + with stderr_file.open("b+w") as f: + f.write(file_text) + assert any(Path(stdout).iterdir()) == True + assert any(Path(stderr).iterdir()) == True + # Validate we cleanup + assert not Path(stderr).exists() + assert not Path(stdout).exists() diff --git a/gufe/tests/test_transformation.py b/gufe/tests/test_transformation.py index 410861075..3acc03635 100644 --- a/gufe/tests/test_transformation.py +++ b/gufe/tests/test_transformation.py @@ -9,6 +9,7 @@ import gufe from gufe.protocols.errors import ProtocolValidationError from gufe.protocols.protocoldag import execute_DAG +from gufe.storage.externalresource import MemoryStorage from gufe.transformations import NonTransformation, Transformation from .test_protocol import DummyProtocol, DummyProtocolResult @@ -47,33 +48,33 @@ def test_init(self, absolute_transformation, solvated_ligand, solvated_complex): assert tnf.stateA is solvated_ligand assert tnf.stateB is solvated_complex - def test_protocol(self, absolute_transformation, tmpdir): + def test_protocol(self, absolute_transformation, tmp_path): tnf = absolute_transformation assert isinstance(tnf.protocol, DummyProtocol) protocoldag = tnf.create() - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = pathlib.Path("scratch") + scratch.mkdir(parents=True, exist_ok=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - protocoldagresult = execute_DAG( - protocoldag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - ) + protocoldagresult = execute_DAG( + protocoldag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + ) protocolresult = tnf.gather([protocoldagresult]) @@ -112,41 +113,42 @@ def test_validation_bad_state(self, solvated_ligand, solvated_complex, tmpdir): validate=True, ) - def test_protocol_extend(self, absolute_transformation, tmpdir): + def test_protocol_extend(self, absolute_transformation, tmp_path): tnf = absolute_transformation assert isinstance(tnf.protocol, DummyProtocol) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - protocoldag = tnf.create() - protocoldagresult = execute_DAG( - protocoldag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - ) + protocoldag = tnf.create() + protocoldagresult = execute_DAG( + protocoldag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + ) - protocoldag2 = tnf.create(extends=protocoldagresult) - protocoldagresult2 = execute_DAG( - protocoldag2, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - ) + protocoldag2 = tnf.create(extends=protocoldagresult) + protocoldagresult2 = execute_DAG( + protocoldag2, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + ) protocolresult = tnf.gather([protocoldagresult, protocoldagresult2]) @@ -215,33 +217,33 @@ def test_init(self, complex_equilibrium, solvated_complex): assert ntnf.system is solvated_complex - def test_protocol(self, complex_equilibrium, tmpdir): + def test_protocol(self, complex_equilibrium, tmp_path): ntnf = complex_equilibrium assert isinstance(ntnf.protocol, DummyProtocol) protocoldag = ntnf.create() - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - protocoldagresult = execute_DAG( - protocoldag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - ) + protocoldagresult = execute_DAG( + protocoldag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + ) protocolresult = ntnf.gather([protocoldagresult]) @@ -273,39 +275,40 @@ def test_validation_bad_state(self, solvated_ligand, solvated_complex, tmpdir): validate=True, ) - def test_protocol_extend(self, complex_equilibrium, tmpdir): + def test_protocol_extend(self, complex_equilibrium, tmp_path): ntnf = complex_equilibrium assert isinstance(ntnf.protocol, DummyProtocol) - with tmpdir.as_cwd(): - shared = pathlib.Path("shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path("scratch") - scratch.mkdir(parents=True) + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) - stderr = pathlib.Path("stderr") - stderr.mkdir(parents=True) + stderr = tmp_path / "stderr" + stderr.mkdir(parents=True) - stdout = pathlib.Path("stdout") - stdout.mkdir(parents=True) + stdout = tmp_path / "stdout" + stdout.mkdir(parents=True) - protocoldag = ntnf.create() - protocoldagresult = execute_DAG( - protocoldag, - shared_basedir=shared, - scratch_basedir=scratch, - stderr_basedir=stderr, - stdout_basedir=stdout, - ) + protocoldag = ntnf.create() + protocoldagresult = execute_DAG( + protocoldag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + ) - protocoldag2 = ntnf.create(extends=protocoldagresult) - protocoldagresult2 = execute_DAG( - protocoldag2, - shared_basedir=shared, - scratch_basedir=scratch, - ) + protocoldag2 = ntnf.create(extends=protocoldagresult) + protocoldagresult2 = execute_DAG( + protocoldag2, + scratch_basedir=scratch, + shared_storage=shared, + perm_storage=perm, + ) protocolresult = ntnf.gather([protocoldagresult, protocoldagresult2]) diff --git a/news/rework-context.rst b/news/rework-context.rst new file mode 100644 index 000000000..49b0b5c4a --- /dev/null +++ b/news/rework-context.rst @@ -0,0 +1,25 @@ +**Added:** + +* ``StorageManager`` for managing files during unit execution. Allows for auto transfer to storage mediums following execution of a unit. + +**Changed:** + +* ``Context`` now uses ``StorageManager`` to handle operations in protocol units. Units must now register files in scratch to be transferred shared or permanent storage. +* ``ProtocolUnit``s now have permanent storage which can be used to outlast a running ``ProtocolDAG``. + + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* From 518a86b4e4ac83d4e22233ab323e33f6aaacfaa2 Mon Sep 17 00:00:00 2001 From: Alyssa Travitz Date: Fri, 23 Jan 2026 14:37:31 -0800 Subject: [PATCH 2/5] apply precommit --- gufe/storage/storagemanager.py | 2 -- gufe/tests/storage/test_storagemanager.py | 2 -- gufe/tests/test_protocoldag.py | 1 - 3 files changed, 5 deletions(-) diff --git a/gufe/storage/storagemanager.py b/gufe/storage/storagemanager.py index 5b0971b0c..afedaa48e 100644 --- a/gufe/storage/storagemanager.py +++ b/gufe/storage/storagemanager.py @@ -1,6 +1,4 @@ -from contextlib import contextmanager from pathlib import Path -from typing import Literal from .externalresource import ExternalStorage diff --git a/gufe/tests/storage/test_storagemanager.py b/gufe/tests/storage/test_storagemanager.py index 8eca4a087..ecd797f54 100644 --- a/gufe/tests/storage/test_storagemanager.py +++ b/gufe/tests/storage/test_storagemanager.py @@ -3,10 +3,8 @@ import pathlib import tempfile -from unittest.mock import MagicMock, patch import pytest -from pluggy import _manager from gufe.storage.externalresource import FileStorage, MemoryStorage from gufe.storage.externalresource.base import ExternalStorage diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index 7375ebc5c..a933e1b77 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -1,7 +1,6 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import os -import pathlib import pytest from openff.units import unit From c932c8b3c3cd058024777d62232ac9ba18105d85 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 28 Jan 2026 09:00:12 -0700 Subject: [PATCH 3/5] feat: return storage handles from StorageManager (#718) * feat: return namespaced handles so other untis can ingest these handles * test: add a test to validate * Update gufe/storage/storagemanager.py Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> --------- Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> --- gufe/storage/storagemanager.py | 8 +++++++- gufe/tests/storage/test_storagemanager.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/gufe/storage/storagemanager.py b/gufe/storage/storagemanager.py index afedaa48e..6a548545d 100644 --- a/gufe/storage/storagemanager.py +++ b/gufe/storage/storagemanager.py @@ -65,15 +65,21 @@ def append_to_namespace(namespace: str, filename: str) -> str: # We opt _not_ to use Paths because these aren't actually path objects return f"{namespace}/{filename}" - def register(self, filename: str): + def register(self, filename: str) -> str: """Register a filename for later transfer to external storage. Parameters ---------- filename : str The filename to register for transfer. + + Returns + ------- + str + The globally namespaced path to be used by other units """ self.registry.add(filename) + return self.append_to_namespace(self.namespace, filename) def load(self, filename: str) -> bytes: """Load an item from external storage. diff --git a/gufe/tests/storage/test_storagemanager.py b/gufe/tests/storage/test_storagemanager.py index ecd797f54..2cab5f8c1 100644 --- a/gufe/tests/storage/test_storagemanager.py +++ b/gufe/tests/storage/test_storagemanager.py @@ -78,7 +78,8 @@ def test_register(self, storage_manager): assert filename not in storage_manager # Register the file - storage_manager.register(filename) + out = storage_manager.register(filename) + assert out == storage_manager.append_to_namespace(storage_manager.namespace, filename) # Check it's now in the registry assert filename in storage_manager.registry From 5d79a7c0614282cc7180f626fe814902fea1d5f7 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Wed, 25 Feb 2026 11:47:11 -0700 Subject: [PATCH 4/5] feat: add the ability to trigger a file sync immediately to a storage backend (#745) * refactor: split out transfer functionality * test: split out transfer functionality tests Signed-off-by: Ethan Holz --------- Signed-off-by: Ethan Holz Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> --- gufe/storage/storagemanager.py | 33 +++++++++++++++++---- gufe/tests/storage/test_storagemanager.py | 35 ++++++++++++++++++++++- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/gufe/storage/storagemanager.py b/gufe/storage/storagemanager.py index 6a548545d..cd496f252 100644 --- a/gufe/storage/storagemanager.py +++ b/gufe/storage/storagemanager.py @@ -102,11 +102,32 @@ def __contains__(self, filename: str) -> bool: return filename in self.registry def _transfer(self): - """Transfer all the files from the files in the internal registry to its - corresponding :class:`gufe.externalresource.ExternalStorage`. + """Transfer all registered files to external storage. + + Transfers each filename currently registered on this manager into the + configured external storage backend using this manager's namespace. + + Raises + ------ + FileNotFoundError + If any registered file does not exist in ``scratch_dir``. """ for filename in self.registry: - path = self.scratch_dir / filename - with open(path, "rb") as f: - data = f.read() - self.storage.store_bytes(self.append_to_namespace(self.namespace, filename), data) + self.transfer_file(filename) + + def transfer_file(self, filename: str) -> None: + """Transfer a single file from scratch space to external storage. + + Parameters + ---------- + filename : str + Relative filename (within ``scratch_dir``) to transfer. + + Raises + ------ + FileNotFoundError + If ``filename`` does not exist in ``scratch_dir``. + """ + path = self.scratch_dir / filename + location = self.append_to_namespace(self.namespace, filename) + self.storage.store_path(location, path) diff --git a/gufe/tests/storage/test_storagemanager.py b/gufe/tests/storage/test_storagemanager.py index 2cab5f8c1..e9711e642 100644 --- a/gufe/tests/storage/test_storagemanager.py +++ b/gufe/tests/storage/test_storagemanager.py @@ -126,6 +126,39 @@ def test_contains(self, storage_manager): # Check a non-existent file assert "nonexistent.txt" not in storage_manager + def test_transfer_file(self, tmp_scratch_dir, storage_manager): + """Test the transfer_file method.""" + filename = "somefile.txt" + data = b"Hello World!" + + path = tmp_scratch_dir / filename + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + + storage_manager.register(filename) + + assert filename in storage_manager + storage_manager.transfer_file(filename) + + # Check for this file + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + assert storage_manager.storage.exists(namespaced_filename) + + with storage_manager.storage.load_stream(namespaced_filename) as f: + stored_content = f.read() + assert stored_content == data + + def test_transfer_file_not_found(self, tmp_scratch_dir, storage_manager): + """Test _transfer when registered file doesn't exist.""" + filename = "nonexistent.txt" + + # Register non-existent file + storage_manager.register(filename) + + # Should raise FileNotFoundError when trying to transfer + with pytest.raises(FileNotFoundError): + storage_manager.transfer_file(filename) + def test_transfer_empty_registry(self, storage_manager): """Test _transfer with empty registry.""" # Should not raise any errors @@ -182,7 +215,7 @@ def test_transfer_with_file_storage(self, storage_manager_file_storage, tmp_scra assert expected_path.exists() assert expected_path.read_bytes() == content - def test_transfer_file_not_found(self, storage_manager): + def test_transfer_not_found(self, storage_manager): """Test _transfer when registered file doesn't exist.""" filename = "nonexistent.txt" From e0a9e49a8537d1630d201bdf28c87240355cf9c4 Mon Sep 17 00:00:00 2001 From: Ethan Holz Date: Thu, 5 Mar 2026 14:47:09 -0700 Subject: [PATCH 5/5] docs: Context and Storage (#712) * docs: add information on storage Signed-off-by: Ethan Holz * docs: initial draft on context * docs: improve context docs * docs: migrate how to write a custom implementation * doc: update with comments from review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * fix docs build * docs: Add information on changes in feat/return-storage-handles This adds new information on how storage is handled, and how storage gets passed around by units. * docs: add information on pre-namespaced objects from other branch * docs: add suggestion to back link to Context docs * docs: add becomes in migration guide per suggestion * docs: add link to protocol how to * docs: add more informaton on StorageManagers * docs: clean up storage code example * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * fix typo * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * docs: add more context on context * fix note formatting * move serialization constraints to the end * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * Apply suggestion from @atravitz Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * fix formatting for docs build * Apply suggestions from code review Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> * clean up --------- Signed-off-by: Ethan Holz Co-authored-by: Alyssa Travitz <31974495+atravitz@users.noreply.github.com> Co-authored-by: Alyssa Travitz --- docs/concepts/context.rst | 175 +++++++++++++++++++++++++++++++ docs/concepts/index.rst | 4 +- docs/concepts/storage.rst | 215 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 393 insertions(+), 1 deletion(-) create mode 100644 docs/concepts/context.rst create mode 100644 docs/concepts/storage.rst diff --git a/docs/concepts/context.rst b/docs/concepts/context.rst new file mode 100644 index 000000000..c678f8fb5 --- /dev/null +++ b/docs/concepts/context.rst @@ -0,0 +1,175 @@ +``ProtocolUnit`` execution ``Context`` +====================================== + +:class:`.Context` instances carry the execution environment for individual :class:`.ProtocolUnit` executions. +They are created by the execution engine just before a unit is executed and discarded once the unit returns. +The class acts as a thin wrapper around two :class:`.StorageManager` objects (shared and permanent) and a scratch directory. + + +Why Context exists +------------------ + +``ProtocolUnit`` code frequently needs a few shared facilities: + +``scratch`` + A temporary directory that the unit can freely write to while it runs. + Files written here are considered ephemeral; the engine may delete them as + soon as the unit finishes. + +``shared`` + + A :class:`.StorageManager` backed by :class:`.ExternalStorage`. + The specific lifetime of ``shared`` is determined by the execution engine, but standard convention it that ``shared`` is to be used to hand large files to downstream units without serializing the payloads through Python return values. + + +``permanent`` + Another :class:`.StorageManager` targeting long–term storage. Results saved here + survive beyond the life of the :class:`.ProtocolDAG` run (for example for + inspection or for reuse in future extensions). + +``stdout`` / ``stderr`` + Optional directories where the engine captures subprocess output triggered + by the ``ProtocolUnit``. The directories are removed automatically when the context + closes. + +Keeping these handles bundled together and managed by a context manager lets +``ProtocolUnit`` implementers focus on domain logic while the engine ensures +storage gets flushed and temporary directories are cleaned. + + +Lifecycle +--------- + +``Context`` implements a `Python context manager`_. +When the context is exited, ``shared`` and ``permanent`` storage managers flushed tracked files back to their underlying :class:`ExternalStorage`. +Any ``stdout`` or ``stderr`` capture directories are also removed. + +.. _Python context manager: https://docs.python.org/3/reference/datamodel.html#context-managers + +This means each ``ProtocolUnit``'s ``shared`` and ``permanent`` objects are not paths, and should not be treated as such. +Both of these are registries that track if a file should be transferred from its location in ``scratch`` to its final location after completing a unit. + +To access data from ``shared`` or ``permanent``, you can use ``ctx.shared.load`` or ``ctx.permanent.load``. +This will allow your unit to fetch those objects from their storage for use. + + +Using Context inside ProtocolUnits +---------------------------------- + +Every :meth:`.ProtocolUnit._execute` definition must accept ``ctx: Context`` as its +first argument. Typical usage looks like the example below. + +.. code-block:: python + + from gufe import ProtocolUnit, Context + + class SimulationUnit(ProtocolUnit): + + @staticmethod + def _execute(ctx: Context, *, setup_result, lambda_window, settings): + scratch_path = ctx.scratch / f"lambda_{lambda_window}" + scratch_path.mkdir(exist_ok=True) + + # Read upstream artifacts from ctx.shared + system_file = ctx.shared.load(setup_result.outputs["system_file"]) + topology_file = ctx.shared.load(setup_result.outputs["topology_file"]) + + + result_path = ctx.scratch / "some_output.pdb" + # When you register the filename doesn't matter, + # just as long as you do it before you return + result_path_final_location = ctx.permanent.register(result_path) + # This is an example of running something that you want to save + simulate(output=result_path) + + # Return only lightweight metadata + return { + "lambda_window": lambda_window, + # We use this because it is already namespaced and can be used between units. + "result_path": result_path_final_location, + } + +The example above showcases how you are to register files. +It is important to note that the ``result_path`` is different from the ``result_path_final_location``. +The ``result_path`` exists as a normal path, but ``result_path_final_location`` is a handle used to be passed between units. +The final location is for the execution engine to correctly namespace files and provide them back to subsequent units. + + +Choosing between shared and permanent storage +--------------------------------------------- + +Both ``ctx.shared`` and ``ctx.permanent`` expose the same :class:`.StorageManager` +API but they serve different audiences: + +``ctx.shared`` + Optimized for communication between units in the same DAG execution. The + execution backend is free to prune these assets once no downstream unit + references them. + +``ctx.permanent`` + Intended for outputs that should survive beyond the immediate DAG, such as + user-facing reports or artifacts that will seed future runs. + +As a rule of thumb, prefer ``ctx.shared`` unless you have a clear requirement +to keep the data after the ``Protocol`` run concludes. Small scalar values or +lightweight metadata should still be returned directly from ``_execute`` so +they become part of the ``ProtocolUnitResult`` record. + + +Interaction with Protocols +-------------------------- + +``Protocol`` instances do not instantiate ``Context`` directly; they declare ``ProtocolUnit`` objects via ``Protocol._create``. +When an execution backend walks the resulting +``ProtocolDAG`` it constructs a ``Context`` for each unit using the DAG label, unit label, scratch directory, and the configured ``ExternalStorage`` implementations. +The backend might provide different ``ExternalStorage`` implementations (e.g., local filesystem, object store, cluster scratch) depending on where the work runs, but the ``Context`` API seen by ``ProtocolUnit`` authors stays consistent. + +Because the execution backend is in charge of creating the contexts, protocol authors can rely on ``ctx`` always being populated with valid storage managers and paths that are safe to write to from distributed workers. + + +Migrating from legacy Context usage +----------------------------------- + +Before ``Context`` was rewritten, it was a simple data class with two ``pathlib.Path`` handles: ``scratch`` and ``shared``. +Existing protocols adopted a variety of implicit conventions around those attributes. +Follow this checklist when migrating old protocols: + +1. **Swap file paths to StorageManager APIs.** Calls like ``ctx.shared / + "filename"`` should be replaced with the helper methods offered by + :class:`StorageManager`. For example: + +.. code-block:: python + + path = ctx.shared.scratch_dir / "myfile.dat" + +becomes: + +.. code-block:: python + + ctx.shared.register("myfile.dat") + + +2. **Avoid storing heavy objects in Python outputs.** Older protocols often + returned raw ``Path`` objects pointing at scratch files. Instead, register + the file with the storage manager and return the storage key (a string) from + ``_execute`` as shown in the example above. Downstream units can then call + :meth:`.StorageManager.load`. + +3. **Handle ctx.permanent.** There was no equivalent in the legacy API. + Decide which results must persist between DAG executions and write them via + ``ctx.permanent``. For migration you can start by mirroring whatever used + to live in ``ctx.shared`` and refine later. + +4. **Expect automatic cleanup.** Old contexts typically left stdout/stderr + directories around. The new context removes these when the unit finishes. + If your code tried to re-read the capture directories after ``_execute`` + returned, move that logic earlier or rely on the logged data captured in the + ``ProtocolUnitResult``. + +5. **Stop constructing Context manually.** + Some bespoke execution scripts once instantiated ``Context(scratch=..., shared=...)`` by hand. + That pattern is obsolete because the constructor now requires ``ExternalStorage`` objects. + Instead, rely on the execution backend to build contexts. + For unit tests use the helpers in ``gufe.storage.externalresource`` (e.g., ``MemoryStorage``) to create the necessary storage instances. + +If you need more information on how to use these concepts, checkout out: :doc:`../how-tos/protocol`. diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index c52ebaca0..a67db6e6a 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -11,5 +11,7 @@ utilize **gufe** APIs. tokenizables included_models - serialization + storage + context logging + serialization diff --git a/docs/concepts/storage.rst b/docs/concepts/storage.rst new file mode 100644 index 000000000..c4ebc885d --- /dev/null +++ b/docs/concepts/storage.rst @@ -0,0 +1,215 @@ +.. _concepts-storage: + +How storage is handled in **gufe** +================================== + +**gufe** abstracts storage into a reusable interface using the :class:`.ExternalStorage` abstract base class. +This abstraction enables the storage of any file or byte stream using various storage backends without changing application code. + +Overview +-------- + +The storage system is designed to handle files (or byte data) that need to be stored in some location. +Instead of embedding the data, objects can store a reference (a unique string indicating the object's location, such as a path) to where the data is stored externally. This approach provides several benefits: + +* **Efficiency**: Large objects don't need to be serialized multiple times +* **Flexibility**: Different storage backends (local filesystem, cloud storage, in-memory) can be used interchangeably +* **Deduplication**: The same data can be referenced by multiple objects +* **Lazy Loading**: Data is only loaded when needed + +The Storage Architecture +------------------------- + +The storage system consists of several key components: + +``ExternalStorage`` Base Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`.ExternalStorage` abstract base class defines the interface that all storage implementations must provide. This class provides: + +* **Store operations**: ``store_bytes()`` and ``store_path()`` to store data +* **Load operations**: ``load_stream()`` to retrieve data as a stream +* **Management**: ``exists()``, ``delete()``, ``transfer_file()``, and ``iter_contents()`` for managing stored data + +All storage operations use a string as a location identifier for the stored data. +Convention is often to use path-like formatting, such as `"datasets/sample1.txt"`, but any unique string can be used. + +Storage Implementations +----------------------- + +**gufe** provides several built-in implementations of :class:`.ExternalStorage`: + +``FileStorage`` +~~~~~~~~~~~~~~~ + +The :class:`.FileStorage` implementation stores data on the local filesystem. It requires a root directory path and organizes stored files using the location string as a relative path: + +.. code-block:: python + + from pathlib import Path + from gufe.storage.externalresource import FileStorage + + # Create a file storage backend + storage = FileStorage(root_dir=Path("/path/to/storage")) + + # Store some data + data = b"Hello, World!" + storage.store_bytes("datasets/sample1.txt", data) + + # Check if data exists + if storage.exists("datasets/sample1.txt"): + # Load the data + with storage.load_stream("datasets/sample1.txt") as stream: + loaded_data = stream.read() + assert loaded_data == data + + # Delete the data + storage.delete("datasets/sample1.txt") + +``FileStorage`` automatically creates any necessary parent directories when storing files. + +``MemoryStorage`` +~~~~~~~~~~~~~~~~~ + +The :class:`.MemoryStorage` implementation stores data in a Python dictionary. This is primarily useful for testing and prototyping: + +.. code-block:: python + + from gufe.storage.externalresource import MemoryStorage + + # Create an in-memory storage backend + storage = MemoryStorage() + + # Store some data + data = b"Hello, World!" + storage.store_bytes("datasets/sample1.txt", data) + + # Load the data back + with storage.load_stream("datasets/sample1.txt") as stream: + loaded_data = stream.read() + +.. warning:: + ``MemoryStorage`` is not intended for production use and all data is lost when the Python process exits. + +Implementing Custom Storage Backends +------------------------------------- + +To create a custom storage backend, subclass :class:`.ExternalStorage` and implement all the abstract methods: + +.. code-block:: python + + from gufe.storage.externalresource.base import ExternalStorage + from typing import ContextManager + + class MyCustomStorage(ExternalStorage): + """A custom storage implementation.""" + + def _store_bytes(self, location: str, byte_data: bytes): + """Store bytes at the given location.""" + # Implement storage logic + pass + + def _store_path(self, location: str, path): + """Store a file at the given path.""" + # Implement storage logic + pass + + def _load_stream(self, location: str) -> ContextManager: + """Return a context manager that yields a bytes-like object.""" + # Implement loading logic + pass + + def _exists(self, location: str) -> bool: + """Check if data exists at the location.""" + # Implement existence check + pass + + def _delete(self, location: str): + """Delete data at the location.""" + # Implement deletion logic + pass + + def _get_filename(self, location: str) -> str: + """Return a filename for the location.""" + # Implement filename generation + pass + + def _iter_contents(self, prefix: str = ""): + """Iterate over stored locations matching the prefix.""" + # Implement iteration logic + pass + + def _get_hexdigest(self, location: str) -> str: + """Return MD5 hexdigest of the data (optional override).""" + # Can override for performance improvements + pass + +.. note:: + All storage methods should be blocking operations, even if the underlying storage backend supports asynchronous operations. + +StorageManager +-------------- + +The :class:`.StorageManager` class provides a higher-level interface for managing storage operations within a computational workflow. + +.. note:: + ``StorageManager`` is largely used by the :class:`.Context` class and should not be instantiated in protocols. + In general, protocol developers will only use the ``register`` and ``load`` functions. + +It handles the transfer of files between a scratch directory and external storage (such as shared or permanent storage): + +.. code-block:: python + + from pathlib import Path + from gufe.storage import StorageManager + from gufe.storage.externalresource import FileStorage + + # Set up storage + storage = FileStorage(root_dir=Path("/path/to/storage")) + scratch_dir = Path("/path/to/scratch") + + # Create a storage manager for a specific DAG and unit + manager = StorageManager( + scratch_dir=scratch_dir, + storage=storage, + dag_label="my_experiment", + unit_label="transformation_1" + ) + + # Register files for later transfer + out = manager.register("trajectory.dcd") + out2 = manager.register("results.json") + # Note: out and out2 are pre-namespaced values that allow storage items to be passed around + + # Transfer all registered files to external storage + manager._transfer() + # Explicitly transfer a specific file. This is useful for checkpointing and other instances where you want to ensure a file is transferred to storage at a specific time in the protocol. + manager.transfer_file("trajectory.dcd") + # Load files from external storage + trajectory_data = manager.load(out) + results_json = manager.load(out2) + +The ``StorageManager`` uses a namespace combining the ``dag_label`` and ``unit_label`` to organize files in the external storage backend. +To see how these work in practice see our documentation on :doc:`context`. + + +Error Handling +-------------- + +The storage system defines several exceptions in :mod:`gufe.storage.errors`: + +* :class:`.ExternalResourceError`: Base class for storage-related errors +* :class:`.MissingExternalResourceError`: Raised when attempting to access non-existent data +* :class:`.ChangedExternalResourceError`: Raised when metadata verification fails + +These exceptions can be caught and handled appropriately in application code: + +.. code-block:: python + + from gufe.storage.errors import MissingExternalResourceError + + try: + with storage.load_stream("nonexistent_file.txt") as stream: + data = stream.read() + except MissingExternalResourceError: + print("File not found in storage")