diff --git a/docs/concepts/context.rst b/docs/concepts/context.rst new file mode 100644 index 000000000..c678f8fb5 --- /dev/null +++ b/docs/concepts/context.rst @@ -0,0 +1,175 @@ +``ProtocolUnit`` execution ``Context`` +====================================== + +:class:`.Context` instances carry the execution environment for individual :class:`.ProtocolUnit` executions. +They are created by the execution engine just before a unit is executed and discarded once the unit returns. +The class acts as a thin wrapper around two :class:`.StorageManager` objects (shared and permanent) and a scratch directory. + + +Why Context exists +------------------ + +``ProtocolUnit`` code frequently needs a few shared facilities: + +``scratch`` + A temporary directory that the unit can freely write to while it runs. + Files written here are considered ephemeral; the engine may delete them as + soon as the unit finishes. + +``shared`` + + A :class:`.StorageManager` backed by :class:`.ExternalStorage`. + The specific lifetime of ``shared`` is determined by the execution engine, but standard convention it that ``shared`` is to be used to hand large files to downstream units without serializing the payloads through Python return values. + + +``permanent`` + Another :class:`.StorageManager` targeting long–term storage. Results saved here + survive beyond the life of the :class:`.ProtocolDAG` run (for example for + inspection or for reuse in future extensions). + +``stdout`` / ``stderr`` + Optional directories where the engine captures subprocess output triggered + by the ``ProtocolUnit``. The directories are removed automatically when the context + closes. + +Keeping these handles bundled together and managed by a context manager lets +``ProtocolUnit`` implementers focus on domain logic while the engine ensures +storage gets flushed and temporary directories are cleaned. + + +Lifecycle +--------- + +``Context`` implements a `Python context manager`_. +When the context is exited, ``shared`` and ``permanent`` storage managers flushed tracked files back to their underlying :class:`ExternalStorage`. +Any ``stdout`` or ``stderr`` capture directories are also removed. + +.. _Python context manager: https://docs.python.org/3/reference/datamodel.html#context-managers + +This means each ``ProtocolUnit``'s ``shared`` and ``permanent`` objects are not paths, and should not be treated as such. +Both of these are registries that track if a file should be transferred from its location in ``scratch`` to its final location after completing a unit. + +To access data from ``shared`` or ``permanent``, you can use ``ctx.shared.load`` or ``ctx.permanent.load``. +This will allow your unit to fetch those objects from their storage for use. + + +Using Context inside ProtocolUnits +---------------------------------- + +Every :meth:`.ProtocolUnit._execute` definition must accept ``ctx: Context`` as its +first argument. Typical usage looks like the example below. + +.. code-block:: python + + from gufe import ProtocolUnit, Context + + class SimulationUnit(ProtocolUnit): + + @staticmethod + def _execute(ctx: Context, *, setup_result, lambda_window, settings): + scratch_path = ctx.scratch / f"lambda_{lambda_window}" + scratch_path.mkdir(exist_ok=True) + + # Read upstream artifacts from ctx.shared + system_file = ctx.shared.load(setup_result.outputs["system_file"]) + topology_file = ctx.shared.load(setup_result.outputs["topology_file"]) + + + result_path = ctx.scratch / "some_output.pdb" + # When you register the filename doesn't matter, + # just as long as you do it before you return + result_path_final_location = ctx.permanent.register(result_path) + # This is an example of running something that you want to save + simulate(output=result_path) + + # Return only lightweight metadata + return { + "lambda_window": lambda_window, + # We use this because it is already namespaced and can be used between units. + "result_path": result_path_final_location, + } + +The example above showcases how you are to register files. +It is important to note that the ``result_path`` is different from the ``result_path_final_location``. +The ``result_path`` exists as a normal path, but ``result_path_final_location`` is a handle used to be passed between units. +The final location is for the execution engine to correctly namespace files and provide them back to subsequent units. + + +Choosing between shared and permanent storage +--------------------------------------------- + +Both ``ctx.shared`` and ``ctx.permanent`` expose the same :class:`.StorageManager` +API but they serve different audiences: + +``ctx.shared`` + Optimized for communication between units in the same DAG execution. The + execution backend is free to prune these assets once no downstream unit + references them. + +``ctx.permanent`` + Intended for outputs that should survive beyond the immediate DAG, such as + user-facing reports or artifacts that will seed future runs. + +As a rule of thumb, prefer ``ctx.shared`` unless you have a clear requirement +to keep the data after the ``Protocol`` run concludes. Small scalar values or +lightweight metadata should still be returned directly from ``_execute`` so +they become part of the ``ProtocolUnitResult`` record. + + +Interaction with Protocols +-------------------------- + +``Protocol`` instances do not instantiate ``Context`` directly; they declare ``ProtocolUnit`` objects via ``Protocol._create``. +When an execution backend walks the resulting +``ProtocolDAG`` it constructs a ``Context`` for each unit using the DAG label, unit label, scratch directory, and the configured ``ExternalStorage`` implementations. +The backend might provide different ``ExternalStorage`` implementations (e.g., local filesystem, object store, cluster scratch) depending on where the work runs, but the ``Context`` API seen by ``ProtocolUnit`` authors stays consistent. + +Because the execution backend is in charge of creating the contexts, protocol authors can rely on ``ctx`` always being populated with valid storage managers and paths that are safe to write to from distributed workers. + + +Migrating from legacy Context usage +----------------------------------- + +Before ``Context`` was rewritten, it was a simple data class with two ``pathlib.Path`` handles: ``scratch`` and ``shared``. +Existing protocols adopted a variety of implicit conventions around those attributes. +Follow this checklist when migrating old protocols: + +1. **Swap file paths to StorageManager APIs.** Calls like ``ctx.shared / + "filename"`` should be replaced with the helper methods offered by + :class:`StorageManager`. For example: + +.. code-block:: python + + path = ctx.shared.scratch_dir / "myfile.dat" + +becomes: + +.. code-block:: python + + ctx.shared.register("myfile.dat") + + +2. **Avoid storing heavy objects in Python outputs.** Older protocols often + returned raw ``Path`` objects pointing at scratch files. Instead, register + the file with the storage manager and return the storage key (a string) from + ``_execute`` as shown in the example above. Downstream units can then call + :meth:`.StorageManager.load`. + +3. **Handle ctx.permanent.** There was no equivalent in the legacy API. + Decide which results must persist between DAG executions and write them via + ``ctx.permanent``. For migration you can start by mirroring whatever used + to live in ``ctx.shared`` and refine later. + +4. **Expect automatic cleanup.** Old contexts typically left stdout/stderr + directories around. The new context removes these when the unit finishes. + If your code tried to re-read the capture directories after ``_execute`` + returned, move that logic earlier or rely on the logged data captured in the + ``ProtocolUnitResult``. + +5. **Stop constructing Context manually.** + Some bespoke execution scripts once instantiated ``Context(scratch=..., shared=...)`` by hand. + That pattern is obsolete because the constructor now requires ``ExternalStorage`` objects. + Instead, rely on the execution backend to build contexts. + For unit tests use the helpers in ``gufe.storage.externalresource`` (e.g., ``MemoryStorage``) to create the necessary storage instances. + +If you need more information on how to use these concepts, checkout out: :doc:`../how-tos/protocol`. diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index c52ebaca0..a67db6e6a 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -11,5 +11,7 @@ utilize **gufe** APIs. tokenizables included_models - serialization + storage + context logging + serialization diff --git a/docs/concepts/storage.rst b/docs/concepts/storage.rst new file mode 100644 index 000000000..c4ebc885d --- /dev/null +++ b/docs/concepts/storage.rst @@ -0,0 +1,215 @@ +.. _concepts-storage: + +How storage is handled in **gufe** +================================== + +**gufe** abstracts storage into a reusable interface using the :class:`.ExternalStorage` abstract base class. +This abstraction enables the storage of any file or byte stream using various storage backends without changing application code. + +Overview +-------- + +The storage system is designed to handle files (or byte data) that need to be stored in some location. +Instead of embedding the data, objects can store a reference (a unique string indicating the object's location, such as a path) to where the data is stored externally. This approach provides several benefits: + +* **Efficiency**: Large objects don't need to be serialized multiple times +* **Flexibility**: Different storage backends (local filesystem, cloud storage, in-memory) can be used interchangeably +* **Deduplication**: The same data can be referenced by multiple objects +* **Lazy Loading**: Data is only loaded when needed + +The Storage Architecture +------------------------- + +The storage system consists of several key components: + +``ExternalStorage`` Base Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`.ExternalStorage` abstract base class defines the interface that all storage implementations must provide. This class provides: + +* **Store operations**: ``store_bytes()`` and ``store_path()`` to store data +* **Load operations**: ``load_stream()`` to retrieve data as a stream +* **Management**: ``exists()``, ``delete()``, ``transfer_file()``, and ``iter_contents()`` for managing stored data + +All storage operations use a string as a location identifier for the stored data. +Convention is often to use path-like formatting, such as `"datasets/sample1.txt"`, but any unique string can be used. + +Storage Implementations +----------------------- + +**gufe** provides several built-in implementations of :class:`.ExternalStorage`: + +``FileStorage`` +~~~~~~~~~~~~~~~ + +The :class:`.FileStorage` implementation stores data on the local filesystem. It requires a root directory path and organizes stored files using the location string as a relative path: + +.. code-block:: python + + from pathlib import Path + from gufe.storage.externalresource import FileStorage + + # Create a file storage backend + storage = FileStorage(root_dir=Path("/path/to/storage")) + + # Store some data + data = b"Hello, World!" + storage.store_bytes("datasets/sample1.txt", data) + + # Check if data exists + if storage.exists("datasets/sample1.txt"): + # Load the data + with storage.load_stream("datasets/sample1.txt") as stream: + loaded_data = stream.read() + assert loaded_data == data + + # Delete the data + storage.delete("datasets/sample1.txt") + +``FileStorage`` automatically creates any necessary parent directories when storing files. + +``MemoryStorage`` +~~~~~~~~~~~~~~~~~ + +The :class:`.MemoryStorage` implementation stores data in a Python dictionary. This is primarily useful for testing and prototyping: + +.. code-block:: python + + from gufe.storage.externalresource import MemoryStorage + + # Create an in-memory storage backend + storage = MemoryStorage() + + # Store some data + data = b"Hello, World!" + storage.store_bytes("datasets/sample1.txt", data) + + # Load the data back + with storage.load_stream("datasets/sample1.txt") as stream: + loaded_data = stream.read() + +.. warning:: + ``MemoryStorage`` is not intended for production use and all data is lost when the Python process exits. + +Implementing Custom Storage Backends +------------------------------------- + +To create a custom storage backend, subclass :class:`.ExternalStorage` and implement all the abstract methods: + +.. code-block:: python + + from gufe.storage.externalresource.base import ExternalStorage + from typing import ContextManager + + class MyCustomStorage(ExternalStorage): + """A custom storage implementation.""" + + def _store_bytes(self, location: str, byte_data: bytes): + """Store bytes at the given location.""" + # Implement storage logic + pass + + def _store_path(self, location: str, path): + """Store a file at the given path.""" + # Implement storage logic + pass + + def _load_stream(self, location: str) -> ContextManager: + """Return a context manager that yields a bytes-like object.""" + # Implement loading logic + pass + + def _exists(self, location: str) -> bool: + """Check if data exists at the location.""" + # Implement existence check + pass + + def _delete(self, location: str): + """Delete data at the location.""" + # Implement deletion logic + pass + + def _get_filename(self, location: str) -> str: + """Return a filename for the location.""" + # Implement filename generation + pass + + def _iter_contents(self, prefix: str = ""): + """Iterate over stored locations matching the prefix.""" + # Implement iteration logic + pass + + def _get_hexdigest(self, location: str) -> str: + """Return MD5 hexdigest of the data (optional override).""" + # Can override for performance improvements + pass + +.. note:: + All storage methods should be blocking operations, even if the underlying storage backend supports asynchronous operations. + +StorageManager +-------------- + +The :class:`.StorageManager` class provides a higher-level interface for managing storage operations within a computational workflow. + +.. note:: + ``StorageManager`` is largely used by the :class:`.Context` class and should not be instantiated in protocols. + In general, protocol developers will only use the ``register`` and ``load`` functions. + +It handles the transfer of files between a scratch directory and external storage (such as shared or permanent storage): + +.. code-block:: python + + from pathlib import Path + from gufe.storage import StorageManager + from gufe.storage.externalresource import FileStorage + + # Set up storage + storage = FileStorage(root_dir=Path("/path/to/storage")) + scratch_dir = Path("/path/to/scratch") + + # Create a storage manager for a specific DAG and unit + manager = StorageManager( + scratch_dir=scratch_dir, + storage=storage, + dag_label="my_experiment", + unit_label="transformation_1" + ) + + # Register files for later transfer + out = manager.register("trajectory.dcd") + out2 = manager.register("results.json") + # Note: out and out2 are pre-namespaced values that allow storage items to be passed around + + # Transfer all registered files to external storage + manager._transfer() + # Explicitly transfer a specific file. This is useful for checkpointing and other instances where you want to ensure a file is transferred to storage at a specific time in the protocol. + manager.transfer_file("trajectory.dcd") + # Load files from external storage + trajectory_data = manager.load(out) + results_json = manager.load(out2) + +The ``StorageManager`` uses a namespace combining the ``dag_label`` and ``unit_label`` to organize files in the external storage backend. +To see how these work in practice see our documentation on :doc:`context`. + + +Error Handling +-------------- + +The storage system defines several exceptions in :mod:`gufe.storage.errors`: + +* :class:`.ExternalResourceError`: Base class for storage-related errors +* :class:`.MissingExternalResourceError`: Raised when attempting to access non-existent data +* :class:`.ChangedExternalResourceError`: Raised when metadata verification fails + +These exceptions can be caught and handled appropriately in application code: + +.. code-block:: python + + from gufe.storage.errors import MissingExternalResourceError + + try: + with storage.load_stream("nonexistent_file.txt") as stream: + data = stream.read() + except MissingExternalResourceError: + print("File not found in storage") diff --git a/news/rework-context.rst b/news/rework-context.rst new file mode 100644 index 000000000..49b0b5c4a --- /dev/null +++ b/news/rework-context.rst @@ -0,0 +1,25 @@ +**Added:** + +* ``StorageManager`` for managing files during unit execution. Allows for auto transfer to storage mediums following execution of a unit. + +**Changed:** + +* ``Context`` now uses ``StorageManager`` to handle operations in protocol units. Units must now register files in scratch to be transferred shared or permanent storage. +* ``ProtocolUnit``s now have permanent storage which can be used to outlast a running ``ProtocolDAG``. + + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/gufe/protocols/protocoldag.py b/src/gufe/protocols/protocoldag.py index f285b8fea..f6e5676f9 100644 --- a/src/gufe/protocols/protocoldag.py +++ b/src/gufe/protocols/protocoldag.py @@ -11,6 +11,8 @@ import networkx as nx +from gufe.storage.externalresource.base import ExternalStorage + from ..tokenization import GufeKey, GufeTokenizable from .errors import MissingUnitResultError, ProtocolDAGError, ProtocolDAGExecutionError, ProtocolUnitFailureError from .protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult @@ -403,7 +405,8 @@ def _get_valid_unit_results( def execute_DAG( protocoldag: ProtocolDAG, *, - shared_basedir: Path, + shared_storage: ExternalStorage, + perm_storage: ExternalStorage, scratch_basedir: Path, cache_basedir: Path | None = None, stderr_basedir: Path | None = None, @@ -420,12 +423,15 @@ def execute_DAG( Parameters ---------- protocoldag : ProtocolDAG - The :class:``ProtocolDAG`` to execute. - shared_basedir : Path - Filesystem path to use for shared space that persists across whole DAG - execution. Used by a `ProtocolUnit` to pass file contents to dependent - class:``ProtocolUnit`` instances. + The :class:`ProtocolDAG` to execute. + shared_storage : ExternalStorage + Storage for shared files that persist across the entire DAG execution. + Used by ProtocolUnits to pass file contents to dependent ProtocolUnits. + perm_storage : ExternalStorage + Permanent storage for files that should persist after DAG execution. scratch_basedir : Path + Base directory for ProtocolUnit scratch space. Each ProtocolUnit gets + its own scratch directory under this base directory. Filesystem path to use for `ProtocolUnit` `scratch` space. cache_basedir : Path | None = None Filesystem path to use for `ProtocolUnitResult` caching during @@ -433,29 +439,41 @@ def execute_DAG( and it will not be able to resume DAG execution from the last successfully finished `ProtocolUnit`. stderr_basedir : Path | None - Filesystem path to use for `ProtocolUnit` `stderr` archiving. + Base directory for ProtocolUnit stderr archiving. If None, stderr + is not archived. stdout_basedir : Path | None - Filesystem path to use for `ProtocolUnit` `stdout` archiving. + Base directory for ProtocolUnit stdout archiving. If None, stdout + is not archived. keep_shared : bool - If True, don't remove shared directories for `ProtocolUnit`s after - the `ProtocolDAG` is executed. + If True, shared directories are not removed after DAG execution. + Default is False. keep_scratch : bool + If True, scratch directories are not removed after each ProtocolUnit + execution. Default is False. If True, don't remove scratch directories for a `ProtocolUnit` after it is executed. keep_cache : bool If True, don't remove the cache directory which contains the serialized `ProtocolUnitResult` for all executed `ProtocolUnit`/s. raise_error : bool - If True, raise an exception if a ProtocolUnit fails, default True - if False, any exceptions will be stored as `ProtocolUnitFailure` - objects inside the returned `ProtocolDAGResult` + If True, raises an exception when a ProtocolUnit fails. If False, + failures are stored as ProtocolUnitFailure objects in the returned + ProtocolDAGResult. Default is True. n_retries : int - the number of times to attempt, default 0, i.e. try once and only once + Number of times to retry failed ProtocolUnits. Default is 0 (no retries). Returns ------- ProtocolDAGResult - The result of executing the `ProtocolDAG`. + Result object containing the execution results of all ProtocolUnits + in the DAG, including both successes and failures. + + Notes + ----- + The function executes ProtocolUnits in DAG-dependency order, ensuring that + each ProtocolUnit's dependencies are executed before the ProtocolUnit itself. + If a ProtocolUnit fails and raise_error is True, execution stops immediately. + Otherwise, execution continues with the next ProtocolUnit. Raises ------ @@ -492,11 +510,8 @@ def execute_DAG( inputs = _pu_to_pur(unit.inputs, results) attempt = 0 + result = None while attempt <= n_retries: - shared = shared_basedir / f"shared_{str(unit.key)}_attempt_{attempt}" - shared_paths.append(shared) - shared.mkdir(exist_ok=True) - scratch = scratch_basedir / f"scratch_{str(unit.key)}_attempt_{attempt}" scratch.mkdir(exist_ok=True) @@ -510,17 +525,16 @@ def execute_DAG( stdout = stdout_basedir / f"stdout_{str(unit.key)}_attempt_{attempt}" stdout.mkdir(exist_ok=True) - context = Context(shared=shared, scratch=scratch, stderr=stderr, stdout=stdout) - # execute - result = unit.execute(context=context, raise_error=raise_error, **inputs) - all_results.append(result) - - # clean up outputs - if stderr: - shutil.rmtree(stderr) - if stdout: - shutil.rmtree(stdout) + with Context( + dag_label=str(protocoldag.key), + unit_label=str(unit.key), + shared_storage=shared_storage, + permanent_storage=perm_storage, + scratch=scratch, + ) as ctx: + result = unit.execute(context=ctx, raise_error=raise_error, **inputs) + all_results.append(result) if not keep_scratch: shutil.rmtree(scratch) @@ -535,7 +549,7 @@ def execute_DAG( break attempt += 1 - if not result.ok(): + if result is not None and not result.ok(): break if not keep_shared: diff --git a/src/gufe/protocols/protocolunit.py b/src/gufe/protocols/protocolunit.py index c70d0ac38..12550ccc5 100644 --- a/src/gufe/protocols/protocolunit.py +++ b/src/gufe/protocols/protocolunit.py @@ -10,28 +10,112 @@ import abc import datetime +import shutil import traceback import uuid from copy import copy -from dataclasses import dataclass from pathlib import Path from typing import Any +from gufe.storage.externalresource.base import ExternalStorage + +from ..storage.storagemanager import StorageManager from ..tokenization import GufeKey, GufeTokenizable from .errors import ExecutionInterrupt -@dataclass class Context: - """Data class for passing around execution context components to + """ + Class for passing around execution context components to `ProtocolUnit._execute`. + This class provides execution context information to ProtocolUnit subclasses + when executing their `_execute` method. + + Parameters + ---------- + dag_label : str + Label for the ProtocolDAG this unit belongs to. + unit_label : str + Label for this specific ProtocolUnit. + scratch_dir : Path + Path to the scratch directory for temporary files. + shared_storage : ExternalStorage + Storage manager for shared resources that can be accessed by other units. + permanent_storage : ExternalStorage + Storage manager for permanent resources that persist after execution. + stderr : Path, optional + Path to directory for capturing stderr output. + stdout : Path, optional + Path to directory for capturing stdout output. + + Attributes + ---------- + scratch : Path + Path to the scratch directory for temporary files. + shared : StorageManager + Storage manager for shared resources. + permanent : StorageManager + Storage manager for permanent resources. + stderr : Path or None + Path to directory for capturing stderr output. + stdout : Path or None + Path to directory for capturing stdout output. + + Notes + ----- + This class implements the context manager protocol, automatically transferring + shared and permanent storage resources when exiting the context. + + Examples + -------- + >>> with Context( + ... dag_label="my_dag", + ... unit_label="unit1", + ... scratch="/tmp/scratch", + ... shared_storage=shared_store, + ... permanent_storage=perm_store, + ... ) as ctx: + ... # use ctx within the ProtocolUnit._execute method + ... pass """ - scratch: Path - shared: Path - stderr: Path | None = None - stdout: Path | None = None + def __init__( + self, + dag_label: str, + unit_label: str, + scratch: Path, + shared_storage: ExternalStorage, + permanent_storage: ExternalStorage, + stderr: Path | None = None, + stdout: Path | None = None, + ): + self.scratch = scratch + self.shared = StorageManager( + scratch_dir=scratch, + storage=shared_storage, + unit_label=unit_label, + dag_label=dag_label, + ) + self.permanent = StorageManager( + scratch_dir=scratch, + storage=permanent_storage, + unit_label=unit_label, + dag_label=dag_label, + ) + self.stderr = stderr + self.stdout = stdout + + def __enter__(self) -> Context: + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shared._transfer() + self.permanent._transfer() + if self.stderr: + shutil.rmtree(self.stderr) + if self.stdout: + shutil.rmtree(self.stdout) def _list_dependencies(inputs, cls): diff --git a/src/gufe/storage/__init__.py b/src/gufe/storage/__init__.py index 853c55a2a..a81434523 100644 --- a/src/gufe/storage/__init__.py +++ b/src/gufe/storage/__init__.py @@ -1 +1,3 @@ """How to store objects across simulation campaigns""" + +from .storagemanager import StorageManager diff --git a/src/gufe/storage/storagemanager.py b/src/gufe/storage/storagemanager.py new file mode 100644 index 000000000..cd496f252 --- /dev/null +++ b/src/gufe/storage/storagemanager.py @@ -0,0 +1,133 @@ +from pathlib import Path + +from .externalresource import ExternalStorage + + +class StorageManager: + """Manage storage operations for files in a DAG. + + This class provides a context manager for working with storage systems, + allowing registration, loading, and transfer of files between scratch + directory and external storage. + + Parameters + ---------- + scratch_dir : Path + Path to the scratch directory where files are temporarily stored. + storage : ExternalStorage + External storage system for persistent file storage. + dag_label : str + Label for the directed acyclic graph (DAG) this storage manager belongs to. + unit_label : str + Label for the specific unit within the DAG. + + Attributes + ---------- + scratch_dir : Path + Path to the scratch directory. + storage : ExternalStorage + External storage system. + registry : set[str] + Set of registered filenames to be transferred. + namespace : str + Namespace combining dag_label and unit_label for file organization. + """ + + def __init__( + self, + scratch_dir: Path, + storage: ExternalStorage, + dag_label: str, + unit_label: str, + ): + self.scratch_dir = scratch_dir + self.storage = storage + self.registry: set[str] = set() + self.namespace = f"{dag_label}/{unit_label}" + + @staticmethod + def append_to_namespace(namespace: str, filename: str) -> str: + """Append a filenmae to a namespace, mainly used to + make testing easier. + + Parameters + ---------- + namespace : str + The namespace prefix for the file. + filename : str + The filename to be appended to the namespace. + + Returns + ------- + str + Combined namespace and filename as a storage path. + """ + # We opt _not_ to use Paths because these aren't actually path objects + return f"{namespace}/{filename}" + + def register(self, filename: str) -> str: + """Register a filename for later transfer to external storage. + + Parameters + ---------- + filename : str + The filename to register for transfer. + + Returns + ------- + str + The globally namespaced path to be used by other units + """ + self.registry.add(filename) + return self.append_to_namespace(self.namespace, filename) + + def load(self, filename: str) -> bytes: + """Load an item from external storage. + + Parameters + ---------- + filename : str + The filename to load from external storage. + + Returns + ------- + bytes + The content of the loaded file. + """ + with self.storage.load_stream(filename) as f: + stored = f.read() + return stored + + def __contains__(self, filename: str) -> bool: + return filename in self.registry + + def _transfer(self): + """Transfer all registered files to external storage. + + Transfers each filename currently registered on this manager into the + configured external storage backend using this manager's namespace. + + Raises + ------ + FileNotFoundError + If any registered file does not exist in ``scratch_dir``. + """ + for filename in self.registry: + self.transfer_file(filename) + + def transfer_file(self, filename: str) -> None: + """Transfer a single file from scratch space to external storage. + + Parameters + ---------- + filename : str + Relative filename (within ``scratch_dir``) to transfer. + + Raises + ------ + FileNotFoundError + If ``filename`` does not exist in ``scratch_dir``. + """ + path = self.scratch_dir / filename + location = self.append_to_namespace(self.namespace, filename) + self.storage.store_path(location, path) diff --git a/src/gufe/tests/storage/test_storagemanager.py b/src/gufe/tests/storage/test_storagemanager.py new file mode 100644 index 000000000..e9711e642 --- /dev/null +++ b/src/gufe/tests/storage/test_storagemanager.py @@ -0,0 +1,281 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/gufe + +import pathlib +import tempfile + +import pytest + +from gufe.storage.externalresource import FileStorage, MemoryStorage +from gufe.storage.externalresource.base import ExternalStorage +from gufe.storage.storagemanager import StorageManager + + +@pytest.fixture +def tmp_scratch_dir(): + """Create a temporary scratch directory.""" + with tempfile.TemporaryDirectory() as tmp_dir: + yield pathlib.Path(tmp_dir) + + +@pytest.fixture +def memory_storage(): + """Create a MemoryStorage instance for testing.""" + return MemoryStorage() + + +@pytest.fixture +def file_storage(tmp_path): + """Create a FileStorage instance for testing.""" + return FileStorage(tmp_path) + + +@pytest.fixture +def storage_manager(tmp_scratch_dir, memory_storage): + """Create a StorageManager instance with MemoryStorage.""" + return StorageManager(tmp_scratch_dir, memory_storage, dag_label="MEM", unit_label="1") + + +@pytest.fixture +def storage_manager_file_storage(tmp_scratch_dir, file_storage): + """Create a StorageManager instance with FileStorage.""" + return StorageManager(tmp_scratch_dir, file_storage, dag_label="FILE", unit_label="1") + + +class TestStorageManager: + """Test the StorageManager class.""" + + def test_init(self, tmp_scratch_dir, memory_storage): + """Test StorageManager initialization.""" + manager = StorageManager(tmp_scratch_dir, memory_storage, dag_label="MEM", unit_label="1") + + assert manager.scratch_dir == tmp_scratch_dir + assert manager.storage == memory_storage + assert isinstance(manager.registry, set) + assert len(manager.registry) == 0 + assert manager.namespace == "MEM/1" + + def test_init_with_file_storage(self, tmp_scratch_dir, file_storage): + """Test StorageManager initialization with FileStorage.""" + manager = StorageManager(tmp_scratch_dir, file_storage, dag_label="FILE", unit_label="1") + + assert manager.scratch_dir == tmp_scratch_dir + assert manager.storage == file_storage + assert isinstance(manager.registry, set) + + def test_convert_to_namespace(self): + filename = "test_file.txt" + + out = StorageManager.append_to_namespace("MEM/1", filename) + assert out == "MEM/1/test_file.txt" + + def test_register(self, storage_manager): + """Test registering a filename.""" + filename = "test_file.txt" + + # Initially registry should be empty + assert filename not in storage_manager.registry + assert filename not in storage_manager + + # Register the file + out = storage_manager.register(filename) + assert out == storage_manager.append_to_namespace(storage_manager.namespace, filename) + + # Check it's now in the registry + assert filename in storage_manager.registry + assert filename in storage_manager + + def test_register_multiple_files(self, storage_manager): + """Test registering multiple filenames.""" + files = ["file1.txt", "file2.txt", "dir/file3.txt"] + + for filename in files: + storage_manager.register(filename) + + # Check all files are registered + for filename in files: + assert filename in storage_manager.registry + assert filename in storage_manager + + # Check registry size + assert len(storage_manager.registry) == 3 + + def test_register_duplicate_file(self, storage_manager): + """Test registering the same file multiple times.""" + filename = "duplicate.txt" + + # Register twice + storage_manager.register(filename) + storage_manager.register(filename) + + # Should only appear once (set behavior) + assert filename in storage_manager.registry + assert len(storage_manager.registry) == 1 + + def test_contains(self, storage_manager): + """Test the __contains__ method.""" + filename = "contains_test.txt" + + # Initially should not contain + assert filename not in storage_manager + + # Register and check again + storage_manager.register(filename) + assert filename in storage_manager + + # Check a non-existent file + assert "nonexistent.txt" not in storage_manager + + def test_transfer_file(self, tmp_scratch_dir, storage_manager): + """Test the transfer_file method.""" + filename = "somefile.txt" + data = b"Hello World!" + + path = tmp_scratch_dir / filename + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + + storage_manager.register(filename) + + assert filename in storage_manager + storage_manager.transfer_file(filename) + + # Check for this file + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + assert storage_manager.storage.exists(namespaced_filename) + + with storage_manager.storage.load_stream(namespaced_filename) as f: + stored_content = f.read() + assert stored_content == data + + def test_transfer_file_not_found(self, tmp_scratch_dir, storage_manager): + """Test _transfer when registered file doesn't exist.""" + filename = "nonexistent.txt" + + # Register non-existent file + storage_manager.register(filename) + + # Should raise FileNotFoundError when trying to transfer + with pytest.raises(FileNotFoundError): + storage_manager.transfer_file(filename) + + def test_transfer_empty_registry(self, storage_manager): + """Test _transfer with empty registry.""" + # Should not raise any errors + storage_manager._transfer() + + # Storage should remain empty + assert list(storage_manager.storage) == [] + + def test_transfer_with_files(self, storage_manager, tmp_scratch_dir): + """Test _transfer with actual files.""" + # Create test files in scratch directory + test_files = {"test1.txt": b"Hello World", "test2.txt": b"Another file", "subdir/test3.txt": b"Nested file"} + + # Create files and register them + for filename, content in test_files.items(): + file_path = tmp_scratch_dir / filename + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_bytes(content) + storage_manager.register(filename) + + # Transfer files + storage_manager._transfer() + + # Check files are now in storage + for filename in test_files: + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + assert storage_manager.storage.exists(namespaced_filename) + + # Verify content + with storage_manager.storage.load_stream(namespaced_filename) as f: + stored_content = f.read() + assert stored_content == test_files[filename] + + def test_transfer_with_file_storage(self, storage_manager_file_storage, tmp_scratch_dir): + """Test _transfer with FileStorage.""" + # Create test file + filename = "transfer_test.txt" + content = b"Test content for file storage" + + # Create file in scratch directory + file_path = tmp_scratch_dir / filename + file_path.write_bytes(content) + storage_manager_file_storage.register(filename) + + # Transfer file + storage_manager_file_storage._transfer() + + # Check file exists in file storage + namespaced_filename = StorageManager.append_to_namespace(storage_manager_file_storage.namespace, filename) + assert storage_manager_file_storage.storage.exists(namespaced_filename) + + # Verify content + expected_path = storage_manager_file_storage.storage.root_dir / namespaced_filename + assert expected_path.exists() + assert expected_path.read_bytes() == content + + def test_transfer_not_found(self, storage_manager): + """Test _transfer when registered file doesn't exist.""" + filename = "nonexistent.txt" + + # Register non-existent file + storage_manager.register(filename) + + # Should raise FileNotFoundError when trying to transfer + with pytest.raises(FileNotFoundError): + storage_manager._transfer() + + def test_registry_persistence(self, storage_manager, tmp_scratch_dir): + """Test that registry persists across operations.""" + filename = "persistent.txt" + + # Create the file first + file_path = tmp_scratch_dir / filename + file_path.write_bytes(b"test content") + + # Register file + storage_manager.register(filename) + assert filename in storage_manager.registry + + # Perform other operations + storage_manager._transfer() # Should not clear registry + + # Registry should still contain the file + assert filename in storage_manager.registry + + def test_load(self, storage_manager, tmp_scratch_dir): + filename = "item.txt" + file_path = tmp_scratch_dir / filename + content = b"test content" + file_path.write_bytes(content) + + # Register file + storage_manager.register(filename) + assert filename in storage_manager.registry + + # Perform other operations + storage_manager._transfer() # Should not clear registry + + # Registry should still contain the file + assert filename in storage_manager.registry + namespaced_filename = StorageManager.append_to_namespace(storage_manager.namespace, filename) + out = storage_manager.load(namespaced_filename) + assert out == content + + def test_prepopulated_storage_load(self, file_storage: ExternalStorage, tmp_scratch_dir): + # This test highlights a case where a unit wants to load something + # from shared storage. We basically load content into the medium that we want to fetch later. + + # Store content + contents = b"test content" + name = "test.txt" + file_storage.store_bytes(name, contents) + # Provide this prepopulated storage to a manager + manager = StorageManager(tmp_scratch_dir, storage=file_storage, dag_label="TEST", unit_label="1") + # Validate that content is loaded + out = manager.load(name) + assert out == contents + + # Validate that the registry has not loaded the item + assert len(manager.registry) == 0 diff --git a/src/gufe/tests/test_protocol.py b/src/gufe/tests/test_protocol.py index f58e4b503..fa5179be8 100644 --- a/src/gufe/tests/test_protocol.py +++ b/src/gufe/tests/test_protocol.py @@ -2,7 +2,6 @@ # For details, see https://github.com/OpenFreeEnergy/gufe import datetime import itertools -import pathlib from collections import defaultdict from collections.abc import Iterable, Sized from typing import Any @@ -29,6 +28,7 @@ ) from gufe.protocols.errors import ProtocolValidationError from gufe.protocols.protocoldag import execute_DAG +from gufe.storage.externalresource import MemoryStorage from .test_tokenization import GufeTokenizableTestsMixin @@ -210,7 +210,7 @@ def instance(self): return DummyProtocol(settings=DummyProtocol.default_settings()) @pytest.fixture - def protocol_dag(self, solvated_ligand, vacuum_ligand, tmp_path): + def protocol_dag(self, tmp_path, solvated_ligand, vacuum_ligand): protocol = DummyProtocol(settings=DummyProtocol.default_settings()) dag = protocol.create( stateA=solvated_ligand, @@ -218,21 +218,28 @@ def protocol_dag(self, solvated_ligand, vacuum_ligand, tmp_path): name="a dummy run", mapping=None, ) + # shared = tmp_path / "shared" + # shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) - - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) dagresult: ProtocolDAGResult = execute_DAG( - dag, shared_basedir=shared, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout + dag, + shared_storage=shared, + perm_storage=perm, + scratch_basedir=scratch, + stderr_basedir=stderr, + stdout_basedir=stdout, + keep_scratch=True, ) return protocol, dag, dagresult @@ -246,21 +253,22 @@ def protocol_dag_broken(self, solvated_ligand, vacuum_ligand, tmp_path): name="a broken dummy run", mapping=None, ) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) dagfailure: ProtocolDAGResult = execute_DAG( dag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -388,16 +396,17 @@ def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, t name="a broken dummy run", mapping=None, ) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) - - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() + with pytest.raises(ValueError, match="I have failed my mission"): execute_DAG( dag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, raise_error=True, ) @@ -484,7 +493,7 @@ def instance(self, protocol_dag): protocol, dag, dagresult = protocol_dag return dag - class TestProtocolDAGResult(ProtocolDAGTestsMixin): + class TestProtocolDAGResult: cls = ProtocolDAGResult repr = None @@ -534,7 +543,7 @@ def test_protocol_unit_successes(self, instance: ProtocolDAGResult): assert len(instance.protocol_unit_successes) == 23 assert all(isinstance(i, ProtocolUnitResult) for i in instance.protocol_unit_successes) - class TestProtocolDAGResultFailure(ProtocolDAGTestsMixin): + class TestProtocolDAGResultFailure: cls = ProtocolDAGResult repr = None @@ -650,13 +659,13 @@ def test_create(self, dag): assert len(dag.protocol_units) == 3 def test_gather(self, protocol, dag, tmp_path): - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + dag_result = execute_DAG(dag, shared_storage=shared, scratch_basedir=scratch, perm_storage=perm) assert dag_result.ok() @@ -666,14 +675,14 @@ def test_gather(self, protocol, dag, tmp_path): assert result.get_uncertainty() == 3 def test_terminal_units(self, protocol, dag, tmp_path): - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) # we have no dependencies, so this should be all three Unit results - dag_result = execute_DAG(dag, shared_basedir=shared, scratch_basedir=scratch) + dag_result = execute_DAG(dag, shared_storage=shared, perm_storage=perm, scratch_basedir=scratch) terminal_results = dag_result.terminal_protocol_unit_results @@ -802,14 +811,16 @@ def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmp_path): mapping=None, ) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) - scratch = pathlib.Path(tmp_path / "scratch") + shared = MemoryStorage() + perm = MemoryStorage() + + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) r = execute_DAG( dag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, keep_shared=True, keep_scratch=True, @@ -821,13 +832,13 @@ def test_execute_DAG_retries(solvated_ligand, vacuum_ligand, tmp_path): number_unit_failures = len(r.protocol_unit_failures) number_unit_results = len(r.protocol_unit_results) - number_dirs = len(list(shared.iterdir())) + number_dirs = len(list(shared)) # failed first attempt of BrokenSimulationUnit, failed 3 retries assert number_unit_failures == 4 # InitializeUnit and 21 SimulationUnits run before guaranteed # final failure - assert number_unit_results == number_dirs == 26 + assert number_unit_results == 26 def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmp_path): @@ -838,15 +849,17 @@ def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmp_path): mapping=None, ) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) - scratch = pathlib.Path(tmp_path / "scratch") + shared = MemoryStorage() + perm = MemoryStorage() + + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) with pytest.raises(ValueError): r = execute_DAG( dag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, keep_shared=True, keep_scratch=True, diff --git a/src/gufe/tests/test_protocoldag.py b/src/gufe/tests/test_protocoldag.py index 8391bf01a..f1c22c5f3 100644 --- a/src/gufe/tests/test_protocoldag.py +++ b/src/gufe/tests/test_protocoldag.py @@ -1,25 +1,33 @@ # This code is part of gufe and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe import os -import pathlib import pytest from openff.units import unit import gufe -from gufe.protocols import execute_DAG, protocoldag +from gufe.protocols import ProtocolDAG, execute_DAG, protocoldag +from gufe.protocols.protocolunit import Context +from gufe.storage.externalresource.filestorage import FileStorage +from gufe.storage.storagemanager import StorageManager class WriterUnit(gufe.ProtocolUnit): @staticmethod - def _execute(ctx, **inputs): + def _execute(ctx: Context, **inputs): my_id = inputs["identity"] - with open(os.path.join(ctx.shared, f"unit_{my_id}_shared.txt"), "w") as out: - out.write(f"unit {my_id} existed!\n") - with open(os.path.join(ctx.scratch, f"unit_{my_id}_scratch.txt"), "w") as out: - out.write(f"unit {my_id} was here\n") + unit_shared_name = f"unit_{my_id}_shared.txt" + ctx.shared.register(unit_shared_name) + unit_shared = ctx.scratch / unit_shared_name + + unit_scratch_name = f"unit_{my_id}_scratch.txt" + unit_scratch = ctx.scratch / unit_scratch_name + with open(unit_shared, "w") as out: + out.write(f"unit {my_id} existed\n") + with open(unit_scratch, "w") as out: + out.write(f"unit {my_id} was here\n") if ctx.stderr: with open(os.path.join(ctx.stderr, f"unit_{my_id}_stderr"), "w") as out: out.write(f"unit {my_id} wrote to stderr") @@ -79,28 +87,37 @@ def writefile_dag(): @pytest.mark.parametrize("keep_scratch", [False, True]) @pytest.mark.parametrize("keep_cache", [False, True]) @pytest.mark.parametrize("capture_stderr_stdout", [False, True]) -def test_execute_dag(tmp_path, keep_shared, keep_scratch, keep_cache, writefile_dag, capture_stderr_stdout): - shared = pathlib.Path(tmp_path / "shared") +def test_execute_dag( + tmp_path, keep_shared, keep_scratch, keep_cache, writefile_dag: ProtocolDAG, capture_stderr_stdout +): + scratch = tmp_path / "scratch" + scratch.mkdir(parents=True) + + shared = tmp_path / "shared" shared.mkdir(parents=True) - scratch = pathlib.Path(tmp_path / "scratch") - scratch.mkdir(parents=True) + shared_storage = FileStorage(shared) + + perm = tmp_path / "perm" + perm.mkdir(parents=True) + perm_storage = FileStorage(perm) - cache_basedir = pathlib.Path(tmp_path / "openfe_cache") + cache_basedir = tmp_path / "openfe_cache" cache_basedir.mkdir(parents=True) stderr = None stdout = None if capture_stderr_stdout: - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stoud" stdout.mkdir(parents=True) # run dag execute_DAG( writefile_dag, - shared_basedir=shared, + shared_storage=shared_storage, + perm_storage=perm_storage, scratch_basedir=scratch, cache_basedir=cache_basedir, stderr_basedir=stderr, @@ -109,12 +126,15 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, keep_cache, writefile_ keep_scratch=keep_scratch, keep_cache=keep_cache, ) + # check outputs are as expected # will have produced 4 files in scratch and shared directory + dag_label = str(writefile_dag.key) for pu in writefile_dag.protocol_units: - id = pu.inputs["identity"] - shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{id}_shared.txt") - scratch_file = os.path.join(scratch, f"scratch_{str(pu.key)}_attempt_0", f"unit_{id}_scratch.txt") + identity = pu.inputs["identity"] + # shared_file = os.path.join(shared, f"shared_{str(pu.key)}_attempt_0", f"unit_{identity}_shared.txt") + shared_file = StorageManager.append_to_namespace(f"{dag_label}/{pu.key}", f"unit_{identity}_shared.txt") + scratch_file = os.path.join(scratch, f"scratch_{str(pu.key)}_attempt_0", f"unit_{identity}_scratch.txt") unit_result_file = os.path.join( cache_basedir, f"{str(writefile_dag.key)}-results_cache", f"{str(pu.key)}_unitresults.json" ) @@ -123,7 +143,12 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, keep_cache, writefile_ stderr_file = os.path.join( stderr, f"stderr_{str(pu.key)}_attempt_0", - f"unit_{id}_stderr", + f"unit_{identity}_stderr", + ) + stdout_file = os.path.join( + stdout, + f"stdout_{str(pu.key)}_attempt_0", + f"unit_{identity}_stdout", ) stdout_file = os.path.join(stdout, f"stdout_{str(pu.key)}_attempt_0", f"unit_{id}_stdout") @@ -133,7 +158,7 @@ def test_execute_dag(tmp_path, keep_shared, keep_scratch, keep_cache, writefile_ assert not os.path.exists(stdout_file) if keep_shared: - assert os.path.exists(shared_file) + assert shared_storage.exists(shared_file) else: assert not os.path.exists(shared_file) if keep_scratch: @@ -186,16 +211,23 @@ def test_execute_DAG_cached_unitresults(tmp_path): ) # run all unit_results - shared = pathlib.Path(tmp_path / "shared") + shared = tmp_path / "shared" shared.mkdir(parents=True) - scratch = pathlib.Path(tmp_path / "scratch") + shared_storage = FileStorage(shared) + + perm = tmp_path / "perm" + perm.mkdir(parents=True) + perm_storage = FileStorage(perm) + + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - unit_results_dir = pathlib.Path(tmp_path / "unitresults_cache") + unit_results_dir = tmp_path / "unitresults_cache" protocol_result = execute_DAG( dep_dag, - shared_basedir=shared, + shared_storage=shared_storage, + perm_storage=perm_storage, scratch_basedir=scratch, cache_basedir=unit_results_dir, stderr_basedir=None, @@ -222,7 +254,8 @@ def test_execute_DAG_cached_unitresults(tmp_path): with pytest.warns(UserWarning, match="Unable to read file, skipping"): protocol_result_rerun = execute_DAG( dep_dag, - shared_basedir=shared, + shared_storage=shared_storage, + perm_storage=perm_storage, scratch_basedir=scratch, cache_basedir=unit_results_dir, stderr_basedir=None, @@ -269,16 +302,23 @@ def test_get_valid_unit_results(tmp_path): protocol_units=all_protocol_units, transformation_key=None, ) - shared = pathlib.Path(tmp_path / "shared") + shared = tmp_path / "shared" shared.mkdir(parents=True) - scratch = pathlib.Path(tmp_path / "scratch") + shared_storage = FileStorage(shared) + + perm = tmp_path / "perm" + perm.mkdir(parents=True) + perm_storage = FileStorage(perm) + + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - unit_results_dir = pathlib.Path(tmp_path / "unitresults_cache") + unit_results_dir = tmp_path / "unitresults_cache" protocol_result = execute_DAG( dep_dag, - shared_basedir=shared, + shared_storage=shared_storage, + perm_storage=perm_storage, scratch_basedir=scratch, cache_basedir=unit_results_dir, stderr_basedir=None, diff --git a/src/gufe/tests/test_protocolunit.py b/src/gufe/tests/test_protocolunit.py index c72b4039a..1ba94f56d 100644 --- a/src/gufe/tests/test_protocolunit.py +++ b/src/gufe/tests/test_protocolunit.py @@ -5,9 +5,29 @@ from gufe.protocols.errors import ExecutionInterrupt from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult +from gufe.storage.externalresource import MemoryStorage from gufe.tests.test_tokenization import GufeTokenizableTestsMixin +@pytest.fixture +def scratch_storage(tmpdir): + """Fixture to provide a scratch directory for ProtocolUnit tests.""" + with tmpdir.as_cwd(): + scratch = Path("scratch") + scratch.mkdir(parents=True) + yield scratch + + +@pytest.fixture +def shared_storage(): + yield MemoryStorage() + + +@pytest.fixture +def permanent_storage(): + yield MemoryStorage() + + class DummyUnit(ProtocolUnit): @staticmethod def _execute(ctx: Context, an_input=2, **inputs): @@ -62,25 +82,33 @@ def test_key_differs(self): assert u1.key != u2.key @pytest.mark.parametrize("capture_stderr_stdout", [False, True]) - def test_execute(self, tmp_path, capture_stderr_stdout): + def test_execute(self, tmpdir, scratch_storage, shared_storage, permanent_storage, capture_stderr_stdout): unit = DummyUnit() - shared = Path(tmp_path / "shared") / str(unit.key) - shared.mkdir(parents=True) - - scratch = Path(tmp_path / "scratch") / str(unit.key) - scratch.mkdir(parents=True) - if capture_stderr_stdout: - stderr = Path(tmp_path / "stderr") / str(unit.key) + stderr = Path("stderr") / str(unit.key) stderr.mkdir(parents=True) - stdout = Path(tmp_path / "stdout") / str(unit.key) + stdout = Path("stdout") / str(unit.key) stdout.mkdir(parents=True) - ctx = Context(shared=shared, scratch=scratch, stderr=stderr, stdout=stdout) + ctx = Context( + scratch=scratch_storage, + dag_label="test", + unit_label=unit.key, + stderr=stderr, + stdout=stdout, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) else: - ctx = Context(shared=shared, scratch=scratch) + ctx = Context( + scratch=scratch_storage, + dag_label="test", + unit_label=unit.key, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) u: ProtocolUnitFailure = unit.execute(context=ctx, an_input=3) assert u.exception[0] == "ValueError" @@ -99,16 +127,22 @@ def test_execute(self, tmp_path, capture_stderr_stdout): with pytest.raises(ValueError, match="should always be 2"): unit.execute(context=ctx, raise_error=True, an_input=3) - def test_execute_ExecutionInterrupt(self, tmp_path): + def test_execute_ExecutionInterrupt(self, scratch_storage, shared_storage, permanent_storage): unit = DummyExecutionInterruptUnit() - shared = Path(tmp_path / "shared") / str(unit.key) + shared = Path("shared") / str(unit.key) shared.mkdir(parents=True) - scratch = Path(tmp_path / "scratch") / str(unit.key) + scratch = Path("scratch") / str(unit.key) scratch.mkdir(parents=True) - ctx = Context(shared=shared, scratch=scratch, stderr=None, stdout=None) + ctx = Context( + shared_storage=shared_storage, + permanent_storage=permanent_storage, + dag_label="test", + unit_label=unit.key, + scratch=scratch_storage, + ) with pytest.raises(ExecutionInterrupt): unit.execute(context=ctx, an_input=3) @@ -117,16 +151,22 @@ def test_execute_ExecutionInterrupt(self, tmp_path): assert u.outputs == {"foo": "bar"} - def test_execute_KeyboardInterrupt(self, tmp_path): + def test_execute_KeyboardInterrupt(self, scratch_storage, permanent_storage, shared_storage): unit = DummyKeyboardInterruptUnit() - shared = Path(tmp_path / "shared") / str(unit.key) + shared = Path("shared") / str(unit.key) shared.mkdir(parents=True) - scratch = Path(tmp_path / "scratch") / str(unit.key) + scratch = Path("scratch") / str(unit.key) scratch.mkdir(parents=True) - ctx = Context(shared=shared, scratch=scratch, stderr=None, stdout=None) + ctx = Context( + shared_storage=shared_storage, + permanent_storage=permanent_storage, + dag_label="test", + unit_label=unit.key, + scratch=scratch_storage, + ) with pytest.raises(KeyboardInterrupt): unit.execute(context=ctx, an_input=3) @@ -140,3 +180,79 @@ def test_normalize(self, instance): assert thingy.startswith("DummyUnit-") assert all(t in string.hexdigits for t in thingy.partition("-")[-1]) + + +class TestContext: + """Test the Context class context manager functionality.""" + + def test_context_manager_enter_exit( + self, scratch_storage, shared_storage: MemoryStorage, permanent_storage: MemoryStorage + ): + """Test that Context can be used as a context manager.""" + ctx = Context( + dag_label="test", + unit_label="test_unit", + scratch=scratch_storage, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + ) + file_text = b"Hello World!" + + # Test __enter__ + with ctx as context: + assert context is ctx + assert ctx.shared.scratch_dir == scratch_storage + assert ctx.permanent.scratch_dir == scratch_storage + filename = "test.txt" + test_file = context.scratch / filename + context.shared.register(filename) + context.permanent.register(filename) + with test_file.open("b+w") as f: + f.write(file_text) + + # Test __exit__ - should transfer the file to a namespaced location in shared_storage + assert shared_storage.exists("test/test_unit/test.txt") + with shared_storage.load_stream("test/test_unit/test.txt") as item: + out = item.read() + assert out == file_text + + assert permanent_storage.exists("test/test_unit/test.txt") + with permanent_storage.load_stream("test/test_unit/test.txt") as item: + out = item.read() + assert out == file_text + + def test_context_manager_cleanup_stdout_stderr( + self, scratch_storage, shared_storage: MemoryStorage, permanent_storage: MemoryStorage, tmp_path + ): + stdout: Path = tmp_path / "stdout" + stdout.mkdir() + # We write something into stdout + stderr: Path = tmp_path / "stderr" + stderr.mkdir() + + ctx = Context( + dag_label="test", + unit_label="test_unit", + scratch=scratch_storage, + shared_storage=shared_storage, + permanent_storage=permanent_storage, + stdout=stdout, + stderr=stderr, + ) + # Validate the directory is empty + assert any(Path(stdout).iterdir()) == False + assert any(Path(stderr).iterdir()) == False + file_text = b"Hello world" + with ctx as context: + filename = "test.txt" + stdout_file = context.stdout / filename + stderr_file = context.stderr / filename + with stdout_file.open("b+w") as f: + f.write(file_text) + with stderr_file.open("b+w") as f: + f.write(file_text) + assert any(Path(stdout).iterdir()) == True + assert any(Path(stderr).iterdir()) == True + # Validate we cleanup + assert not Path(stderr).exists() + assert not Path(stdout).exists() diff --git a/src/gufe/tests/test_transformation.py b/src/gufe/tests/test_transformation.py index c78c2f309..be4d85e7b 100644 --- a/src/gufe/tests/test_transformation.py +++ b/src/gufe/tests/test_transformation.py @@ -9,6 +9,7 @@ import gufe from gufe.protocols.errors import ProtocolValidationError from gufe.protocols.protocoldag import execute_DAG +from gufe.storage.externalresource import MemoryStorage from gufe.transformations import NonTransformation, Transformation from .test_protocol import DummyProtocol, DummyProtocolResult @@ -54,21 +55,22 @@ def test_protocol(self, absolute_transformation, tmp_path): protocoldag = tnf.create() - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") - scratch.mkdir(parents=True) + scratch = pathlib.Path("scratch") + scratch.mkdir(parents=True, exist_ok=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) protocoldagresult = execute_DAG( protocoldag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -120,22 +122,23 @@ def test_protocol_extend(self, absolute_transformation, tmp_path): assert isinstance(tnf.protocol, DummyProtocol) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) protocoldag = tnf.create() protocoldagresult = execute_DAG( protocoldag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -144,7 +147,8 @@ def test_protocol_extend(self, absolute_transformation, tmp_path): protocoldag2 = tnf.create(extends=protocoldagresult) protocoldagresult2 = execute_DAG( protocoldag2, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -224,21 +228,22 @@ def test_protocol(self, complex_equilibrium, tmp_path): protocoldag = ntnf.create() - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) protocoldagresult = execute_DAG( protocoldag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -279,22 +284,23 @@ def test_protocol_extend(self, complex_equilibrium, tmp_path): assert isinstance(ntnf.protocol, DummyProtocol) - shared = pathlib.Path(tmp_path / "shared") - shared.mkdir(parents=True) + shared = MemoryStorage() + perm = MemoryStorage() - scratch = pathlib.Path(tmp_path / "scratch") + scratch = tmp_path / "scratch" scratch.mkdir(parents=True) - stderr = pathlib.Path(tmp_path / "stderr") + stderr = tmp_path / "stderr" stderr.mkdir(parents=True) - stdout = pathlib.Path(tmp_path / "stdout") + stdout = tmp_path / "stdout" stdout.mkdir(parents=True) protocoldag = ntnf.create() protocoldagresult = execute_DAG( protocoldag, - shared_basedir=shared, + shared_storage=shared, + perm_storage=perm, scratch_basedir=scratch, stderr_basedir=stderr, stdout_basedir=stdout, @@ -303,8 +309,9 @@ def test_protocol_extend(self, complex_equilibrium, tmp_path): protocoldag2 = ntnf.create(extends=protocoldagresult) protocoldagresult2 = execute_DAG( protocoldag2, - shared_basedir=shared, scratch_basedir=scratch, + shared_storage=shared, + perm_storage=perm, ) protocolresult = ntnf.gather([protocoldagresult, protocoldagresult2])