diff --git a/pyproject.toml b/pyproject.toml index 32a2074f..cf3a9510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ extend_skip_glob = ["dist/*"] [tool.pytest.ini_options] pythonpath = [ "src/py", + "tests/py", ] testpaths = [ "tests/py" diff --git a/src/py/mat3ra/wode/context/providers/base/__init__.py b/src/py/mat3ra/wode/context/providers/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/context/providers/base/context_provider.py b/src/py/mat3ra/wode/context/providers/base/context_provider.py new file mode 100644 index 00000000..ab06aa6d --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/base/context_provider.py @@ -0,0 +1,13 @@ +from typing import Any, Dict + +from mat3ra.ade.context.context_provider import ContextProvider as AdeContextProvider + +# TODO: Remove context provider from Ade -- sync with JS implementation fully +class ContextProvider(AdeContextProvider): + def get_context_item_data(self) -> Dict[str, Any]: + return { + "name": self.name_str, + "isEdited": self.is_edited, + "data": self.get_data(), + "extraData": self.extra_data or {}, + } diff --git a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py index c1487278..d0140284 100644 --- a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py +++ b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py @@ -1,11 +1,12 @@ from typing import Any, Dict, Optional -from mat3ra.ade.context.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.planewave_cutoffs_context_provider import ( PlanewaveCutoffsContextProviderSchema, ) from pydantic import Field +from .base.context_provider import ContextProvider + class PlanewaveCutoffsContextProvider(PlanewaveCutoffsContextProviderSchema, ContextProvider): name: str = Field(default="cutoffs") diff --git a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py index a451fb5a..85ae7eed 100644 --- a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py @@ -1,12 +1,12 @@ from typing import Any, Dict, List, Optional -from mat3ra.ade.context.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import ( GridMetricType, PointsGridDataProviderSchema, ) from pydantic import Field +from .base.context_provider import ContextProvider DEFAULT_KPPRA = -1 diff --git a/src/py/mat3ra/wode/context/providers/points_path_data_provider.py b/src/py/mat3ra/wode/context/providers/points_path_data_provider.py index 0a45699b..b3a222f7 100644 --- a/src/py/mat3ra/wode/context/providers/points_path_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_path_data_provider.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from mat3ra.ade.context.context_provider import ContextProvider +from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.points_path_data_provider import ( PointsPathDataProviderSchemaItem, ) diff --git a/src/py/mat3ra/wode/mixins/flowchart_units_manager.py b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py index a7105963..f57ac041 100644 --- a/src/py/mat3ra/wode/mixins/flowchart_units_manager.py +++ b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py @@ -1,42 +1,43 @@ -from typing import List, Optional, TypeVar +from typing import Generic, List, Optional, TypeVar from mat3ra.utils import find_by_key_or_regex from ..units import Unit +UnitT = TypeVar("UnitT", bound=Unit) T = TypeVar("T") -class FlowchartUnitsManager: +class FlowchartUnitsManager(Generic[UnitT]): """ Mixin class providing common unit operations for flowchart units. - This mixin expects the class to have a `units: List[Unit]` attribute. + This mixin expects the class to have a `units: List[UnitT]` attribute. It provides common methods for managing units in both Workflow and Subworkflow classes. """ - units: List[Unit] + units: List[UnitT] - def set_units(self, units: List[Unit]) -> None: + def set_units(self, units: List[UnitT]) -> None: self.units = units def _set_units_order_in_place(self) -> None: self.set_units_head(self.units) self.set_next_links(self.units) - def get_unit(self, flowchart_id: str) -> Optional[Unit]: + def get_unit(self, flowchart_id: str) -> Optional[UnitT]: for unit in self.units: if unit.flowchartId == flowchart_id: return unit return None - def find_unit_by_id(self, id: str) -> Optional[Unit]: + def find_unit_by_id(self, id: str) -> Optional[UnitT]: for unit in self.units: if getattr(unit, "id", None) == id: return unit return None - def find_unit_with_tag(self, tag: str) -> Optional[Unit]: + def find_unit_with_tag(self, tag: str) -> Optional[UnitT]: for unit in self.units: if hasattr(unit, "tags") and unit.tags is not None and tag in unit.tags: return unit @@ -46,7 +47,7 @@ def get_unit_by_name( self, name: Optional[str] = None, name_regex: Optional[str] = None, - ) -> Optional[Unit]: + ) -> Optional[UnitT]: return find_by_key_or_regex(self.units, key="name", value=name, value_regex=name_regex) @staticmethod @@ -68,7 +69,7 @@ def _add_to_list(items: List[T], item: T, head: bool = False, index: int = -1) - else: items.append(item) - def set_units_head(self, units: List[Unit]) -> List[Unit]: + def set_units_head(self, units: List[UnitT]) -> List[UnitT]: """ Set the head flag on the first unit and unset it on all others. @@ -84,7 +85,7 @@ def set_units_head(self, units: List[Unit]) -> List[Unit]: unit.head = False return units - def set_next_links(self, units: List[Unit]) -> List[Unit]: + def set_next_links(self, units: List[UnitT]) -> List[UnitT]: """ Re-establishes the linked next => flowchartId logic in an array of units. @@ -118,7 +119,7 @@ def _clear_link_to_unit(self, flowchart_id: str) -> None: unit.next = None break - def add_unit(self, unit: Unit, head: bool = False, index: int = -1) -> None: + def add_unit(self, unit: UnitT, head: bool = False, index: int = -1) -> None: """ Add a unit to the units list. @@ -175,8 +176,8 @@ def replace_unit(self, index: int, unit: Unit) -> None: def set_unit( self, - new_unit: Unit, - unit: Optional[Unit] = None, + new_unit: UnitT, + unit: Optional[UnitT] = None, unit_flowchart_id: Optional[str] = None, ) -> bool: """ diff --git a/src/py/mat3ra/wode/subworkflows/convergence/factory.py b/src/py/mat3ra/wode/subworkflows/convergence/factory.py index 398659cc..4a55f1b5 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/factory.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/factory.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum + from .non_uniform_kgrid import NonUniformKGridConvergence from .non_uniform_kgrid_2d import NonUniformKGridConvergence2D from .parameter import ConvergenceParameter diff --git a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py index 929d8f74..05369634 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py @@ -8,11 +8,12 @@ class UniformKGridConvergence(ConvergenceParameter): def _points_grid_context( self, dimensions: List[str], reciprocal_vector_ratios: Optional[List[float]] = None ) -> Dict[str, Any]: - return PointsGridDataProvider().yield_data_with_overrides( + provider = PointsGridDataProvider(isEdited=True) + provider.data = provider.build_data( dimensions=dimensions, reciprocal_vector_ratios=reciprocal_vector_ratios, - is_using_jinja_variables=True, ) + return provider.get_context_item_data() @property def increment(self) -> str: diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index b3124524..a6ccc1a5 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -2,7 +2,6 @@ from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum from mat3ra.utils.extra.jinja import JINJA_EXPRESSION_PATTERN, NUMERIC_VALUE_PATTERN, wrap_text_in_raw_block - from .convergence.factory import create_convergence_parameter from ..context.providers import PointsGridDataProvider from ..units import Unit @@ -67,15 +66,6 @@ def _find_unit_for_convergence(self, result: str): return unit return None - @staticmethod - def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context: Dict[str, Any]) -> Dict[str, Any]: - merged_context = dict(unit_context) - merged_kgrid_context = dict(unit_context.get("kgrid") or {}) - merged_kgrid_context.update(convergence_context.get("kgrid") or {}) - merged_context.update(convergence_context) - if merged_kgrid_context: - merged_context["kgrid"] = merged_kgrid_context - return merged_context def _build_convergence_units( self, @@ -153,6 +143,17 @@ def _build_convergence_units( host.add_unit(next_step) host.add_unit(exit_unit) + param_init.next = prev_result_init.flowchartId + prev_result_init.next = iter_init.flowchartId + iter_init.next = execution_unit_flowchart_id + + execution_unit = host.get_unit(execution_unit_flowchart_id) + if execution_unit is not None: + execution_unit.next = store_result.flowchartId + + store_result.next = condition_unit.flowchartId + store_prev_result.next = next_iter.flowchartId + next_iter.next = next_step.flowchartId next_step.next = execution_unit_flowchart_id def add_convergence( @@ -185,9 +186,12 @@ def add_convergence( ) and reciprocal_vector_ratios is None ): - reciprocal_vector_ratios = PointsGridDataProvider( - context=unit_for_convergence.context - ).get_reciprocal_vector_ratios() + kgrid_item = unit_for_convergence.get_context_item("kgrid") + provider = PointsGridDataProvider( + data=kgrid_item.get("data"), + is_edited=kgrid_item.get("isEdited"), + ) + reciprocal_vector_ratios = provider.get_reciprocal_vector_ratios() if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") @@ -198,11 +202,7 @@ def add_convergence( reciprocal_vector_ratios=reciprocal_vector_ratios, ) - merged_context = self._merge_convergence_context( - unit_for_convergence.context, - parameter.unit_context, - ) - unit_for_convergence.set_context(merged_context) + unit_for_convergence.add_context(parameter.unit_context) self._build_convergence_units( parameter_name=parameter.name, @@ -266,7 +266,6 @@ def add_template_parameter_convergence( execution_unit.replace_in_input_content( pattern, f"{parameter_name} = {scope_reference}", input_name=input_name ) - execution_unit.add_context({parameter_name: parameter_initial}) self._build_convergence_units( parameter_name=parameter_name, diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py index aa4167e2..46f0bd67 100644 --- a/src/py/mat3ra/wode/subworkflows/subworkflow.py +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -9,7 +9,7 @@ from mat3ra.mode.model import Model from mat3ra.mode.models.factory import ModelFactory from mat3ra.utils.uuid import get_uuid -from pydantic import Field, field_validator +from pydantic import ConfigDict, Field, SerializeAsAny, field_validator from .convergence_mixin import ConvergenceMixin from ..mixins import FlowchartUnitsManager @@ -22,7 +22,7 @@ class Subworkflow( SubworkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, - FlowchartUnitsManager, + FlowchartUnitsManager[Unit], ): """ Subworkflow class representing a logical collection of workflow units. @@ -35,19 +35,21 @@ class Subworkflow( properties: List of properties extracted by the subworkflow """ + model_config = ConfigDict(validate_assignment=True) + id: str = Field(default_factory=get_uuid, alias="_id") application: Application = Field( default_factory=lambda: Application(name="", version="", build="", shortName="", summary="") ) properties: List[str] = Field(default_factory=list) - model: Model = Field(default_factory=DFTModel) - units: List[Union[Unit, ExecutionUnit, SubworkflowUnit]] = Field(default_factory=list) + model: SerializeAsAny[Model] = Field(default_factory=DFTModel) + units: List[SerializeAsAny[Union[Unit, ExecutionUnit, SubworkflowUnit]]] = Field(default_factory=list) @field_validator("model", mode="before") @classmethod def _instantiate_model(cls, value: Any) -> Any: if isinstance(value, Model): - return value + value = value.to_dict() if isinstance(value, dict): return ModelFactory.create(value) return value diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 8ad43f52..7c346228 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional from mat3ra.ade import Application, Executable, Flavor from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema @@ -8,7 +8,7 @@ remove_empty_lines_from_string, ) from mat3ra.utils.extra.jinja import replace_in_text -from pydantic import Field, field_validator +from pydantic import Field, SerializeAsAny, field_validator from .execution_unit_input import ExecutionUnitInput from .unit import Unit @@ -16,10 +16,11 @@ class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" - executable: Executable = None - flavor: Flavor = None - application: Application = None + application: Application + executable: Executable + flavor: Flavor input: List[ExecutionUnitInput] = Field(default_factory=list) + context: List[SerializeAsAny[Dict[str, Any]]] = Field(default_factory=list) @field_validator("input", mode="before") @classmethod @@ -34,6 +35,38 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]: instantiated.append(ExecutionUnitInput(**item)) return instantiated + + @staticmethod + def _context_item_name(item: Any) -> Optional[str]: + if isinstance(item, dict): + return item.get("name") + return str(item.name) + + def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: + for item in self.context: + if self._context_item_name(item) == name: + return item + return None + + def get_context_item_data(self, name: str, default: Any = None) -> Any: + if default is None: + default = {} + item = self.get_context_item(name) + return item.get("data", default) if item else default + + def add_context(self, item: Dict[str, Any]) -> None: + existing_item = self.get_context_item(item["name"]) + if existing_item: + existing_item.update(item) + else: + self.context.append(item) + + def remove_context(self, name: str) -> None: + self.context = [item for item in self.context if self._context_item_name(item) != name] + + def clear_context(self) -> None: + self.context = [] + def replace_in_input_content(self, pattern: str, replacement: str, input_name=None) -> None: for item in self.input: if input_name is None or item.template.name == input_name: diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py index 38686342..1abd4afa 100644 --- a/src/py/mat3ra/wode/units/unit.py +++ b/src/py/mat3ra/wode/units/unit.py @@ -4,7 +4,7 @@ from mat3ra.code.mixins import HashedEntityMixin from mat3ra.esse.models.workflow.unit.base import WorkflowBaseUnitSchema from mat3ra.utils.uuid import get_uuid -from pydantic import Field, field_validator +from pydantic import Field class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): @@ -18,7 +18,6 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): head: Whether this unit is the head of the workflow next: Flowchart ID of the next unit tags: List of tags for the unit - context: Context data dictionary for the unit """ id: str = Field(default_factory=get_uuid, alias="_id") flowchartId: str = Field(default_factory=get_uuid) @@ -27,15 +26,6 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): postProcessors: List[Any] = Field(default_factory=list) monitors: List[Any] = Field(default_factory=list) results: List[Any] = Field(default_factory=list) - context: Dict[str, Any] = Field(default_factory=dict) - - @field_validator("context", mode="before") - @classmethod - def _coerce_context(cls, value: Any) -> Dict[str, Any]: - if value is None or value == []: - return {} - return value - def get_hash_object(self) -> Dict[str, Any]: return { @@ -47,18 +37,3 @@ def get_hash_object(self) -> Dict[str, Any]: def is_in_status(self, status: str) -> bool: return self.status == status - - def add_context(self, new_context: Dict[str, Any]): - self.context.update(new_context) - - def set_context(self, new_context: Dict[str, Any]): - self.context = new_context - - def get_context(self, key: str, default: Any = None) -> Any: - return self.context.get(key, default) - - def remove_context(self, key: str): - self.context.pop(key, None) - - def clear_context(self): - self.context = {} diff --git a/src/py/mat3ra/wode/workflows/workflow.py b/src/py/mat3ra/wode/workflows/workflow.py index e3ad9b8f..8478dda9 100644 --- a/src/py/mat3ra/wode/workflows/workflow.py +++ b/src/py/mat3ra/wode/workflows/workflow.py @@ -12,7 +12,7 @@ from ..units import Unit -class Workflow(WorkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, FlowchartUnitsManager): +class Workflow(WorkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, FlowchartUnitsManager[Unit]): """ Workflow class representing a complete workflow configuration. @@ -62,7 +62,7 @@ def set_context_to_unit( new_context: Optional[Dict[str, Any]] = None, ): target_unit = self.get_unit_by_name(name=unit_name, name_regex=unit_name_regex) - target_unit.context = new_context + target_unit.set_context(new_context or []) def _find_relaxation_subworkflow(self) -> Optional[Subworkflow]: target_name = self.relaxation_subworkflow.name @@ -99,5 +99,5 @@ def to_dict_without_special_keys(self, special_keys=["context"]) -> Dict[str, An for unit in swf.get("units", []): for key in special_keys: if key in unit: - unit[key] = {} + unit[key] = [] return workflow_dict diff --git a/tests/py/fixtures/__init__.py b/tests/py/fixtures/__init__.py new file mode 100644 index 00000000..8775c11b --- /dev/null +++ b/tests/py/fixtures/__init__.py @@ -0,0 +1,17 @@ +from typing import Any, Dict + +from mat3ra.standata.workflows import WorkflowStandata + +WORKFLOW_STANDATA = WorkflowStandata() + + +def get_execution_unit_config_by_application_workflow_unit(application: str, workflow_name: str, unit_name: str) -> \ +Dict[str, Any]: + workflows = WORKFLOW_STANDATA.get_by_categories(application, workflow_name) + if not workflows: + raise ValueError(f"No workflow {workflow_name!r} for application {application!r}") + for subworkflow in workflows[0]["subworkflows"]: + for unit in subworkflow["units"]: + if unit.get("type") == "execution" and unit.get("name") == unit_name: + return unit + raise ValueError(f"No execution unit {unit_name!r} in workflow {workflow_name!r}") diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 8520d6ee..6e1e20a2 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -12,6 +12,29 @@ def _build_total_energy_subworkflow(): return workflow.subworkflows[0] +def _assert_convergence_flowchart_links(subworkflow, execution_unit_name): + def unit(name): + return subworkflow.get_unit_by_name(name=name) + + next_links = [ + ("init parameter", "init result"), + ("init result", "init counter"), + ("init counter", execution_unit_name), + (execution_unit_name, "update result"), + ("update result", "check convergence"), + ("check convergence", "store result"), + ("store result", "update counter"), + ("update counter", "update parameter"), + ("update parameter", execution_unit_name), + ] + for from_name, to_name in next_links: + assert unit(from_name).next == unit(to_name).flowchartId + + check_convergence = unit("check convergence") + assert check_convergence.then == unit("exit").flowchartId + assert getattr(check_convergence, "else") == unit("store result").flowchartId + + def test_add_uniform_energy_convergence(): subworkflow = _build_total_energy_subworkflow() @@ -41,19 +64,19 @@ def test_add_uniform_energy_convergence(): ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] - assert pw_scf.context["kgrid"]["shifts"] == [0, 0, 0] - assert pw_scf.context["isKgridEdited"] is True - assert pw_scf.context["isUsingJinjaVariables"] is True + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] + assert pw_scf.get_context_item_data("kgrid")["shifts"] == [0, 0, 0] + assert any(item.get("name") == "kgrid" and item.get("isEdited") for item in pw_scf.context) assert subworkflow.convergence_parameter == ConvergenceParameterNameEnum.N_k.value assert subworkflow.convergence_result == "total_energy" assert subworkflow.has_convergence is True + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k.value assert update_parameter.value == f"{ConvergenceParameterNameEnum.N_k.value} + 1" - assert update_parameter.next == pw_scf.flowchartId check_convergence = subworkflow.get_unit_by_name(name="check convergence") assert check_convergence.input == [] @@ -82,20 +105,22 @@ def test_add_non_uniform_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == [ + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[2]}}}}", ] - assert pw_scf.context["kgrid"]["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform.value assert update_parameter.input == [{"scope": pw_scf.flowchartId, "name": "context"}] assert ( - update_parameter.value - == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " - "for i in range(3)]" + update_parameter.value + == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " + "for i in range(3)]" ) @@ -117,20 +142,22 @@ def test_add_non_uniform_2d_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == [ + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[2]}}}}", ] - assert pw_scf.context["kgrid"]["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform_2D.value assert update_parameter.input == [{"scope": pw_scf.flowchartId, "name": "context"}] assert ( - update_parameter.value - == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " - "for i in range(2)] + [1]" + update_parameter.value + == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " + "for i in range(2)] + [1]" ) @@ -219,7 +246,9 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context[param_name] == param_initial + init_parameter = subworkflow.get_unit_by_name(name="init parameter") + assert init_parameter.operand == param_name + assert init_parameter.value == param_initial input_item = pw_scf.input[0] template_content = input_item.template.content assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content @@ -229,11 +258,12 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme assert subworkflow.convergence_result == result_name assert subworkflow.has_convergence is True + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == param_name assert update_parameter.value == f"{param_name} + {param_increment}" assert update_parameter.input == [] - assert update_parameter.next == pw_scf.flowchartId exit_unit = subworkflow.get_unit_by_name(name="exit") assert exit_unit.operand == param_name @@ -258,12 +288,17 @@ def test_add_template_param_convergence_multi_unit(): pw_scf = subworkflow.get_unit_by_name("pw_scf") pw_bands = subworkflow.get_unit_by_name("pw_bands") + init_parameter = subworkflow.get_unit_by_name(name="init parameter") + assert init_parameter.operand == "ecutwfc" + assert init_parameter.value == 20 + for unit in [pw_scf, pw_bands]: - assert unit.context["ecutwfc"] == 20 input_item = unit.input[0] template_content = input_item.template.content assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + assert subworkflow.convergence_parameter == "ecutwfc" assert subworkflow.convergence_result == "total_energy" diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 4b9e0451..72be7904 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -2,10 +2,11 @@ from mat3ra.ade.application import Application from mat3ra.mode.method import Method from mat3ra.mode.model import Model +from mat3ra.mode.models.dft import DFTModel from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata -from mat3ra.wode import Subworkflow, Unit, Workflow +from mat3ra.wode import Subworkflow, Unit, Workflow, ExecutionUnit from mat3ra.wode.context.providers import PointsGridDataProvider SUBWORKFLOW_NAME = "Total Energy" @@ -23,6 +24,8 @@ "flowchartId": "unit-flowchart-id", "head": True, } +DFT_METHOD_CONFIG = {"type": "pseudopotential", "subtype": "us"} +DFT_MODEL_CONFIG_WITHOUT_FUNCTIONAL = {"type": "dft", "subtype": "gga", "method": DFT_METHOD_CONFIG} def test_creation(): @@ -54,6 +57,14 @@ def test_model(model_type, model_subtype): assert sw.model.subtype == model_subtype +@pytest.mark.parametrize("config", [DFT_MODEL_CONFIG_WITHOUT_FUNCTIONAL]) +def test_model_assignment_is_coerced_to_dft_model_with_default_functional(config): + subworkflow = Subworkflow(name=SUBWORKFLOW_NAME) + subworkflow.model = Model(**config) + assert isinstance(subworkflow.model, DFTModel) + assert subworkflow.model.to_dict().get("functional") == "pbe" + + def test_with_units(): unit = Unit(**UNIT_CONFIG) sw = Subworkflow(name=SUBWORKFLOW_NAME, units=[unit]) @@ -89,13 +100,14 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): relaxation_subworkflow = wf.subworkflows[0] unit_to_modify = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert unit_to_modify is not None + assert isinstance(unit_to_modify, ExecutionUnit) original_rendered = unit_to_modify.input[0].rendered - unit_to_modify.add_context({"test_key": "test_value", "another_key": 42}) - unit_to_modify.add_context( - PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True).yield_data() - ) + unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) + unit_to_modify.add_context({"name": "another_key", "data": 42}) + points_grid_provider = PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True) + unit_to_modify.add_context(points_grid_provider.get_context_item_data()) if method == "only_new_unit": success = relaxation_subworkflow.set_unit(unit_to_modify) @@ -109,7 +121,7 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): assert success is True updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") - assert updated_unit.context["test_key"] == "test_value" - assert updated_unit.context["another_key"] == 42 - assert updated_unit.context["kgrid"]["dimensions"] == [2, 2, 1] + assert updated_unit.get_context_item_data("test_key") == "test_value" + assert updated_unit.get_context_item_data("another_key") == 42 + assert updated_unit.get_context_item_data("kgrid")["dimensions"] == [2, 2, 1] assert updated_unit.input[0].rendered == original_rendered diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 22f23c76..8e35a8c5 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -1,7 +1,9 @@ import pytest from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata -from mat3ra.wode import Unit + +from fixtures import get_execution_unit_config_by_application_workflow_unit +from mat3ra.wode import ExecutionUnit, Unit WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() @@ -15,8 +17,7 @@ NEW_CONTEXT_RELAX = {"kgrid": {"density": 0.5}, "convergence": {"threshold": 1e-6}} UNIT_CONFIG_EXECUTION = { - "type": "execution", - "name": "pw_scf", + **get_execution_unit_config_by_application_workflow_unit(APPLICATION_ESPRESSO, "total_energy", "pw_scf"), "flowchartId": UNIT_FLOWCHART_ID, "head": True, } @@ -55,14 +56,18 @@ def test_next_property(): def test_add_context(): - unit = Unit(**{**UNIT_CONFIG_EXECUTION, "name": "relaxation step"}) + unit = ExecutionUnit(**{**UNIT_CONFIG_EXECUTION, "name": "relaxation step"}) assert unit is not None assert "relax" in unit.name.lower() + assert unit.context == [] - unit.add_context(NEW_CONTEXT_RELAX) + unit.add_context({"name": "kgrid", "data": NEW_CONTEXT_RELAX["kgrid"], "isEdited": False}) + unit.add_context({"name": "convergence", "data": NEW_CONTEXT_RELAX["convergence"], "isEdited": False}) - assert "kgrid" in unit.context - assert "convergence" in unit.context - assert unit.context["kgrid"] == NEW_CONTEXT_RELAX["kgrid"] - assert unit.context["convergence"] == NEW_CONTEXT_RELAX["convergence"] + assert unit.get_context_item_data("kgrid") == NEW_CONTEXT_RELAX["kgrid"] + assert unit.get_context_item_data("convergence") == NEW_CONTEXT_RELAX["convergence"] + assert unit.to_dict()["context"] == [ + {"name": "kgrid", "isEdited": False, "data": NEW_CONTEXT_RELAX["kgrid"]}, + {"name": "convergence", "isEdited": False, "data": NEW_CONTEXT_RELAX["convergence"]}, + ] diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 17500ddb..1d67ed7c 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -32,6 +32,20 @@ "head": True, } +BAND_STRUCTURE_SEARCH_NAME = "band_structure" +BAND_GAP_SEARCH_NAME = "espresso/band_gap\\.json$" +TOTAL_ENERGY_SEARCH_NAME = "total_energy" + +EXPECTED_MODEL_FUNCTIONAL = "pbe" +EXECUTION_UNIT_TYPE = "execution" +CONTEXT_ITEM_REQUIRED_KEYS = ("name", "isEdited", "data", "extraData") + +WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES = [ + BAND_STRUCTURE_SEARCH_NAME, + BAND_GAP_SEARCH_NAME, + TOTAL_ENERGY_SEARCH_NAME, +] + def test_creation(): wf = Workflow(name=WORKFLOW_NAME) @@ -162,28 +176,29 @@ def test_set_unit(method): wf.add_relaxation() - unit_to_modify = wf.get_unit_by_name(name_regex="relax") + relaxation_subworkflow = wf._find_relaxation_subworkflow() + assert relaxation_subworkflow is not None + + unit_to_modify = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert unit_to_modify is not None - new_context = {"test_key": "test_value", "another_key": 42} - unit_to_modify.add_context(new_context) + unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) + unit_to_modify.add_context({"name": "another_key", "data": 42}) if method == "only_new_unit": - success = wf.set_unit(unit_to_modify) + success = relaxation_subworkflow.set_unit(unit_to_modify) elif method == "with_unit_instance": - original_unit = wf.get_unit_by_name(name_regex="relax") - success = wf.set_unit(unit_to_modify, unit=original_unit) + original_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") + success = relaxation_subworkflow.set_unit(unit_to_modify, unit=original_unit) elif method == "with_flowchart_id": flowchart_id = unit_to_modify.flowchartId - success = wf.set_unit(unit_to_modify, unit_flowchart_id=flowchart_id) + success = relaxation_subworkflow.set_unit(unit_to_modify, unit_flowchart_id=flowchart_id) assert success is True - updated_unit = wf.get_unit_by_name(name_regex="relax") - assert "test_key" in updated_unit.context - assert "another_key" in updated_unit.context - assert updated_unit.context["test_key"] == "test_value" - assert updated_unit.context["another_key"] == 42 + updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") + assert updated_unit.get_context_item_data("test_key") == "test_value" + assert updated_unit.get_context_item_data("another_key") == 42 @pytest.mark.parametrize("workflow, app", [("band_gap", "espresso")]) @@ -198,3 +213,4 @@ def test_calculate_hash(workflow, app): fixture = next(w for w in workflows if w["name"] == BAND_GAP_WORKFLOW_NAME) wf = Workflow(**{k: v for k, v in fixture.items() if k != "hash"}) assert wf.hash == expected_hash + diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py index a0724e95..01ff5d6a 100644 --- a/tests/py/units/test_execution_unit.py +++ b/tests/py/units/test_execution_unit.py @@ -2,10 +2,10 @@ import pytest from mat3ra.wode.units.execution import ExecutionUnit +from fixtures import get_execution_unit_config_by_application_workflow_unit UNIT_CONFIG = { - "type": "execution", - "name": "pw_scf", + **get_execution_unit_config_by_application_workflow_unit("espresso", "total_energy", "pw_scf"), "flowchartId": "abc-123", "head": True, }