Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/py/mat3ra/wode/units/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .context_item import ContextItem
from .execution import ExecutionUnit
from .execution_unit_input import ExecutionUnitInput
from .subworkflow import SubworkflowUnit
from .unit import Unit
from .unit_context import UnitContext

__all__ = [
"Unit",
"ExecutionUnit",
"ExecutionUnitInput",
"SubworkflowUnit",
"ContextItem",
"UnitContext",
]
80 changes: 80 additions & 0 deletions src/py/mat3ra/wode/units/context_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from collections.abc import Mapping
from typing import Any, Dict

from pydantic import BaseModel, ConfigDict, Field, field_validator


class ContextItem(BaseModel):
name: str
isEdited: bool = True
data: Any = Field(default_factory=dict)
extraData: Dict[str, Any] = Field(default_factory=dict)

model_config = ConfigDict(extra="ignore")

@field_validator("data", mode="before")
@classmethod
def _normalize_data(cls, value: Any) -> Dict[str, Any]:
return value if isinstance(value, dict) else {"value": value}

@field_validator("extraData", mode="before")
@classmethod
def _normalize_extra_data(cls, value: Any) -> Dict[str, Any]:
return value if isinstance(value, dict) else {}

@classmethod
def from_persisted(cls, item: Mapping[str, Any], *, default_is_edited: bool = True) -> "ContextItem":
name = item.get("name")
if not name:
raise ValueError("Context item must contain a name")
return cls(
name=str(name),
isEdited=bool(item.get("isEdited", default_is_edited)),
data=item.get("data"),
extraData=item.get("extraData"),
)

@classmethod
def from_provider_yield(cls, yielded: Mapping[str, Any]) -> "ContextItem":
name = None
data = None
is_edited = True
extra_data: Dict[str, Any] = {}
for key, value in yielded.items():
if key == "isUsingJinjaVariables":
continue
if key.startswith("is") and key.endswith("Edited"):
is_edited = bool(value)
continue
if key.endswith("ExtraData"):
extra_data = value if isinstance(value, dict) else {}
continue
if name is not None:
raise ValueError("yield_data() must contain a single provider data key")
name = key
data = value
if name is None:
raise ValueError("yield_data() must contain a provider data key")
return cls(name=name, isEdited=is_edited, data=data, extraData=extra_data)

@classmethod
def from_value(cls, value: Any, *, default_is_edited: bool = True) -> "ContextItem":
if isinstance(value, cls):
return value
if isinstance(value, Mapping):
if "name" in value:
return cls.from_persisted(value, default_is_edited=default_is_edited)
return cls.from_provider_yield(value)
raise TypeError("Context item must be a mapping or ContextItem")

def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)

def read_data(self, default: Any = None) -> Any:
data = self.data if self.data is not None else default
if isinstance(data, dict) and set(data) == {"value"}:
return data["value"]
return data

def as_dict(self) -> Dict[str, Any]:
return self.model_dump()
108 changes: 13 additions & 95 deletions src/py/mat3ra/wode/units/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mat3ra.ade import Application, Executable, Flavor
from mat3ra.ade.context.context_provider import ContextProvider
from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema
from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema
from mat3ra.utils import (
calculate_hash_from_object,
remove_comments_from_source_code,
Expand All @@ -13,16 +13,16 @@

from .execution_unit_input import ExecutionUnitInput
from .unit import Unit
from .unit_context import UnitContext

Context = List[ContextItemSchema]

class ExecutionUnit(Unit, ExecutionUnitSchema):
type: Literal["execution"] = "execution"
application: Application
executable: Executable
flavor: Flavor
input: List[ExecutionUnitInput] = Field(default_factory=list)
context: Context = Field(default_factory=list)
context: UnitContext = Field(default_factory=UnitContext)

@field_validator("input", mode="before")
@classmethod
Expand All @@ -37,114 +37,32 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]:
instantiated.append(ExecutionUnitInput(**item))
return instantiated


@field_validator("context", mode="before")
@classmethod
def _validate_context(cls, value: Any) -> List[Dict[str, Any]]:
if value is None:
return []
if not isinstance(value, list):
return value
return [{
"name": item["name"],
"isEdited": bool(item.get("isEdited", False)),
"data": item.get("data", {}),
"extraData": item.get("extraData") or {},
} for item in value if isinstance(item, dict) and item.get("name")]

@staticmethod
def _context_item_name(item: Any) -> Optional[str]:
if isinstance(item, dict):
return item.get("name")
name = getattr(item, "name", None)
return str(name) if name is not None else None
def _validate_context(cls, value: Any) -> UnitContext:
return UnitContext.from_value(value, default_is_edited=False)

def get_context_item(self, name: str) -> Optional[Dict[str, Any]]:
for item in self.context:
if self._context_item_name(item) == name:
return item if isinstance(item, dict) else item.model_dump()
return None

@staticmethod
def _read_context_data(item: Dict[str, Any], default: Any = None) -> Any:
data = item.get("data", default)
if isinstance(data, dict) and set(data) == {"value"}:
return data["value"]
return data

