From c1d557b3f277382c226f7fe62a569d0358140cd5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Jun 2026 04:10:01 +0000 Subject: [PATCH 1/2] Initial plan From 203de8dcd9c0ccfd8f5ce99718878407fe131bbf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 3 Jun 2026 04:17:16 +0000 Subject: [PATCH 2/2] refactor: extract execution unit context item and collection logic --- src/py/mat3ra/wode/units/__init__.py | 4 + src/py/mat3ra/wode/units/context_item.py | 80 +++++++++++++++++ src/py/mat3ra/wode/units/execution.py | 108 +++-------------------- src/py/mat3ra/wode/units/unit_context.py | 61 +++++++++++++ tests/py/test_unit.py | 26 +++++- 5 files changed, 182 insertions(+), 97 deletions(-) create mode 100644 src/py/mat3ra/wode/units/context_item.py create mode 100644 src/py/mat3ra/wode/units/unit_context.py diff --git a/src/py/mat3ra/wode/units/__init__.py b/src/py/mat3ra/wode/units/__init__.py index fec5c294..96801d22 100644 --- a/src/py/mat3ra/wode/units/__init__.py +++ b/src/py/mat3ra/wode/units/__init__.py @@ -1,11 +1,15 @@ +from .context_item import ContextItem from .execution import ExecutionUnit from .execution_unit_input import ExecutionUnitInput from .subworkflow import SubworkflowUnit from .unit import Unit +from .unit_context import UnitContext __all__ = [ "Unit", "ExecutionUnit", "ExecutionUnitInput", "SubworkflowUnit", + "ContextItem", + "UnitContext", ] diff --git a/src/py/mat3ra/wode/units/context_item.py b/src/py/mat3ra/wode/units/context_item.py new file mode 100644 index 00000000..b7de830c --- /dev/null +++ b/src/py/mat3ra/wode/units/context_item.py @@ -0,0 +1,80 @@ +from collections.abc import Mapping +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class ContextItem(BaseModel): + name: str + isEdited: bool = True + data: Any = Field(default_factory=dict) + extraData: Dict[str, Any] = Field(default_factory=dict) + + model_config = ConfigDict(extra="ignore") + + @field_validator("data", mode="before") + @classmethod + def _normalize_data(cls, value: Any) -> Dict[str, Any]: + return value if isinstance(value, dict) else {"value": value} + + @field_validator("extraData", mode="before") + @classmethod + def _normalize_extra_data(cls, value: Any) -> Dict[str, Any]: + return value if isinstance(value, dict) else {} + + @classmethod + def from_persisted(cls, item: Mapping[str, Any], *, default_is_edited: bool = True) -> "ContextItem": + name = item.get("name") + if not name: + raise ValueError("Context item must contain a name") + return cls( + name=str(name), + isEdited=bool(item.get("isEdited", default_is_edited)), + data=item.get("data"), + extraData=item.get("extraData"), + ) + + @classmethod + def from_provider_yield(cls, yielded: Mapping[str, Any]) -> "ContextItem": + name = None + data = None + is_edited = True + extra_data: Dict[str, Any] = {} + for key, value in yielded.items(): + if key == "isUsingJinjaVariables": + continue + if key.startswith("is") and key.endswith("Edited"): + is_edited = bool(value) + continue + if key.endswith("ExtraData"): + extra_data = value if isinstance(value, dict) else {} + continue + if name is not None: + raise ValueError("yield_data() must contain a single provider data key") + name = key + data = value + if name is None: + raise ValueError("yield_data() must contain a provider data key") + return cls(name=name, isEdited=is_edited, data=data, extraData=extra_data) + + @classmethod + def from_value(cls, value: Any, *, default_is_edited: bool = True) -> "ContextItem": + if isinstance(value, cls): + return value + if isinstance(value, Mapping): + if "name" in value: + return cls.from_persisted(value, default_is_edited=default_is_edited) + return cls.from_provider_yield(value) + raise TypeError("Context item must be a mapping or ContextItem") + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def read_data(self, default: Any = None) -> Any: + data = self.data if self.data is not None else default + if isinstance(data, dict) and set(data) == {"value"}: + return data["value"] + return data + + def as_dict(self) -> Dict[str, Any]: + return self.model_dump() diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 8a02e789..04e39cf1 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -2,7 +2,7 @@ from mat3ra.ade import Application, Executable, Flavor from mat3ra.ade.context.context_provider import ContextProvider -from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema +from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema from mat3ra.utils import ( calculate_hash_from_object, remove_comments_from_source_code, @@ -13,8 +13,8 @@ from .execution_unit_input import ExecutionUnitInput from .unit import Unit +from .unit_context import UnitContext -Context = List[ContextItemSchema] class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" @@ -22,7 +22,7 @@ class ExecutionUnit(Unit, ExecutionUnitSchema): executable: Executable flavor: Flavor input: List[ExecutionUnitInput] = Field(default_factory=list) - context: Context = Field(default_factory=list) + context: UnitContext = Field(default_factory=UnitContext) @field_validator("input", mode="before") @classmethod @@ -37,114 +37,32 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]: instantiated.append(ExecutionUnitInput(**item)) return instantiated - @field_validator("context", mode="before") @classmethod - def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: - if value is None: - return [] - if not isinstance(value, list): - return value - return [{ - "name": item["name"], - "isEdited": bool(item.get("isEdited", False)), - "data": item.get("data", {}), - "extraData": item.get("extraData") or {}, - } for item in value if isinstance(item, dict) and item.get("name")] - - @staticmethod - def _context_item_name(item: Any) -> Optional[str]: - if isinstance(item, dict): - return item.get("name") - name = getattr(item, "name", None) - return str(name) if name is not None else None + def _validate_context(cls, value: Any) -> UnitContext: + return UnitContext.from_value(value, default_is_edited=False) def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: - for item in self.context: - if self._context_item_name(item) == name: - return item if isinstance(item, dict) else item.model_dump() - return None - - @staticmethod - def _read_context_data(item: Dict[str, Any], default: Any = None) -> Any: - data = item.get("data", default) - if isinstance(data, dict) and set(data) == {"value"}: - return data["value"] - return data - - @staticmethod - def context_item( - name: str, - data: Any, - *, - is_edited: bool = True, - extra_data: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - payload = data if isinstance(data, dict) else {"value": data} - return { - "name": name, - "isEdited": is_edited, - "data": payload, - "extraData": extra_data or {}, - } - - def _replace_context_item(self, name: str, item: Dict[str, Any]) -> None: - rest = [entry for entry in self.context if self._context_item_name(entry) != name] - self.context = rest + [item] - - @staticmethod - def _normalized_context_item(item: Dict[str, Any]) -> Dict[str, Any]: - if "name" in item: - return ExecutionUnit.context_item( - item["name"], - item.get("data"), - is_edited=bool(item.get("isEdited", True)), - extra_data=item.get("extraData") or {}, - ) - return ExecutionUnit._context_item_from_provider_yield(item) - - @staticmethod - def _context_item_from_provider_yield(yielded: Dict[str, Any]) -> Dict[str, Any]: - name = None - data = None - is_edited = True - extra_data: Dict[str, Any] = {} - for key, value in yielded.items(): - if key == "isUsingJinjaVariables": - continue - if key.startswith("is") and key.endswith("Edited"): - is_edited = bool(value) - continue - if key.endswith("ExtraData"): - extra_data = value or {} - continue - if name is not None: - raise ValueError("yield_data() must contain a single provider data key") - name = key - data = value - if name is None: - raise ValueError("yield_data() must contain a provider data key") - return ExecutionUnit.context_item(name, data, is_edited=is_edited, extra_data=extra_data) + item = self.context.get(name) + return item.as_dict() if item else None def add_context(self, item: Dict[str, Any]) -> None: - normalized = self._normalized_context_item(item) - self._replace_context_item(normalized["name"], normalized) + self.context.upsert(item) def add_context_provider(self, provider: ContextProvider) -> None: self.add_context(provider.yield_data()) - def set_context(self, items: Context) -> None: - self.context = items + def set_context(self, items: Any) -> None: + self.context = UnitContext.from_value(items, default_is_edited=False) def get_context(self, name: str, default: Any = None) -> Any: - item = self.get_context_item(name) - return self._read_context_data(item, default) if item else default + return self.context.get_data(name, default) def remove_context(self, name: str) -> None: - self.context = [item for item in self.context if self._context_item_name(item) != name] + self.context.remove(name) def clear_context(self) -> None: - self.context = [] + self.context.clear() def replace_in_input_content(self, pattern: str, replacement: str, input_name=None) -> None: for item in self.input: diff --git a/src/py/mat3ra/wode/units/unit_context.py b/src/py/mat3ra/wode/units/unit_context.py new file mode 100644 index 00000000..d0fddae0 --- /dev/null +++ b/src/py/mat3ra/wode/units/unit_context.py @@ -0,0 +1,61 @@ +from typing import Any, Iterable, List, Optional + +from pydantic import Field, RootModel + +from .context_item import ContextItem + + +class UnitContext(RootModel[List[ContextItem]]): + root: List[ContextItem] = Field(default_factory=list) + + @classmethod + def from_value(cls, value: Any, *, default_is_edited: bool = False) -> "UnitContext": + if isinstance(value, cls): + return value + if value is None: + return cls([]) + if isinstance(value, list): + items = [] + for entry in value: + try: + items.append(ContextItem.from_value(entry, default_is_edited=default_is_edited)) + except (TypeError, ValueError): + continue + return cls(items) + if isinstance(value, ContextItem): + return cls([value]) + return cls.model_validate(value) + + def get(self, name: str) -> Optional[ContextItem]: + return next((item for item in self.root if item.name == name), None) + + def get_data(self, name: str, default: Any = None) -> Any: + item = self.get(name) + return item.read_data(default) if item else default + + def add(self, item: Any) -> None: + self.upsert(item) + + def upsert(self, item: Any) -> None: + context_item = ContextItem.from_value(item, default_is_edited=True) + self.root = [entry for entry in self.root if entry.name != context_item.name] + [context_item] + + def set_items(self, items: Iterable[Any]) -> None: + self.root = UnitContext.from_value(list(items), default_is_edited=False).root + + def remove(self, name: str) -> None: + self.root = [item for item in self.root if item.name != name] + + def clear(self) -> None: + self.root = [] + + def __iter__(self): + return iter(self.root) + + def __len__(self) -> int: + return len(self.root) + + def __eq__(self, other: object) -> bool: + if isinstance(other, list): + return [item.as_dict() for item in self.root] == other + return super().__eq__(other) diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 9eb07db4..ba7a9201 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -1,10 +1,9 @@ import pytest +from fixtures import execution_unit_config from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata from mat3ra.wode import ExecutionUnit, Unit -from fixtures import execution_unit_config - WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() @@ -83,3 +82,26 @@ def test_add_context(): "extraData": {}, }, ] + + +def test_add_context_from_provider_yield_wraps_scalar_data(): + config = execution_unit_config(APPLICATION_ESPRESSO, "band_gap", "pw_scf") + unit = ExecutionUnit(**{**config, "name": "relaxation step"}) + + unit.add_context( + { + "degauss": 0.001, + "isDegaussEdited": False, + "degaussExtraData": {"units": "Ry"}, + "isUsingJinjaVariables": True, + } + ) + + assert unit.get_context("degauss") == 0.001 + assert unit.get_context_item("degauss") == { + "name": "degauss", + "isEdited": False, + "data": {"value": 0.001}, + "extraData": {"units": "Ry"}, + } + assert any(item.get("name") == "degauss" for item in unit.context)