@staticmethod
def context_item(
name: str,
data: Any,
*,
is_edited: bool = True,
extra_data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
payload = data if isinstance(data, dict) else {"value": data}
return {
"name": name,
"isEdited": is_edited,
"data": payload,
"extraData": extra_data or {},
}

def _replace_context_item(self, name: str, item: Dict[str, Any]) -> None:
rest = [entry for entry in self.context if self._context_item_name(entry) != name]
self.context = rest + [item]

@staticmethod
def _normalized_context_item(item: Dict[str, Any]) -> Dict[str, Any]:
if "name" in item:
return ExecutionUnit.context_item(
item["name"],
item.get("data"),
is_edited=bool(item.get("isEdited", True)),
extra_data=item.get("extraData") or {},
)
return ExecutionUnit._context_item_from_provider_yield(item)

@staticmethod
def _context_item_from_provider_yield(yielded: Dict[str, Any]) -> Dict[str, Any]:
name = None
data = None
is_edited = True
extra_data: Dict[str, Any] = {}
for key, value in yielded.items():
if key == "isUsingJinjaVariables":
continue
if key.startswith("is") and key.endswith("Edited"):
is_edited = bool(value)
continue
if key.endswith("ExtraData"):
extra_data = value or {}
continue
if name is not None:
raise ValueError("yield_data() must contain a single provider data key")
name = key
data = value
if name is None:
raise ValueError("yield_data() must contain a provider data key")
return ExecutionUnit.context_item(name, data, is_edited=is_edited, extra_data=extra_data)
item = self.context.get(name)
return item.as_dict() if item else None

def add_context(self, item: Dict[str, Any]) -> None:
normalized = self._normalized_context_item(item)
self._replace_context_item(normalized["name"], normalized)
self.context.upsert(item)

def add_context_provider(self, provider: ContextProvider) -> None:
self.add_context(provider.yield_data())

def set_context(self, items: Context) -> None:
self.context = items
def set_context(self, items: Any) -> None:
self.context = UnitContext.from_value(items, default_is_edited=False)

def get_context(self, name: str, default: Any = None) -> Any:
item = self.get_context_item(name)
return self._read_context_data(item, default) if item else default
return self.context.get_data(name, default)

def remove_context(self, name: str) -> None:
self.context = [item for item in self.context if self._context_item_name(item) != name]
self.context.remove(name)

def clear_context(self) -> None:
self.context = []
self.context.clear()

def replace_in_input_content(self, pattern: str, replacement: str, input_name=None) -> None:
for item in self.input:
Expand Down
61 changes: 61 additions & 0 deletions src/py/mat3ra/wode/units/unit_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, Iterable, List, Optional

from pydantic import Field, RootModel

from .context_item import ContextItem


class UnitContext(RootModel[List[ContextItem]]):
root: List[ContextItem] = Field(default_factory=list)

@classmethod
def from_value(cls, value: Any, *, default_is_edited: bool = False) -> "UnitContext":
if isinstance(value, cls):
return value
if value is None:
return cls([])
if isinstance(value, list):
items = []
for entry in value:
try:
items.append(ContextItem.from_value(entry, default_is_edited=default_is_edited))
except (TypeError, ValueError):
continue
return cls(items)
if isinstance(value, ContextItem):
return cls([value])
return cls.model_validate(value)

def get(self, name: str) -> Optional[ContextItem]:
return next((item for item in self.root if item.name == name), None)

def get_data(self, name: str, default: Any = None) -> Any:
item = self.get(name)
return item.read_data(default) if item else default

def add(self, item: Any) -> None:
self.upsert(item)

def upsert(self, item: Any) -> None:
context_item = ContextItem.from_value(item, default_is_edited=True)
self.root = [entry for entry in self.root if entry.name != context_item.name] + [context_item]

def set_items(self, items: Iterable[Any]) -> None:
self.root = UnitContext.from_value(list(items), default_is_edited=False).root

def remove(self, name: str) -> None:
self.root = [item for item in self.root if item.name != name]

def clear(self) -> None:
self.root = []

def __iter__(self):
return iter(self.root)

def __len__(self) -> int:
return len(self.root)

def __eq__(self, other: object) -> bool:
if isinstance(other, list):
return [item.as_dict() for item in self.root] == other
return super().__eq__(other)
26 changes: 24 additions & 2 deletions tests/py/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest
from fixtures import execution_unit_config
from mat3ra.standata.applications import ApplicationStandata
from mat3ra.standata.workflows import WorkflowStandata
from mat3ra.wode import ExecutionUnit, Unit

from fixtures import execution_unit_config

WORKFLOW_STANDATA = WorkflowStandata()
APPLICATION_STANDATA = ApplicationStandata()

Expand Down Expand Up @@ -83,3 +82,26 @@ def test_add_context():
"extraData": {},
},
]


def test_add_context_from_provider_yield_wraps_scalar_data():
config = execution_unit_config(APPLICATION_ESPRESSO, "band_gap", "pw_scf")
unit = ExecutionUnit(**{**config, "name": "relaxation step"})

unit.add_context(
{
"degauss": 0.001,
"isDegaussEdited": False,
"degaussExtraData": {"units": "Ry"},
"isUsingJinjaVariables": True,
}
)

assert unit.get_context("degauss") == 0.001
assert unit.get_context_item("degauss") == {
"name": "degauss",
"isEdited": False,
"data": {"value": 0.001},
"extraData": {"units": "Ry"},
}
assert any(item.get("name") == "degauss" for item in unit.context)