From dd5c3c35a08465b4d823e1bd93e15997813d9ea8 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 1 Jun 2026 20:51:07 -0700 Subject: [PATCH 01/33] update: adjust {} -> [] --- .../wode/subworkflows/convergence_mixin.py | 4 +- src/py/mat3ra/wode/units/unit.py | 49 ++++++++++++++----- src/py/mat3ra/wode/workflows/workflow.py | 4 +- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index b3124524..f9f76fca 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -186,7 +186,7 @@ def add_convergence( and reciprocal_vector_ratios is None ): reciprocal_vector_ratios = PointsGridDataProvider( - context=unit_for_convergence.context + context={item["name"]: item["data"] for item in unit_for_convergence.context} ).get_reciprocal_vector_ratios() if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") @@ -199,7 +199,7 @@ def add_convergence( ) merged_context = self._merge_convergence_context( - unit_for_convergence.context, + {item["name"]: item["data"] for item in unit_for_convergence.context}, parameter.unit_context, ) unit_for_convergence.set_context(merged_context) diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py index 38686342..4bf64dbf 100644 --- a/src/py/mat3ra/wode/units/unit.py +++ b/src/py/mat3ra/wode/units/unit.py @@ -18,7 +18,7 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): head: Whether this unit is the head of the workflow next: Flowchart ID of the next unit tags: List of tags for the unit - context: Context data dictionary for the unit + context: Persisted context provider items for the unit """ id: str = Field(default_factory=get_uuid, alias="_id") flowchartId: str = Field(default_factory=get_uuid) @@ -27,16 +27,15 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): postProcessors: List[Any] = Field(default_factory=list) monitors: List[Any] = Field(default_factory=list) results: List[Any] = Field(default_factory=list) - context: Dict[str, Any] = Field(default_factory=dict) + context: List[Dict[str, Any]] = Field(default_factory=list) @field_validator("context", mode="before") @classmethod - def _coerce_context(cls, value: Any) -> Dict[str, Any]: - if value is None or value == []: - return {} + def _coerce_context(cls, value: Any) -> List[Dict[str, Any]]: + if value in (None, {}): + return [] return value - def get_hash_object(self) -> Dict[str, Any]: return { "results": self.results or [], @@ -49,16 +48,42 @@ def is_in_status(self, status: str) -> bool: return self.status == status def add_context(self, new_context: Dict[str, Any]): - self.context.update(new_context) + if "name" in new_context and "data" in new_context: + self._upsert_context_item(new_context) + return + for key, value in new_context.items(): + if key.startswith("is") and key.endswith("Edited"): + continue + edited = new_context.get(f"is{key[0].upper()}{key[1:]}Edited", False) + data = value if isinstance(value, dict) else {"value": value} + self._upsert_context_item({"name": key, "isEdited": bool(edited), "data": data}) - def set_context(self, new_context: Dict[str, Any]): - self.context = new_context + def set_context(self, new_context: Dict[str, Any] | List[Dict[str, Any]] | None): + if new_context in (None, {}, []): + self.context = [] + return + if isinstance(new_context, list): + self.context = new_context + return + self.context = [] + self.add_context(new_context) def get_context(self, key: str, default: Any = None) -> Any: - return self.context.get(key, default) + for item in self.context: + if item.get("name") != key: + continue + data = item.get("data", default) + if isinstance(data, dict) and set(data) == {"value"}: + return data["value"] + return data + return default def remove_context(self, key: str): - self.context.pop(key, None) + self.context = [item for item in self.context if item.get("name") != key] def clear_context(self): - self.context = {} + self.context = [] + + def _upsert_context_item(self, item: Dict[str, Any]): + name = item["name"] + self.context = [existing for existing in self.context if existing.get("name") != name] + [item] diff --git a/src/py/mat3ra/wode/workflows/workflow.py b/src/py/mat3ra/wode/workflows/workflow.py index e3ad9b8f..ad9829dd 100644 --- a/src/py/mat3ra/wode/workflows/workflow.py +++ b/src/py/mat3ra/wode/workflows/workflow.py @@ -62,7 +62,7 @@ def set_context_to_unit( new_context: Optional[Dict[str, Any]] = None, ): target_unit = self.get_unit_by_name(name=unit_name, name_regex=unit_name_regex) - target_unit.context = new_context + target_unit.set_context(new_context or []) def _find_relaxation_subworkflow(self) -> Optional[Subworkflow]: target_name = self.relaxation_subworkflow.name @@ -99,5 +99,5 @@ def to_dict_without_special_keys(self, special_keys=["context"]) -> Dict[str, An for unit in swf.get("units", []): for key in special_keys: if key in unit: - unit[key] = {} + unit[key] = [] return workflow_dict From d1cbc3970ea0d53688000d1a112c02f3062bfdf8 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 1 Jun 2026 20:51:22 -0700 Subject: [PATCH 02/33] update: tests: --- tests/py/test_convergence.py | 19 +++++++++---------- tests/py/test_subworkflow.py | 6 +++--- tests/py/test_unit.py | 11 +++++++---- tests/py/test_workflow.py | 6 ++---- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 8520d6ee..1715b92e 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -41,10 +41,9 @@ def test_add_uniform_energy_convergence(): ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] - assert pw_scf.context["kgrid"]["shifts"] == [0, 0, 0] - assert pw_scf.context["isKgridEdited"] is True - assert pw_scf.context["isUsingJinjaVariables"] is True + assert pw_scf.get_context("kgrid")["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] + assert pw_scf.get_context("kgrid")["shifts"] == [0, 0, 0] + assert any(item.get("name") == "kgrid" and item.get("isEdited") for item in pw_scf.context) assert subworkflow.convergence_parameter == ConvergenceParameterNameEnum.N_k.value assert subworkflow.convergence_result == "total_energy" @@ -82,12 +81,12 @@ def test_add_non_uniform_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == [ + assert pw_scf.get_context("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[2]}}}}", ] - assert pw_scf.context["kgrid"]["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform.value @@ -117,12 +116,12 @@ def test_add_non_uniform_2d_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context["kgrid"]["dimensions"] == [ + assert pw_scf.get_context("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[2]}}}}", ] - assert pw_scf.context["kgrid"]["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform_2D.value @@ -219,7 +218,7 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.context[param_name] == param_initial + assert pw_scf.get_context(param_name) == param_initial input_item = pw_scf.input[0] template_content = input_item.template.content assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content @@ -259,7 +258,7 @@ def test_add_template_param_convergence_multi_unit(): pw_bands = subworkflow.get_unit_by_name("pw_bands") for unit in [pw_scf, pw_bands]: - assert unit.context["ecutwfc"] == 20 + assert unit.get_context("ecutwfc") == 20 input_item = unit.input[0] template_content = input_item.template.content assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 4b9e0451..1f52f49b 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -109,7 +109,7 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): assert success is True updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") - assert updated_unit.context["test_key"] == "test_value" - assert updated_unit.context["another_key"] == 42 - assert updated_unit.context["kgrid"]["dimensions"] == [2, 2, 1] + assert updated_unit.get_context("test_key") == "test_value" + assert updated_unit.get_context("another_key") == 42 + assert updated_unit.get_context("kgrid")["dimensions"] == [2, 2, 1] assert updated_unit.input[0].rendered == original_rendered diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 22f23c76..e0a9a6a5 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -59,10 +59,13 @@ def test_add_context(): assert unit is not None assert "relax" in unit.name.lower() + assert unit.context == [] unit.add_context(NEW_CONTEXT_RELAX) - assert "kgrid" in unit.context - assert "convergence" in unit.context - assert unit.context["kgrid"] == NEW_CONTEXT_RELAX["kgrid"] - assert unit.context["convergence"] == NEW_CONTEXT_RELAX["convergence"] + assert unit.get_context("kgrid") == NEW_CONTEXT_RELAX["kgrid"] + assert unit.get_context("convergence") == NEW_CONTEXT_RELAX["convergence"] + assert unit.to_dict()["context"] == [ + {"name": "kgrid", "isEdited": False, "data": NEW_CONTEXT_RELAX["kgrid"]}, + {"name": "convergence", "isEdited": False, "data": NEW_CONTEXT_RELAX["convergence"]}, + ] diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 17500ddb..aecb1df3 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -180,10 +180,8 @@ def test_set_unit(method): assert success is True updated_unit = wf.get_unit_by_name(name_regex="relax") - assert "test_key" in updated_unit.context - assert "another_key" in updated_unit.context - assert updated_unit.context["test_key"] == "test_value" - assert updated_unit.context["another_key"] == 42 + assert updated_unit.get_context("test_key") == "test_value" + assert updated_unit.get_context("another_key") == 42 @pytest.mark.parametrize("workflow, app", [("band_gap", "espresso")]) From a1af9dbf0ba5e746463b8c93213c1f69563bd95b Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 15:42:31 -0700 Subject: [PATCH 03/33] update: tests --- tests/py/test_subworkflow.py | 11 ++++++ tests/py/test_unit.py | 14 ++++++-- tests/py/test_workflow.py | 69 ++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 1f52f49b..244edd73 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -2,6 +2,7 @@ from mat3ra.ade.application import Application from mat3ra.mode.method import Method from mat3ra.mode.model import Model +from mat3ra.mode.models.dft import DFTModel from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata @@ -23,6 +24,8 @@ "flowchartId": "unit-flowchart-id", "head": True, } +DFT_METHOD_CONFIG = {"type": "pseudopotential", "subtype": "us"} +DFT_MODEL_CONFIG_WITHOUT_FUNCTIONAL = {"type": "dft", "subtype": "gga", "method": DFT_METHOD_CONFIG} def test_creation(): @@ -54,6 +57,14 @@ def test_model(model_type, model_subtype): assert sw.model.subtype == model_subtype +@pytest.mark.parametrize("config", [DFT_MODEL_CONFIG_WITHOUT_FUNCTIONAL]) +def test_model_assignment_is_coerced_to_dft_model_with_default_functional(config): + subworkflow = Subworkflow(name=SUBWORKFLOW_NAME) + subworkflow.model = Model(**config) + assert isinstance(subworkflow.model, DFTModel) + assert subworkflow.model.to_dict().get("functional") == "pbe" + + def test_with_units(): unit = Unit(**UNIT_CONFIG) sw = Subworkflow(name=SUBWORKFLOW_NAME, units=[unit]) diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index e0a9a6a5..1db891fc 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -66,6 +66,16 @@ def test_add_context(): assert unit.get_context("kgrid") == NEW_CONTEXT_RELAX["kgrid"] assert unit.get_context("convergence") == NEW_CONTEXT_RELAX["convergence"] assert unit.to_dict()["context"] == [ - {"name": "kgrid", "isEdited": False, "data": NEW_CONTEXT_RELAX["kgrid"]}, - {"name": "convergence", "isEdited": False, "data": NEW_CONTEXT_RELAX["convergence"]}, + { + "name": "kgrid", + "isEdited": False, + "data": NEW_CONTEXT_RELAX["kgrid"], + "extraData": {}, + }, + { + "name": "convergence", + "isEdited": False, + "data": NEW_CONTEXT_RELAX["convergence"], + "extraData": {}, + }, ] diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index aecb1df3..0817944a 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -2,6 +2,8 @@ import os import pytest +from mat3ra.mode.methods.factory import MethodFactory +from mat3ra.mode.model import Model from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.subworkflows import SubworkflowStandata from mat3ra.standata.workflows import WorkflowStandata @@ -32,6 +34,20 @@ "head": True, } +BAND_STRUCTURE_SEARCH_NAME = "band_structure" +BAND_GAP_SEARCH_NAME = "espresso/band_gap\\.json$" +TOTAL_ENERGY_SEARCH_NAME = "total_energy" + +EXPECTED_MODEL_FUNCTIONAL = "pbe" +EXECUTION_UNIT_TYPE = "execution" +CONTEXT_ITEM_REQUIRED_KEYS = ("name", "isEdited", "data", "extraData") + +WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES = [ + BAND_STRUCTURE_SEARCH_NAME, + BAND_GAP_SEARCH_NAME, + TOTAL_ENERGY_SEARCH_NAME, +] + def test_creation(): wf = Workflow(name=WORKFLOW_NAME) @@ -196,3 +212,56 @@ def test_calculate_hash(workflow, app): fixture = next(w for w in workflows if w["name"] == BAND_GAP_WORKFLOW_NAME) wf = Workflow(**{k: v for k, v in fixture.items() if k != "hash"}) assert wf.hash == expected_hash + + +def _execution_units_from_payload(workflow_payload): + for subworkflow in workflow_payload.get("subworkflows", []): + for unit in subworkflow.get("units", []): + if unit.get("type") == EXECUTION_UNIT_TYPE: + yield unit + + +def _assert_subworkflow_models_have_functional(workflow_payload, expected_functional): + for subworkflow in workflow_payload.get("subworkflows", []): + model = subworkflow.get("model", {}) + assert model.get("functional") == expected_functional + + +def _assert_execution_unit_context_is_webapp_shaped(unit): + context = unit.get("context") + assert isinstance(context, list) + for item in context: + for key in CONTEXT_ITEM_REQUIRED_KEYS: + assert key in item + + +@pytest.mark.parametrize( + "workflow_search_name,expected_functional", + [(name, EXPECTED_MODEL_FUNCTIONAL) for name in WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES], + ids=WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES, +) +def test_workflow_to_dict_is_webapp_compatible(workflow_search_name, expected_functional): + workflow_config = WORKFLOW_STANDATA.get_by_name_first_match(workflow_search_name) + workflow = Workflow.create(workflow_config) + payload = workflow.to_dict() + + _assert_subworkflow_models_have_functional(payload, expected_functional) + + execution_units = list(_execution_units_from_payload(payload)) + assert execution_units + + for unit in execution_units: + _assert_execution_unit_context_is_webapp_shaped(unit) + + +def test_workflow_to_dict_is_json_serializable_after_model_assignment(): + workflow_config = WORKFLOW_STANDATA.get_by_name_first_match(BAND_STRUCTURE_SEARCH_NAME) + workflow = Workflow.create(workflow_config) + method = MethodFactory.create( + {"type": "pseudopotential", "subtype": "us", "data": {}}, + ) + assigned_model = Model(type="dft", subtype="gga", method=method, functional=EXPECTED_MODEL_FUNCTIONAL) + for subworkflow in workflow.subworkflows: + subworkflow.model = assigned_model + + json.dumps(workflow.to_dict()) From 13ff5ebc6f8871d9284e99ccc521e37e06c17067 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 15:42:49 -0700 Subject: [PATCH 04/33] update: restructure unit context --- src/py/mat3ra/wode/units/unit.py | 64 +++++++++++---------- src/py/mat3ra/wode/units/utils.py | 93 +++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 33 deletions(-) create mode 100644 src/py/mat3ra/wode/units/utils.py diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py index 4bf64dbf..bd0bcad2 100644 --- a/src/py/mat3ra/wode/units/unit.py +++ b/src/py/mat3ra/wode/units/unit.py @@ -1,11 +1,23 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +from mat3ra.ade.context.context_provider import ContextProvider from mat3ra.code.entity import InMemoryEntitySnakeCase from mat3ra.code.mixins import HashedEntityMixin from mat3ra.esse.models.workflow.unit.base import WorkflowBaseUnitSchema from mat3ra.utils.uuid import get_uuid from pydantic import Field, field_validator +from .utils import ( + context_item_from_provider, + context_items_from_input, + parse_persisted_context, + read_context_data, + upsert_context_item, +) + +ContextInput = Union[Dict[str, Any], ContextProvider] +ContextPayload = Union[Dict[str, Any], List[Dict[str, Any]], None] + class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): """ @@ -31,10 +43,8 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): @field_validator("context", mode="before") @classmethod - def _coerce_context(cls, value: Any) -> List[Dict[str, Any]]: - if value in (None, {}): - return [] - return value + def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: + return parse_persisted_context(value) def get_hash_object(self) -> Dict[str, Any]: return { @@ -47,43 +57,31 @@ def get_hash_object(self) -> Dict[str, Any]: def is_in_status(self, status: str) -> bool: return self.status == status - def add_context(self, new_context: Dict[str, Any]): - if "name" in new_context and "data" in new_context: - self._upsert_context_item(new_context) - return - for key, value in new_context.items(): - if key.startswith("is") and key.endswith("Edited"): - continue - edited = new_context.get(f"is{key[0].upper()}{key[1:]}Edited", False) - data = value if isinstance(value, dict) else {"value": value} - self._upsert_context_item({"name": key, "isEdited": bool(edited), "data": data}) + def add_context(self, payload: ContextInput) -> None: + if isinstance(payload, ContextProvider): + items = [context_item_from_provider(payload)] + else: + items = context_items_from_input(payload) + for item in items: + self.context = upsert_context_item(self.context, item) - def set_context(self, new_context: Dict[str, Any] | List[Dict[str, Any]] | None): - if new_context in (None, {}, []): + def set_context(self, payload: ContextPayload) -> None: + if not payload: self.context = [] return - if isinstance(new_context, list): - self.context = new_context + if isinstance(payload, list): + self.context = parse_persisted_context(payload) return - self.context = [] - self.add_context(new_context) + self.context = context_items_from_input(payload) def get_context(self, key: str, default: Any = None) -> Any: for item in self.context: - if item.get("name") != key: - continue - data = item.get("data", default) - if isinstance(data, dict) and set(data) == {"value"}: - return data["value"] - return data + if item.get("name") == key: + return read_context_data(item, default) return default - def remove_context(self, key: str): + def remove_context(self, key: str) -> None: self.context = [item for item in self.context if item.get("name") != key] - def clear_context(self): + def clear_context(self) -> None: self.context = [] - - def _upsert_context_item(self, item: Dict[str, Any]): - name = item["name"] - self.context = [existing for existing in self.context if existing.get("name") != name] + [item] diff --git a/src/py/mat3ra/wode/units/utils.py b/src/py/mat3ra/wode/units/utils.py new file mode 100644 index 00000000..6329d353 --- /dev/null +++ b/src/py/mat3ra/wode/units/utils.py @@ -0,0 +1,93 @@ +from typing import Any, Dict, List, Union + +from mat3ra.ade.context.context_provider import ContextProvider + +PersistedContextItem = Dict[str, Any] +ContextInput = Union[Dict[str, Any], ContextProvider] + +_EXTRA_DATA_SUFFIX = "ExtraData" + + +def to_persisted_context_item(item: Dict[str, Any]) -> PersistedContextItem: + return { + "name": item["name"], + "isEdited": bool(item.get("isEdited", False)), + "data": item.get("data", {}), + "extraData": item.get("extraData") or {}, + } + + +def parse_persisted_context(value: Any) -> List[PersistedContextItem]: + if value in (None, {}): + return [] + if not isinstance(value, list): + return value + return [to_persisted_context_item(item) for item in value if isinstance(item, dict) and item.get("name")] + + +def is_persisted_context_item(payload: Dict[str, Any]) -> bool: + return "name" in payload and "data" in payload + + +def _edited_flag_key(name: str) -> str: + return f"is{name[0].upper()}{name[1:]}Edited" + + +def _extra_data_key(name: str) -> str: + return f"{name}{_EXTRA_DATA_SUFFIX}" + + +def _data_names(flat: Dict[str, Any]) -> List[str]: + names = [key for key in flat if not key.endswith(_EXTRA_DATA_SUFFIX)] + edited_flag_names = {_edited_flag_key(name) for name in names if _edited_flag_key(name) in flat} + return [name for name in names if name not in edited_flag_names] + + +def context_item_from_provider(provider: ContextProvider) -> PersistedContextItem: + data = provider.yield_data() + name = provider.name_str + value = data[name] + return to_persisted_context_item( + { + "name": name, + "isEdited": bool(data.get(provider.is_edited_key, False)), + "data": value if isinstance(value, dict) else {"value": value}, + "extraData": data.get(provider.extra_data_key) or {}, + } + ) + + +def context_items_from_flat_context(flat: Dict[str, Any]) -> List[PersistedContextItem]: + items: List[PersistedContextItem] = [] + for name in _data_names(flat): + value = flat[name] + items.append( + to_persisted_context_item( + { + "name": name, + "isEdited": bool(flat.get(_edited_flag_key(name), False)), + "data": value if isinstance(value, dict) else {"value": value}, + "extraData": flat.get(_extra_data_key(name)) or {}, + } + ) + ) + return items + + +def context_items_from_input(payload: Dict[str, Any]) -> List[PersistedContextItem]: + if is_persisted_context_item(payload): + return [to_persisted_context_item(payload)] + return context_items_from_flat_context(payload) + + +def read_context_data(item: PersistedContextItem, default: Any = None) -> Any: + data = item.get("data", default) + if isinstance(data, dict) and set(data) == {"value"}: + return data["value"] + return data + + +def upsert_context_item(items: List[PersistedContextItem], item: Dict[str, Any]) -> List[PersistedContextItem]: + persisted_item = to_persisted_context_item(item) + name = persisted_item["name"] + return [entry for entry in items if entry.get("name") != name] + [persisted_item] From 876aab1631ec70546a2f9724ae124289cb1cabc5 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 15:43:00 -0700 Subject: [PATCH 05/33] update: model serialization --- src/py/mat3ra/wode/subworkflows/subworkflow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py index aa4167e2..b412a72c 100644 --- a/src/py/mat3ra/wode/subworkflows/subworkflow.py +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -9,7 +9,7 @@ from mat3ra.mode.model import Model from mat3ra.mode.models.factory import ModelFactory from mat3ra.utils.uuid import get_uuid -from pydantic import Field, field_validator +from pydantic import ConfigDict, Field, SerializeAsAny, field_validator from .convergence_mixin import ConvergenceMixin from ..mixins import FlowchartUnitsManager @@ -35,19 +35,21 @@ class Subworkflow( properties: List of properties extracted by the subworkflow """ + model_config = ConfigDict(validate_assignment=True) + id: str = Field(default_factory=get_uuid, alias="_id") application: Application = Field( default_factory=lambda: Application(name="", version="", build="", shortName="", summary="") ) properties: List[str] = Field(default_factory=list) - model: Model = Field(default_factory=DFTModel) + model: SerializeAsAny[Model] = Field(default_factory=DFTModel) units: List[Union[Unit, ExecutionUnit, SubworkflowUnit]] = Field(default_factory=list) @field_validator("model", mode="before") @classmethod def _instantiate_model(cls, value: Any) -> Any: if isinstance(value, Model): - return value + return ModelFactory.create(value.to_dict()) if isinstance(value, dict): return ModelFactory.create(value) return value From a11e6940cbc5b1aaaf362f63957a67de0c45ea6d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 15:43:09 -0700 Subject: [PATCH 06/33] chore: mode --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32a2074f..ac5e8960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "mat3ra-code", "mat3ra-utils", "mat3ra-esse", - "mat3ra-mode", + "mat3ra-mode @ git+https://github.com/Exabyte-io/mode.git@f9136d966dd7f558153f4a02ec8070a8059bd4ef", "mat3ra-ade", "mat3ra-made", "mat3ra-standata", From 9ac6a90ac51d1d9ffad0e2476f8ae38c0a0facd5 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 19:32:30 -0700 Subject: [PATCH 07/33] update: unit has no context --- src/py/mat3ra/wode/units/unit.py | 52 ++------------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py index bd0bcad2..1abd4afa 100644 --- a/src/py/mat3ra/wode/units/unit.py +++ b/src/py/mat3ra/wode/units/unit.py @@ -1,22 +1,10 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List -from mat3ra.ade.context.context_provider import ContextProvider from mat3ra.code.entity import InMemoryEntitySnakeCase from mat3ra.code.mixins import HashedEntityMixin from mat3ra.esse.models.workflow.unit.base import WorkflowBaseUnitSchema from mat3ra.utils.uuid import get_uuid -from pydantic import Field, field_validator - -from .utils import ( - context_item_from_provider, - context_items_from_input, - parse_persisted_context, - read_context_data, - upsert_context_item, -) - -ContextInput = Union[Dict[str, Any], ContextProvider] -ContextPayload = Union[Dict[str, Any], List[Dict[str, Any]], None] +from pydantic import Field class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): @@ -30,7 +18,6 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): head: Whether this unit is the head of the workflow next: Flowchart ID of the next unit tags: List of tags for the unit - context: Persisted context provider items for the unit """ id: str = Field(default_factory=get_uuid, alias="_id") flowchartId: str = Field(default_factory=get_uuid) @@ -39,12 +26,6 @@ class Unit(WorkflowBaseUnitSchema, HashedEntityMixin, InMemoryEntitySnakeCase): postProcessors: List[Any] = Field(default_factory=list) monitors: List[Any] = Field(default_factory=list) results: List[Any] = Field(default_factory=list) - context: List[Dict[str, Any]] = Field(default_factory=list) - - @field_validator("context", mode="before") - @classmethod - def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: - return parse_persisted_context(value) def get_hash_object(self) -> Dict[str, Any]: return { @@ -56,32 +37,3 @@ def get_hash_object(self) -> Dict[str, Any]: def is_in_status(self, status: str) -> bool: return self.status == status - - def add_context(self, payload: ContextInput) -> None: - if isinstance(payload, ContextProvider): - items = [context_item_from_provider(payload)] - else: - items = context_items_from_input(payload) - for item in items: - self.context = upsert_context_item(self.context, item) - - def set_context(self, payload: ContextPayload) -> None: - if not payload: - self.context = [] - return - if isinstance(payload, list): - self.context = parse_persisted_context(payload) - return - self.context = context_items_from_input(payload) - - def get_context(self, key: str, default: Any = None) -> Any: - for item in self.context: - if item.get("name") == key: - return read_context_data(item, default) - return default - - def remove_context(self, key: str) -> None: - self.context = [item for item in self.context if item.get("name") != key] - - def clear_context(self) -> None: - self.context = [] From c84862e2f2f59a241c601e7eb13b2e165c4f871c Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 19:33:01 -0700 Subject: [PATCH 08/33] update: possible type --- .../wode/mixins/flowchart_units_manager.py | 29 ++++++++++--------- .../mat3ra/wode/subworkflows/subworkflow.py | 2 +- src/py/mat3ra/wode/workflows/workflow.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/py/mat3ra/wode/mixins/flowchart_units_manager.py b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py index a7105963..f57ac041 100644 --- a/src/py/mat3ra/wode/mixins/flowchart_units_manager.py +++ b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py @@ -1,42 +1,43 @@ -from typing import List, Optional, TypeVar +from typing import Generic, List, Optional, TypeVar from mat3ra.utils import find_by_key_or_regex from ..units import Unit +UnitT = TypeVar("UnitT", bound=Unit) T = TypeVar("T") -class FlowchartUnitsManager: +class FlowchartUnitsManager(Generic[UnitT]): """ Mixin class providing common unit operations for flowchart units. - This mixin expects the class to have a `units: List[Unit]` attribute. + This mixin expects the class to have a `units: List[UnitT]` attribute. It provides common methods for managing units in both Workflow and Subworkflow classes. """ - units: List[Unit] + units: List[UnitT] - def set_units(self, units: List[Unit]) -> None: + def set_units(self, units: List[UnitT]) -> None: self.units = units def _set_units_order_in_place(self) -> None: self.set_units_head(self.units) self.set_next_links(self.units) - def get_unit(self, flowchart_id: str) -> Optional[Unit]: + def get_unit(self, flowchart_id: str) -> Optional[UnitT]: for unit in self.units: if unit.flowchartId == flowchart_id: return unit return None - def find_unit_by_id(self, id: str) -> Optional[Unit]: + def find_unit_by_id(self, id: str) -> Optional[UnitT]: for unit in self.units: if getattr(unit, "id", None) == id: return unit return None - def find_unit_with_tag(self, tag: str) -> Optional[Unit]: + def find_unit_with_tag(self, tag: str) -> Optional[UnitT]: for unit in self.units: if hasattr(unit, "tags") and unit.tags is not None and tag in unit.tags: return unit @@ -46,7 +47,7 @@ def get_unit_by_name( self, name: Optional[str] = None, name_regex: Optional[str] = None, - ) -> Optional[Unit]: + ) -> Optional[UnitT]: return find_by_key_or_regex(self.units, key="name", value=name, value_regex=name_regex) @staticmethod @@ -68,7 +69,7 @@ def _add_to_list(items: List[T], item: T, head: bool = False, index: int = -1) - else: items.append(item) - def set_units_head(self, units: List[Unit]) -> List[Unit]: + def set_units_head(self, units: List[UnitT]) -> List[UnitT]: """ Set the head flag on the first unit and unset it on all others. @@ -84,7 +85,7 @@ def set_units_head(self, units: List[Unit]) -> List[Unit]: unit.head = False return units - def set_next_links(self, units: List[Unit]) -> List[Unit]: + def set_next_links(self, units: List[UnitT]) -> List[UnitT]: """ Re-establishes the linked next => flowchartId logic in an array of units. @@ -118,7 +119,7 @@ def _clear_link_to_unit(self, flowchart_id: str) -> None: unit.next = None break - def add_unit(self, unit: Unit, head: bool = False, index: int = -1) -> None: + def add_unit(self, unit: UnitT, head: bool = False, index: int = -1) -> None: """ Add a unit to the units list. @@ -175,8 +176,8 @@ def replace_unit(self, index: int, unit: Unit) -> None: def set_unit( self, - new_unit: Unit, - unit: Optional[Unit] = None, + new_unit: UnitT, + unit: Optional[UnitT] = None, unit_flowchart_id: Optional[str] = None, ) -> bool: """ diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py index b412a72c..65c623cc 100644 --- a/src/py/mat3ra/wode/subworkflows/subworkflow.py +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -22,7 +22,7 @@ class Subworkflow( SubworkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, - FlowchartUnitsManager, + FlowchartUnitsManager[Unit], ): """ Subworkflow class representing a logical collection of workflow units. diff --git a/src/py/mat3ra/wode/workflows/workflow.py b/src/py/mat3ra/wode/workflows/workflow.py index ad9829dd..8478dda9 100644 --- a/src/py/mat3ra/wode/workflows/workflow.py +++ b/src/py/mat3ra/wode/workflows/workflow.py @@ -12,7 +12,7 @@ from ..units import Unit -class Workflow(WorkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, FlowchartUnitsManager): +class Workflow(WorkflowSchema, HashedEntityMixin, InMemoryEntitySnakeCase, FlowchartUnitsManager[Unit]): """ Workflow class representing a complete workflow configuration. From 5b714e96f5ee62290d4a1f7b7b879034736b1b4c Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:07:56 -0700 Subject: [PATCH 09/33] update: execution to have context --- src/py/mat3ra/wode/units/execution.py | 112 +++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 2 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 8ad43f52..20f2540f 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional, Union from mat3ra.ade import Application, Executable, Flavor -from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema +from mat3ra.ade.context.context_provider import ContextProvider +from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema from mat3ra.utils import ( calculate_hash_from_object, remove_comments_from_source_code, @@ -13,6 +14,7 @@ from .execution_unit_input import ExecutionUnitInput from .unit import Unit +Context = List[ContextItemSchema] class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" @@ -20,6 +22,7 @@ class ExecutionUnit(Unit, ExecutionUnitSchema): flavor: Flavor = None application: Application = None input: List[ExecutionUnitInput] = Field(default_factory=list) + context: Context = Field(default_factory=list) @field_validator("input", mode="before") @classmethod @@ -34,6 +37,111 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]: instantiated.append(ExecutionUnitInput(**item)) return instantiated + + @field_validator("context", mode="before") + @classmethod + def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: + if value is None: + return [] + if not isinstance(value, list): + return value + return [{ + "name": item["name"], + "isEdited": bool(item.get("isEdited", False)), + "data": item.get("data", {}), + "extraData": item.get("extraData") or {}, + } for item in value if isinstance(item, dict) and item.get("name")] + + @staticmethod + def _context_item_name(item: Any) -> Optional[str]: + if isinstance(item, dict): + return item.get("name") + name = getattr(item, "name", None) + return str(name) if name is not None else None + + def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: + for item in self.context: + if self._context_item_name(item) == name: + return item if isinstance(item, dict) else item.model_dump() + return None + + @staticmethod + def _read_context_data(item: Dict[str, Any], default: Any = None) -> Any: + data = item.get("data", default) + if isinstance(data, dict) and set(data) == {"value"}: + return data["value"] + return data + + @staticmethod + def context_item( + name: str, + data: Any, + *, + is_edited: bool = True, + extra_data: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + payload = data if isinstance(data, dict) else {"value": data} + return { + "name": name, + "isEdited": is_edited, + "data": payload, + "extraData": extra_data or {}, + } + + def _upsert_context_item(self, item: Dict[str, Any]) -> Context: + name = item["name"] + existing = self.get_context_item(name) + data = dict(item.get("data") or {}) + if existing: + merged = dict(existing.get("data") or {}) + merged.update(data) + data = merged + normalized: Dict[str, Any] = { + "name": name, + "isEdited": bool(item.get("isEdited", existing.get("isEdited", True) if existing else True)), + "data": data, + "extraData": item.get("extraData") or (existing.get("extraData") if existing else {}) or {}, + } + rest = [entry for entry in self.context if self._context_item_name(entry) != name] + return rest + [normalized] + + def add_context( + self, + name_or_item: Union[str, Dict[str, Any]], + data: Any = None, + *, + is_edited: bool = True, + extra_data: Optional[Dict[str, Any]] = None, + ) -> None: + if isinstance(name_or_item, dict): + item = name_or_item + else: + item = self.context_item(name_or_item, data, is_edited=is_edited, extra_data=extra_data) + self.context = self._upsert_context_item(item) + + def add_context_provider(self, provider: ContextProvider) -> None: + yielded = provider.yield_data() + name = provider.name_str + self.add_context( + name, + yielded[name], + is_edited=bool(yielded.get(provider.is_edited_key, False)), + extra_data=yielded.get(provider.extra_data_key) or {}, + ) + + def set_context(self, items: Context) -> None: + self.context = items + + def get_context(self, name: str, default: Any = None) -> Any: + item = self.get_context_item(name) + return self._read_context_data(item, default) if item else default + + def remove_context(self, name: str) -> None: + self.context = [item for item in self.context if self._context_item_name(item) != name] + + def clear_context(self) -> None: + self.context = [] + def replace_in_input_content(self, pattern: str, replacement: str, input_name=None) -> None: for item in self.input: if input_name is None or item.template.name == input_name: From 872bc1df18e8805ba4523466284cdcd602ac8df1 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:08:57 -0700 Subject: [PATCH 10/33] update: adjsut convergence --- .../wode/subworkflows/convergence_mixin.py | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index f9f76fca..30b90b3d 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -67,16 +67,6 @@ def _find_unit_for_convergence(self, result: str): return unit return None - @staticmethod - def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context: Dict[str, Any]) -> Dict[str, Any]: - merged_context = dict(unit_context) - merged_kgrid_context = dict(unit_context.get("kgrid") or {}) - merged_kgrid_context.update(convergence_context.get("kgrid") or {}) - merged_context.update(convergence_context) - if merged_kgrid_context: - merged_context["kgrid"] = merged_kgrid_context - return merged_context - def _build_convergence_units( self, parameter_name: str, @@ -185,9 +175,11 @@ def add_convergence( ) and reciprocal_vector_ratios is None ): - reciprocal_vector_ratios = PointsGridDataProvider( - context={item["name"]: item["data"] for item in unit_for_convergence.context} - ).get_reciprocal_vector_ratios() + kgrid_item = unit_for_convergence.get_context_item("kgrid") + provider_context = {"kgrid": kgrid_item["data"]} if kgrid_item else None + reciprocal_vector_ratios = PointsGridDataProvider().get_reciprocal_vector_ratios( + context=provider_context, + ) if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") @@ -198,11 +190,7 @@ def add_convergence( reciprocal_vector_ratios=reciprocal_vector_ratios, ) - merged_context = self._merge_convergence_context( - {item["name"]: item["data"] for item in unit_for_convergence.context}, - parameter.unit_context, - ) - unit_for_convergence.set_context(merged_context) + unit_for_convergence.add_context(parameter.unit_context) self._build_convergence_units( parameter_name=parameter.name, @@ -266,7 +254,7 @@ def add_template_parameter_convergence( execution_unit.replace_in_input_content( pattern, f"{parameter_name} = {scope_reference}", input_name=input_name ) - execution_unit.add_context({parameter_name: parameter_initial}) + execution_unit.add_context(parameter_name, parameter_initial, is_edited=True) self._build_convergence_units( parameter_name=parameter_name, From f13167802cef56861bb7829f9027d6711a507933 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:09:19 -0700 Subject: [PATCH 11/33] update: use context item --- .../wode/subworkflows/convergence/non_uniform_kgrid.py | 9 ++++++++- .../wode/subworkflows/convergence/uniform_kgrid.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py index dd7e185c..5d9d3bbf 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +from ...units.execution import ExecutionUnit from .uniform_kgrid import UniformKGridConvergence @@ -23,7 +24,7 @@ def increment(self) -> str: @property def unit_context(self) -> Dict[str, Any]: - return self._points_grid_context( + yielded = self._points_grid_context( dimensions=[ f"{{{{{self.name}[0]}}}}", f"{{{{{self.name}[1]}}}}", @@ -31,6 +32,12 @@ def unit_context(self) -> Dict[str, Any]: ], reciprocal_vector_ratios=self._reciprocal_vector_ratios, ) + return ExecutionUnit.context_item( + "kgrid", + yielded["kgrid"], + is_edited=True, + extra_data=yielded.get("kgridExtraData") or {}, + ) def use_variables_from_unit_context(self, flowchart_id: str) -> List[Dict[str, str]]: return [{"scope": flowchart_id, "name": "context"}] diff --git a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py index 929d8f74..917e5f35 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional from ...context.providers import PointsGridDataProvider +from ...units.execution import ExecutionUnit from .parameter import ConvergenceParameter @@ -20,9 +21,15 @@ def increment(self) -> str: @property def unit_context(self) -> Dict[str, Any]: - return self._points_grid_context( + yielded = self._points_grid_context( dimensions=[f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}"], ) + return ExecutionUnit.context_item( + "kgrid", + yielded["kgrid"], + is_edited=True, + extra_data=yielded.get("kgridExtraData") or {}, + ) @property def final_value(self) -> str: From 5409a7ec89b9dfefa9ce0660eed109b24f736a60 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:09:41 -0700 Subject: [PATCH 12/33] chore: cleanup --- src/py/mat3ra/wode/units/utils.py | 93 ------------------------------- 1 file changed, 93 deletions(-) delete mode 100644 src/py/mat3ra/wode/units/utils.py diff --git a/src/py/mat3ra/wode/units/utils.py b/src/py/mat3ra/wode/units/utils.py deleted file mode 100644 index 6329d353..00000000 --- a/src/py/mat3ra/wode/units/utils.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any, Dict, List, Union - -from mat3ra.ade.context.context_provider import ContextProvider - -PersistedContextItem = Dict[str, Any] -ContextInput = Union[Dict[str, Any], ContextProvider] - -_EXTRA_DATA_SUFFIX = "ExtraData" - - -def to_persisted_context_item(item: Dict[str, Any]) -> PersistedContextItem: - return { - "name": item["name"], - "isEdited": bool(item.get("isEdited", False)), - "data": item.get("data", {}), - "extraData": item.get("extraData") or {}, - } - - -def parse_persisted_context(value: Any) -> List[PersistedContextItem]: - if value in (None, {}): - return [] - if not isinstance(value, list): - return value - return [to_persisted_context_item(item) for item in value if isinstance(item, dict) and item.get("name")] - - -def is_persisted_context_item(payload: Dict[str, Any]) -> bool: - return "name" in payload and "data" in payload - - -def _edited_flag_key(name: str) -> str: - return f"is{name[0].upper()}{name[1:]}Edited" - - -def _extra_data_key(name: str) -> str: - return f"{name}{_EXTRA_DATA_SUFFIX}" - - -def _data_names(flat: Dict[str, Any]) -> List[str]: - names = [key for key in flat if not key.endswith(_EXTRA_DATA_SUFFIX)] - edited_flag_names = {_edited_flag_key(name) for name in names if _edited_flag_key(name) in flat} - return [name for name in names if name not in edited_flag_names] - - -def context_item_from_provider(provider: ContextProvider) -> PersistedContextItem: - data = provider.yield_data() - name = provider.name_str - value = data[name] - return to_persisted_context_item( - { - "name": name, - "isEdited": bool(data.get(provider.is_edited_key, False)), - "data": value if isinstance(value, dict) else {"value": value}, - "extraData": data.get(provider.extra_data_key) or {}, - } - ) - - -def context_items_from_flat_context(flat: Dict[str, Any]) -> List[PersistedContextItem]: - items: List[PersistedContextItem] = [] - for name in _data_names(flat): - value = flat[name] - items.append( - to_persisted_context_item( - { - "name": name, - "isEdited": bool(flat.get(_edited_flag_key(name), False)), - "data": value if isinstance(value, dict) else {"value": value}, - "extraData": flat.get(_extra_data_key(name)) or {}, - } - ) - ) - return items - - -def context_items_from_input(payload: Dict[str, Any]) -> List[PersistedContextItem]: - if is_persisted_context_item(payload): - return [to_persisted_context_item(payload)] - return context_items_from_flat_context(payload) - - -def read_context_data(item: PersistedContextItem, default: Any = None) -> Any: - data = item.get("data", default) - if isinstance(data, dict) and set(data) == {"value"}: - return data["value"] - return data - - -def upsert_context_item(items: List[PersistedContextItem], item: Dict[str, Any]) -> List[PersistedContextItem]: - persisted_item = to_persisted_context_item(item) - name = persisted_item["name"] - return [entry for entry in items if entry.get("name") != name] + [persisted_item] From 9c36a3653533eccc87f2dacea15527b2c93ad88d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:10:20 -0700 Subject: [PATCH 13/33] update: execution unit test --- tests/py/test_unit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 1db891fc..c807e56f 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -1,7 +1,7 @@ import pytest from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata -from mat3ra.wode import Unit +from mat3ra.wode import ExecutionUnit, Unit WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() @@ -55,13 +55,14 @@ def test_next_property(): def test_add_context(): - unit = Unit(**{**UNIT_CONFIG_EXECUTION, "name": "relaxation step"}) + unit = ExecutionUnit(**{**UNIT_CONFIG_EXECUTION, "name": "relaxation step"}) assert unit is not None assert "relax" in unit.name.lower() assert unit.context == [] - unit.add_context(NEW_CONTEXT_RELAX) + unit.add_context("kgrid", NEW_CONTEXT_RELAX["kgrid"], is_edited=False) + unit.add_context("convergence", NEW_CONTEXT_RELAX["convergence"], is_edited=False) assert unit.get_context("kgrid") == NEW_CONTEXT_RELAX["kgrid"] assert unit.get_context("convergence") == NEW_CONTEXT_RELAX["convergence"] From 06646aacdcd1b36051c219f6dbedff22134e01a6 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:15:39 -0700 Subject: [PATCH 14/33] update: check for execution unit --- tests/py/test_subworkflow.py | 10 +++++----- tests/py/test_workflow.py | 19 +++++++++++-------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 244edd73..295c38e4 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -6,7 +6,7 @@ from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata -from mat3ra.wode import Subworkflow, Unit, Workflow +from mat3ra.wode import Subworkflow, Unit, Workflow, ExecutionUnit from mat3ra.wode.context.providers import PointsGridDataProvider SUBWORKFLOW_NAME = "Total Energy" @@ -100,13 +100,13 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): relaxation_subworkflow = wf.subworkflows[0] unit_to_modify = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert unit_to_modify is not None + assert isinstance(unit_to_modify, ExecutionUnit) original_rendered = unit_to_modify.input[0].rendered - unit_to_modify.add_context({"test_key": "test_value", "another_key": 42}) - unit_to_modify.add_context( - PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True).yield_data() - ) + unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) + unit_to_modify.add_context({"name": "another_key", "data": 42}) + unit_to_modify.add_context_provider(PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True)) if method == "only_new_unit": success = relaxation_subworkflow.set_unit(unit_to_modify) diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 0817944a..94a1aa64 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -178,24 +178,27 @@ def test_set_unit(method): wf.add_relaxation() - unit_to_modify = wf.get_unit_by_name(name_regex="relax") + relaxation_subworkflow = wf._find_relaxation_subworkflow() + assert relaxation_subworkflow is not None + + unit_to_modify = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert unit_to_modify is not None - new_context = {"test_key": "test_value", "another_key": 42} - unit_to_modify.add_context(new_context) + unit_to_modify.add_context("test_key", "test_value") + unit_to_modify.add_context("another_key", 42) if method == "only_new_unit": - success = wf.set_unit(unit_to_modify) + success = relaxation_subworkflow.set_unit(unit_to_modify) elif method == "with_unit_instance": - original_unit = wf.get_unit_by_name(name_regex="relax") - success = wf.set_unit(unit_to_modify, unit=original_unit) + original_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") + success = relaxation_subworkflow.set_unit(unit_to_modify, unit=original_unit) elif method == "with_flowchart_id": flowchart_id = unit_to_modify.flowchartId - success = wf.set_unit(unit_to_modify, unit_flowchart_id=flowchart_id) + success = relaxation_subworkflow.set_unit(unit_to_modify, unit_flowchart_id=flowchart_id) assert success is True - updated_unit = wf.get_unit_by_name(name_regex="relax") + updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert updated_unit.get_context("test_key") == "test_value" assert updated_unit.get_context("another_key") == 42 From 97d48997ac18455e7f41eba89a8c4147f70d7157 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:16:16 -0700 Subject: [PATCH 15/33] chore: simplify --- src/py/mat3ra/wode/units/execution.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 20f2540f..3de350e8 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -88,22 +88,9 @@ def context_item( "extraData": extra_data or {}, } - def _upsert_context_item(self, item: Dict[str, Any]) -> Context: - name = item["name"] - existing = self.get_context_item(name) - data = dict(item.get("data") or {}) - if existing: - merged = dict(existing.get("data") or {}) - merged.update(data) - data = merged - normalized: Dict[str, Any] = { - "name": name, - "isEdited": bool(item.get("isEdited", existing.get("isEdited", True) if existing else True)), - "data": data, - "extraData": item.get("extraData") or (existing.get("extraData") if existing else {}) or {}, - } + def _replace_context_item(self, name: str, item: Dict[str, Any]) -> None: rest = [entry for entry in self.context if self._context_item_name(entry) != name] - return rest + [normalized] + self.context = rest + [item] def add_context( self, @@ -115,9 +102,14 @@ def add_context( ) -> None: if isinstance(name_or_item, dict): item = name_or_item + name = item["name"] + data = item.get("data") + is_edited = bool(item.get("isEdited", True)) + extra_data = item.get("extraData") or {} else: - item = self.context_item(name_or_item, data, is_edited=is_edited, extra_data=extra_data) - self.context = self._upsert_context_item(item) + name = name_or_item + item = self.context_item(name, data, is_edited=is_edited, extra_data=extra_data) + self._replace_context_item(name, item) def add_context_provider(self, provider: ContextProvider) -> None: yielded = provider.yield_data() From aab26fe25666ca3f440317178daa276de81c3a1f Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 2 Jun 2026 20:25:51 -0700 Subject: [PATCH 16/33] update: context to match NBs --- .../convergence/non_uniform_kgrid.py | 12 ++-- .../subworkflows/convergence/uniform_kgrid.py | 12 ++-- .../wode/subworkflows/convergence_mixin.py | 2 +- src/py/mat3ra/wode/units/execution.py | 66 +++++++++++-------- 4 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py index 5d9d3bbf..d2cef753 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Optional -from ...units.execution import ExecutionUnit from .uniform_kgrid import UniformKGridConvergence @@ -32,12 +31,11 @@ def unit_context(self) -> Dict[str, Any]: ], reciprocal_vector_ratios=self._reciprocal_vector_ratios, ) - return ExecutionUnit.context_item( - "kgrid", - yielded["kgrid"], - is_edited=True, - extra_data=yielded.get("kgridExtraData") or {}, - ) + return { + "name": "kgrid", + "data": yielded["kgrid"], + "extraData": yielded.get("kgridExtraData") or {}, + } def use_variables_from_unit_context(self, flowchart_id: str) -> List[Dict[str, str]]: return [{"scope": flowchart_id, "name": "context"}] diff --git a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py index 917e5f35..7636a90a 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional from ...context.providers import PointsGridDataProvider -from ...units.execution import ExecutionUnit from .parameter import ConvergenceParameter @@ -24,12 +23,11 @@ def unit_context(self) -> Dict[str, Any]: yielded = self._points_grid_context( dimensions=[f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}"], ) - return ExecutionUnit.context_item( - "kgrid", - yielded["kgrid"], - is_edited=True, - extra_data=yielded.get("kgridExtraData") or {}, - ) + return { + "name": "kgrid", + "data": yielded["kgrid"], + "extraData": yielded.get("kgridExtraData") or {}, + } @property def final_value(self) -> str: diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 30b90b3d..357f2767 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -254,7 +254,7 @@ def add_template_parameter_convergence( execution_unit.replace_in_input_content( pattern, f"{parameter_name} = {scope_reference}", input_name=input_name ) - execution_unit.add_context(parameter_name, parameter_initial, is_edited=True) + execution_unit.add_context({"name": parameter_name, "data": parameter_initial}) self._build_convergence_units( parameter_name=parameter_name, diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 3de350e8..88b7a9d8 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from mat3ra.ade import Application, Executable, Flavor from mat3ra.ade.context.context_provider import ContextProvider @@ -92,34 +92,46 @@ def _replace_context_item(self, name: str, item: Dict[str, Any]) -> None: rest = [entry for entry in self.context if self._context_item_name(entry) != name] self.context = rest + [item] - def add_context( - self, - name_or_item: Union[str, Dict[str, Any]], - data: Any = None, - *, - is_edited: bool = True, - extra_data: Optional[Dict[str, Any]] = None, - ) -> None: - if isinstance(name_or_item, dict): - item = name_or_item - name = item["name"] - data = item.get("data") - is_edited = bool(item.get("isEdited", True)) - extra_data = item.get("extraData") or {} - else: - name = name_or_item - item = self.context_item(name, data, is_edited=is_edited, extra_data=extra_data) - self._replace_context_item(name, item) + @staticmethod + def _normalized_context_item(item: Dict[str, Any]) -> Dict[str, Any]: + if "name" in item: + return ExecutionUnit.context_item( + item["name"], + item.get("data"), + is_edited=bool(item.get("isEdited", True)), + extra_data=item.get("extraData") or {}, + ) + return ExecutionUnit._context_item_from_provider_yield(item) + + @staticmethod + def _context_item_from_provider_yield(yielded: Dict[str, Any]) -> Dict[str, Any]: + name = None + data = None + is_edited = True + extra_data: Dict[str, Any] = {} + for key, value in yielded.items(): + if key == "isUsingJinjaVariables": + continue + if key.startswith("is") and key.endswith("Edited"): + is_edited = bool(value) + continue + if key.endswith("ExtraData"): + extra_data = value or {} + continue + if name is not None: + raise ValueError("yield_data() must contain a single provider data key") + name = key + data = value + if name is None: + raise ValueError("yield_data() must contain a provider data key") + return ExecutionUnit.context_item(name, data, is_edited=is_edited, extra_data=extra_data) + + def add_context(self, item: Dict[str, Any]) -> None: + normalized = self._normalized_context_item(item) + self._replace_context_item(normalized["name"], normalized) def add_context_provider(self, provider: ContextProvider) -> None: - yielded = provider.yield_data() - name = provider.name_str - self.add_context( - name, - yielded[name], - is_edited=bool(yielded.get(provider.is_edited_key, False)), - extra_data=yielded.get(provider.extra_data_key) or {}, - ) + self.add_context(provider.yield_data()) def set_context(self, items: Context) -> None: self.context = items From c3d2cb294e17028cf32230c315c16aa8cce52c1d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 10:13:46 -0700 Subject: [PATCH 17/33] update: validate context item --- src/py/mat3ra/wode/units/execution.py | 97 ++++++--------------------- 1 file changed, 19 insertions(+), 78 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 88b7a9d8..af46991c 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Literal, Optional from mat3ra.ade import Application, Executable, Flavor -from mat3ra.ade.context.context_provider import ContextProvider from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema from mat3ra.utils import ( calculate_hash_from_object, @@ -18,9 +17,9 @@ class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" - executable: Executable = None - flavor: Flavor = None - application: Application = None + application: Application + executable: Executable + flavor: Flavor input: List[ExecutionUnitInput] = Field(default_factory=list) context: Context = Field(default_factory=list) @@ -45,12 +44,13 @@ def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: return [] if not isinstance(value, list): return value - return [{ - "name": item["name"], - "isEdited": bool(item.get("isEdited", False)), - "data": item.get("data", {}), - "extraData": item.get("extraData") or {}, - } for item in value if isinstance(item, dict) and item.get("name")] + validated_items: List[Dict[str, Any]] = [] + for item in value: + if not isinstance(item, dict): + continue + validated_item = ContextItemSchema(**item) + validated_items.append(validated_item.model_dump(exclude_none=True)) + return validated_items @staticmethod def _context_item_name(item: Any) -> Optional[str]: @@ -66,79 +66,20 @@ def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: return None @staticmethod - def _read_context_data(item: Dict[str, Any], default: Any = None) -> Any: - data = item.get("data", default) - if isinstance(data, dict) and set(data) == {"value"}: - return data["value"] - return data - - @staticmethod - def context_item( - name: str, - data: Any, - *, - is_edited: bool = True, - extra_data: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - payload = data if isinstance(data, dict) else {"value": data} - return { - "name": name, - "isEdited": is_edited, - "data": payload, - "extraData": extra_data or {}, - } - - def _replace_context_item(self, name: str, item: Dict[str, Any]) -> None: - rest = [entry for entry in self.context if self._context_item_name(entry) != name] - self.context = rest + [item] - - @staticmethod - def _normalized_context_item(item: Dict[str, Any]) -> Dict[str, Any]: - if "name" in item: - return ExecutionUnit.context_item( - item["name"], - item.get("data"), - is_edited=bool(item.get("isEdited", True)), - extra_data=item.get("extraData") or {}, - ) - return ExecutionUnit._context_item_from_provider_yield(item) - - @staticmethod - def _context_item_from_provider_yield(yielded: Dict[str, Any]) -> Dict[str, Any]: - name = None - data = None - is_edited = True - extra_data: Dict[str, Any] = {} - for key, value in yielded.items(): - if key == "isUsingJinjaVariables": - continue - if key.startswith("is") and key.endswith("Edited"): - is_edited = bool(value) - continue - if key.endswith("ExtraData"): - extra_data = value or {} - continue - if name is not None: - raise ValueError("yield_data() must contain a single provider data key") - name = key - data = value - if name is None: - raise ValueError("yield_data() must contain a provider data key") - return ExecutionUnit.context_item(name, data, is_edited=is_edited, extra_data=extra_data) + def create_context_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + return ContextItemSchema(**item).model_dump(exclude_none=True) def add_context(self, item: Dict[str, Any]) -> None: - normalized = self._normalized_context_item(item) - self._replace_context_item(normalized["name"], normalized) - - def add_context_provider(self, provider: ContextProvider) -> None: - self.add_context(provider.yield_data()) - - def set_context(self, items: Context) -> None: - self.context = items + item = self.create_context_item(item) + existing_item = self.get_context_item(item.get("name")) + if existing_item: + existing_item.update(item) + else: + self.context.append(item) def get_context(self, name: str, default: Any = None) -> Any: item = self.get_context_item(name) - return self._read_context_data(item, default) if item else default + return item.get("data", default) if item else default def remove_context(self, name: str) -> None: self.context = [item for item in self.context if self._context_item_name(item) != name] From 5993c0c02dcad9ff52955ffc231088f3661d26f6 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 10:28:35 -0700 Subject: [PATCH 18/33] update: use standata for tests --- pyproject.toml | 1 + tests/py/fixtures/__init__.py | 16 ++++++++++++++++ tests/py/test_unit.py | 4 ++-- tests/py/units/test_execution_unit.py | 4 ++-- 4 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 tests/py/fixtures/__init__.py diff --git a/pyproject.toml b/pyproject.toml index ac5e8960..d9686a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ extend_skip_glob = ["dist/*"] [tool.pytest.ini_options] pythonpath = [ "src/py", + "tests/py", ] testpaths = [ "tests/py" diff --git a/tests/py/fixtures/__init__.py b/tests/py/fixtures/__init__.py new file mode 100644 index 00000000..14c7ed5a --- /dev/null +++ b/tests/py/fixtures/__init__.py @@ -0,0 +1,16 @@ +from typing import Any, Dict + +from mat3ra.standata.workflows import WorkflowStandata + +WORKFLOW_STANDATA = WorkflowStandata() + + +def execution_unit_config(application: str, workflow_name: str, unit_name: str) -> Dict[str, Any]: + workflows = WORKFLOW_STANDATA.get_by_categories(application, workflow_name) + if not workflows: + raise ValueError(f"No workflow {workflow_name!r} for application {application!r}") + for subworkflow in workflows[0]["subworkflows"]: + for unit in subworkflow["units"]: + if unit.get("type") == "execution" and unit.get("name") == unit_name: + return unit + raise ValueError(f"No execution unit {unit_name!r} in workflow {workflow_name!r}") diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index c807e56f..abaea7bc 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -2,6 +2,7 @@ from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata from mat3ra.wode import ExecutionUnit, Unit +from fixtures import execution_unit_config WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() @@ -15,8 +16,7 @@ NEW_CONTEXT_RELAX = {"kgrid": {"density": 0.5}, "convergence": {"threshold": 1e-6}} UNIT_CONFIG_EXECUTION = { - "type": "execution", - "name": "pw_scf", + **execution_unit_config(APPLICATION_ESPRESSO, "total_energy", "pw_scf"), "flowchartId": UNIT_FLOWCHART_ID, "head": True, } diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py index a0724e95..18baa6d9 100644 --- a/tests/py/units/test_execution_unit.py +++ b/tests/py/units/test_execution_unit.py @@ -2,10 +2,10 @@ import pytest from mat3ra.wode.units.execution import ExecutionUnit +from fixtures import execution_unit_config UNIT_CONFIG = { - "type": "execution", - "name": "pw_scf", + **execution_unit_config("espresso", "total_energy", "pw_scf"), "flowchartId": "abc-123", "head": True, } From 8a17b420cb1c66e41c85313d8c95945838dd9a7a Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 10:49:05 -0700 Subject: [PATCH 19/33] update: context shape in tests --- .../convergence/non_uniform_kgrid.py | 1 + .../subworkflows/convergence/uniform_kgrid.py | 1 + src/py/mat3ra/wode/units/execution.py | 19 +++---------------- tests/py/test_subworkflow.py | 11 ++++++++++- tests/py/test_unit.py | 18 ++++-------------- tests/py/test_workflow.py | 4 ++-- 6 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py index d2cef753..70365172 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py @@ -33,6 +33,7 @@ def unit_context(self) -> Dict[str, Any]: ) return { "name": "kgrid", + "isEdited": bool(yielded.get("isKgridEdited", True)), "data": yielded["kgrid"], "extraData": yielded.get("kgridExtraData") or {}, } diff --git a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py index 7636a90a..2ead6da6 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py @@ -25,6 +25,7 @@ def unit_context(self) -> Dict[str, Any]: ) return { "name": "kgrid", + "isEdited": bool(yielded.get("isKgridEdited", True)), "data": yielded["kgrid"], "extraData": yielded.get("kgridExtraData") or {}, } diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index af46991c..73d737da 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -39,18 +39,10 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]: @field_validator("context", mode="before") @classmethod - def _validate_context(cls, value: Any) -> List[Dict[str, Any]]: + def _coerce_context(cls, value: Any) -> Any: if value is None: return [] - if not isinstance(value, list): - return value - validated_items: List[Dict[str, Any]] = [] - for item in value: - if not isinstance(item, dict): - continue - validated_item = ContextItemSchema(**item) - validated_items.append(validated_item.model_dump(exclude_none=True)) - return validated_items + return value @staticmethod def _context_item_name(item: Any) -> Optional[str]: @@ -65,13 +57,8 @@ def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: return item if isinstance(item, dict) else item.model_dump() return None - @staticmethod - def create_context_item(self, item: Dict[str, Any]) -> Dict[str, Any]: - return ContextItemSchema(**item).model_dump(exclude_none=True) - def add_context(self, item: Dict[str, Any]) -> None: - item = self.create_context_item(item) - existing_item = self.get_context_item(item.get("name")) + existing_item = self.get_context_item(item["name"]) if existing_item: existing_item.update(item) else: diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 295c38e4..6cbd6f6b 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -106,7 +106,16 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) unit_to_modify.add_context({"name": "another_key", "data": 42}) - unit_to_modify.add_context_provider(PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True)) + points_grid_provider = PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True) + points_grid_context = points_grid_provider.yield_data() + unit_to_modify.add_context( + { + "name": "kgrid", + "isEdited": bool(points_grid_context.get("isKgridEdited", True)), + "data": points_grid_context.get("kgrid"), + "extraData": points_grid_context.get("kgridExtraData") or {}, + } + ) if method == "only_new_unit": success = relaxation_subworkflow.set_unit(unit_to_modify) diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index abaea7bc..3a70e825 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -61,22 +61,12 @@ def test_add_context(): assert "relax" in unit.name.lower() assert unit.context == [] - unit.add_context("kgrid", NEW_CONTEXT_RELAX["kgrid"], is_edited=False) - unit.add_context("convergence", NEW_CONTEXT_RELAX["convergence"], is_edited=False) + unit.add_context({"name": "kgrid", "data": NEW_CONTEXT_RELAX["kgrid"], "isEdited": False}) + unit.add_context({"name": "convergence", "data": NEW_CONTEXT_RELAX["convergence"], "isEdited": False}) assert unit.get_context("kgrid") == NEW_CONTEXT_RELAX["kgrid"] assert unit.get_context("convergence") == NEW_CONTEXT_RELAX["convergence"] assert unit.to_dict()["context"] == [ - { - "name": "kgrid", - "isEdited": False, - "data": NEW_CONTEXT_RELAX["kgrid"], - "extraData": {}, - }, - { - "name": "convergence", - "isEdited": False, - "data": NEW_CONTEXT_RELAX["convergence"], - "extraData": {}, - }, + {"name": "kgrid", "isEdited": False, "data": NEW_CONTEXT_RELAX["kgrid"]}, + {"name": "convergence", "isEdited": False, "data": NEW_CONTEXT_RELAX["convergence"]}, ] diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 94a1aa64..b2e7b12a 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -184,8 +184,8 @@ def test_set_unit(method): unit_to_modify = relaxation_subworkflow.get_unit_by_name(name_regex="relax") assert unit_to_modify is not None - unit_to_modify.add_context("test_key", "test_value") - unit_to_modify.add_context("another_key", 42) + unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) + unit_to_modify.add_context({"name": "another_key", "data": 42}) if method == "only_new_unit": success = relaxation_subworkflow.set_unit(unit_to_modify) From dd0f48e9660edf6f59c0990925b514d32e8eb84f Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 10:51:20 -0700 Subject: [PATCH 20/33] update: preserve yield data --- .../mat3ra/wode/context/providers/base/__init__.py | 0 .../wode/context/providers/base/context_provider.py | 13 +++++++++++++ .../providers/planewave_cutoffs_context_provider.py | 2 +- .../context/providers/points_grid_data_provider.py | 2 +- .../context/providers/points_path_data_provider.py | 2 +- tests/py/test_subworkflow.py | 10 +--------- 6 files changed, 17 insertions(+), 12 deletions(-) create mode 100644 src/py/mat3ra/wode/context/providers/base/__init__.py create mode 100644 src/py/mat3ra/wode/context/providers/base/context_provider.py diff --git a/src/py/mat3ra/wode/context/providers/base/__init__.py b/src/py/mat3ra/wode/context/providers/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/context/providers/base/context_provider.py b/src/py/mat3ra/wode/context/providers/base/context_provider.py new file mode 100644 index 00000000..62cce8c9 --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/base/context_provider.py @@ -0,0 +1,13 @@ +from typing import Any, Dict + +from mat3ra.ade.context.context_provider import ContextProvider as AdeContextProvider + + +class ContextProvider(AdeContextProvider): + def get_context_item_data(self) -> Dict[str, Any]: + return { + "name": self.name_str, + "isEdited": self.is_edited, + "data": self.get_data(), + "extraData": self.extra_data or {}, + } diff --git a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py index c1487278..54b33606 100644 --- a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py +++ b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py @@ -1,6 +1,6 @@ from typing import Any, Dict, Optional -from mat3ra.ade.context.context_provider import ContextProvider +from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.planewave_cutoffs_context_provider import ( PlanewaveCutoffsContextProviderSchema, ) diff --git a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py index a451fb5a..99029f9a 100644 --- a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from mat3ra.ade.context.context_provider import ContextProvider +from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import ( GridMetricType, PointsGridDataProviderSchema, diff --git a/src/py/mat3ra/wode/context/providers/points_path_data_provider.py b/src/py/mat3ra/wode/context/providers/points_path_data_provider.py index 0a45699b..b3a222f7 100644 --- a/src/py/mat3ra/wode/context/providers/points_path_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_path_data_provider.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from mat3ra.ade.context.context_provider import ContextProvider +from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.points_path_data_provider import ( PointsPathDataProviderSchemaItem, ) diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 6cbd6f6b..537ef8a7 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -107,15 +107,7 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): unit_to_modify.add_context({"name": "test_key", "data": "test_value"}) unit_to_modify.add_context({"name": "another_key", "data": 42}) points_grid_provider = PointsGridDataProvider(dimensions=[2, 2, 1], isEdited=True) - points_grid_context = points_grid_provider.yield_data() - unit_to_modify.add_context( - { - "name": "kgrid", - "isEdited": bool(points_grid_context.get("isKgridEdited", True)), - "data": points_grid_context.get("kgrid"), - "extraData": points_grid_context.get("kgridExtraData") or {}, - } - ) + unit_to_modify.add_context(points_grid_provider.get_context_item_data()) if method == "only_new_unit": success = relaxation_subworkflow.set_unit(unit_to_modify) From 066c5469a87c6b735f61a639943adb386d7feee0 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 11:46:55 -0700 Subject: [PATCH 21/33] update: cleanup --- src/py/mat3ra/wode/units/execution.py | 14 ++-------- tests/py/test_workflow.py | 40 --------------------------- 2 files changed, 3 insertions(+), 51 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 73d737da..e55fa5e5 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from mat3ra.ade import Application, Executable, Flavor from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema @@ -37,24 +37,16 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]: return instantiated - @field_validator("context", mode="before") - @classmethod - def _coerce_context(cls, value: Any) -> Any: - if value is None: - return [] - return value - @staticmethod def _context_item_name(item: Any) -> Optional[str]: if isinstance(item, dict): return item.get("name") - name = getattr(item, "name", None) - return str(name) if name is not None else None + return str(item.name) def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: for item in self.context: if self._context_item_name(item) == name: - return item if isinstance(item, dict) else item.model_dump() + return item return None def add_context(self, item: Dict[str, Any]) -> None: diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index b2e7b12a..53056867 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -217,46 +217,6 @@ def test_calculate_hash(workflow, app): assert wf.hash == expected_hash -def _execution_units_from_payload(workflow_payload): - for subworkflow in workflow_payload.get("subworkflows", []): - for unit in subworkflow.get("units", []): - if unit.get("type") == EXECUTION_UNIT_TYPE: - yield unit - - -def _assert_subworkflow_models_have_functional(workflow_payload, expected_functional): - for subworkflow in workflow_payload.get("subworkflows", []): - model = subworkflow.get("model", {}) - assert model.get("functional") == expected_functional - - -def _assert_execution_unit_context_is_webapp_shaped(unit): - context = unit.get("context") - assert isinstance(context, list) - for item in context: - for key in CONTEXT_ITEM_REQUIRED_KEYS: - assert key in item - - -@pytest.mark.parametrize( - "workflow_search_name,expected_functional", - [(name, EXPECTED_MODEL_FUNCTIONAL) for name in WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES], - ids=WEBAPP_COMPATIBLE_WORKFLOW_SEARCH_NAMES, -) -def test_workflow_to_dict_is_webapp_compatible(workflow_search_name, expected_functional): - workflow_config = WORKFLOW_STANDATA.get_by_name_first_match(workflow_search_name) - workflow = Workflow.create(workflow_config) - payload = workflow.to_dict() - - _assert_subworkflow_models_have_functional(payload, expected_functional) - - execution_units = list(_execution_units_from_payload(payload)) - assert execution_units - - for unit in execution_units: - _assert_execution_unit_context_is_webapp_shaped(unit) - - def test_workflow_to_dict_is_json_serializable_after_model_assignment(): workflow_config = WORKFLOW_STANDATA.get_by_name_first_match(BAND_STRUCTURE_SEARCH_NAME) workflow = Workflow.create(workflow_config) From f2325082226c97ac471480e0c699a87905983ecf Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 12:23:02 -0700 Subject: [PATCH 22/33] chore: rename --- src/py/mat3ra/wode/units/execution.py | 12 +++++++----- tests/py/test_convergence.py | 16 ++++++++-------- tests/py/test_subworkflow.py | 6 +++--- tests/py/test_unit.py | 4 ++-- tests/py/test_workflow.py | 4 ++-- 5 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index e55fa5e5..94d447cb 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from mat3ra.ade import Application, Executable, Flavor from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema @@ -49,6 +49,12 @@ def get_context_item(self, name: str) -> Optional[Dict[str, Any]]: return item return None + def get_context_item_data(self, name: str, default: Any = None) -> Any: + if default is None: + default = {} + item = self.get_context_item(name) + return item.get("data", default) if item else default + def add_context(self, item: Dict[str, Any]) -> None: existing_item = self.get_context_item(item["name"]) if existing_item: @@ -56,10 +62,6 @@ def add_context(self, item: Dict[str, Any]) -> None: else: self.context.append(item) - def get_context(self, name: str, default: Any = None) -> Any: - item = self.get_context_item(name) - return item.get("data", default) if item else default - def remove_context(self, name: str) -> None: self.context = [item for item in self.context if self._context_item_name(item) != name] diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index 1715b92e..e21add03 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -41,8 +41,8 @@ def test_add_uniform_energy_convergence(): ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.get_context("kgrid")["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] - assert pw_scf.get_context("kgrid")["shifts"] == [0, 0, 0] + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == ["{{N_k}}", "{{N_k}}", "{{N_k}}"] + assert pw_scf.get_context_item_data("kgrid")["shifts"] == [0, 0, 0] assert any(item.get("name") == "kgrid" and item.get("isEdited") for item in pw_scf.context) assert subworkflow.convergence_parameter == ConvergenceParameterNameEnum.N_k.value @@ -81,12 +81,12 @@ def test_add_non_uniform_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.get_context("kgrid")["dimensions"] == [ + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform.value}[2]}}}}", ] - assert pw_scf.get_context("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform.value @@ -116,12 +116,12 @@ def test_add_non_uniform_2d_energy_convergence(): ) pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.get_context("kgrid")["dimensions"] == [ + assert pw_scf.get_context_item_data("kgrid")["dimensions"] == [ f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[0]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[1]}}}}", f"{{{{{ConvergenceParameterNameEnum.N_k_nonuniform_2D.value}[2]}}}}", ] - assert pw_scf.get_context("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform_2D.value @@ -218,7 +218,7 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.get_context(param_name) == param_initial + assert pw_scf.get_context_item_data(param_name) == param_initial input_item = pw_scf.input[0] template_content = input_item.template.content assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content @@ -258,7 +258,7 @@ def test_add_template_param_convergence_multi_unit(): pw_bands = subworkflow.get_unit_by_name("pw_bands") for unit in [pw_scf, pw_bands]: - assert unit.get_context("ecutwfc") == 20 + assert unit.get_context_item_data("ecutwfc") == 20 input_item = unit.input[0] template_content = input_item.template.content assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 537ef8a7..72be7904 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -121,7 +121,7 @@ def test_set_unit_keeps_rendered_input_for_context_only_update(method): assert success is True updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") - assert updated_unit.get_context("test_key") == "test_value" - assert updated_unit.get_context("another_key") == 42 - assert updated_unit.get_context("kgrid")["dimensions"] == [2, 2, 1] + assert updated_unit.get_context_item_data("test_key") == "test_value" + assert updated_unit.get_context_item_data("another_key") == 42 + assert updated_unit.get_context_item_data("kgrid")["dimensions"] == [2, 2, 1] assert updated_unit.input[0].rendered == original_rendered diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 3a70e825..8bb82185 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -64,8 +64,8 @@ def test_add_context(): unit.add_context({"name": "kgrid", "data": NEW_CONTEXT_RELAX["kgrid"], "isEdited": False}) unit.add_context({"name": "convergence", "data": NEW_CONTEXT_RELAX["convergence"], "isEdited": False}) - assert unit.get_context("kgrid") == NEW_CONTEXT_RELAX["kgrid"] - assert unit.get_context("convergence") == NEW_CONTEXT_RELAX["convergence"] + assert unit.get_context_item_data("kgrid") == NEW_CONTEXT_RELAX["kgrid"] + assert unit.get_context_item_data("convergence") == NEW_CONTEXT_RELAX["convergence"] assert unit.to_dict()["context"] == [ {"name": "kgrid", "isEdited": False, "data": NEW_CONTEXT_RELAX["kgrid"]}, {"name": "convergence", "isEdited": False, "data": NEW_CONTEXT_RELAX["convergence"]}, diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 53056867..6da5115d 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -199,8 +199,8 @@ def test_set_unit(method): assert success is True updated_unit = relaxation_subworkflow.get_unit_by_name(name_regex="relax") - assert updated_unit.get_context("test_key") == "test_value" - assert updated_unit.get_context("another_key") == 42 + assert updated_unit.get_context_item_data("test_key") == "test_value" + assert updated_unit.get_context_item_data("another_key") == 42 @pytest.mark.parametrize("workflow, app", [("band_gap", "espresso")]) From 3e871d38d40119162d2ec04dc605369acb9c9091 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 12:45:40 -0700 Subject: [PATCH 23/33] update: kgrid context --- .../subworkflows/convergence/non_uniform_kgrid.py | 8 +------- .../wode/subworkflows/convergence/uniform_kgrid.py | 13 ++++--------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py index 70365172..dd7e185c 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/non_uniform_kgrid.py @@ -23,7 +23,7 @@ def increment(self) -> str: @property def unit_context(self) -> Dict[str, Any]: - yielded = self._points_grid_context( + return self._points_grid_context( dimensions=[ f"{{{{{self.name}[0]}}}}", f"{{{{{self.name}[1]}}}}", @@ -31,12 +31,6 @@ def unit_context(self) -> Dict[str, Any]: ], reciprocal_vector_ratios=self._reciprocal_vector_ratios, ) - return { - "name": "kgrid", - "isEdited": bool(yielded.get("isKgridEdited", True)), - "data": yielded["kgrid"], - "extraData": yielded.get("kgridExtraData") or {}, - } def use_variables_from_unit_context(self, flowchart_id: str) -> List[Dict[str, str]]: return [{"scope": flowchart_id, "name": "context"}] diff --git a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py index 2ead6da6..05369634 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py @@ -8,11 +8,12 @@ class UniformKGridConvergence(ConvergenceParameter): def _points_grid_context( self, dimensions: List[str], reciprocal_vector_ratios: Optional[List[float]] = None ) -> Dict[str, Any]: - return PointsGridDataProvider().yield_data_with_overrides( + provider = PointsGridDataProvider(isEdited=True) + provider.data = provider.build_data( dimensions=dimensions, reciprocal_vector_ratios=reciprocal_vector_ratios, - is_using_jinja_variables=True, ) + return provider.get_context_item_data() @property def increment(self) -> str: @@ -20,15 +21,9 @@ def increment(self) -> str: @property def unit_context(self) -> Dict[str, Any]: - yielded = self._points_grid_context( + return self._points_grid_context( dimensions=[f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}", f"{{{{{self.name}}}}}"], ) - return { - "name": "kgrid", - "isEdited": bool(yielded.get("isKgridEdited", True)), - "data": yielded["kgrid"], - "extraData": yielded.get("kgridExtraData") or {}, - } @property def final_value(self) -> str: From 7ccf85acd1d990cadf2382597fc4805ac6fd508e Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 3 Jun 2026 13:10:52 -0700 Subject: [PATCH 24/33] update: correction for serialization --- src/py/mat3ra/wode/subworkflows/subworkflow.py | 2 +- src/py/mat3ra/wode/units/execution.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py index 65c623cc..8cb941e5 100644 --- a/src/py/mat3ra/wode/subworkflows/subworkflow.py +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -43,7 +43,7 @@ class Subworkflow( ) properties: List[str] = Field(default_factory=list) model: SerializeAsAny[Model] = Field(default_factory=DFTModel) - units: List[Union[Unit, ExecutionUnit, SubworkflowUnit]] = Field(default_factory=list) + units: List[SerializeAsAny[Union[Unit, ExecutionUnit, SubworkflowUnit]]] = Field(default_factory=list) @field_validator("model", mode="before") @classmethod diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 94d447cb..918748a6 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -1,27 +1,25 @@ from typing import Any, Dict, List, Literal, Optional from mat3ra.ade import Application, Executable, Flavor -from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema, ContextItemSchema +from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema from mat3ra.utils import ( calculate_hash_from_object, remove_comments_from_source_code, remove_empty_lines_from_string, ) from mat3ra.utils.extra.jinja import replace_in_text -from pydantic import Field, field_validator +from pydantic import Field, SerializeAsAny, field_validator from .execution_unit_input import ExecutionUnitInput from .unit import Unit -Context = List[ContextItemSchema] - class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" application: Application executable: Executable flavor: Flavor input: List[ExecutionUnitInput] = Field(default_factory=list) - context: Context = Field(default_factory=list) + context: List[SerializeAsAny[Dict[str, Any]]] = Field(default_factory=list) @field_validator("input", mode="before") @classmethod From f5d946b37a86a0f8dc22fcac31d32fffa6aab199 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 13:25:26 -0700 Subject: [PATCH 25/33] update: set convergence --- .../wode/subworkflows/convergence_mixin.py | 44 ++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 357f2767..73230b66 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -2,7 +2,6 @@ from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum from mat3ra.utils.extra.jinja import JINJA_EXPRESSION_PATTERN, NUMERIC_VALUE_PATTERN, wrap_text_in_raw_block - from .convergence.factory import create_convergence_parameter from ..context.providers import PointsGridDataProvider from ..units import Unit @@ -67,6 +66,32 @@ def _find_unit_for_convergence(self, result: str): return unit return None + @staticmethod + def _wire_convergence_flow( + host: ConvergenceHost, + param_init: Unit, + prev_result_init: Unit, + iter_init: Unit, + store_result: Unit, + condition_unit: Unit, + store_prev_result: Unit, + next_iter: Unit, + next_step: Unit, + execution_unit_flowchart_id: str, + ) -> None: + param_init.next = prev_result_init.flowchartId + prev_result_init.next = iter_init.flowchartId + iter_init.next = execution_unit_flowchart_id + + execution_unit = host.get_unit(execution_unit_flowchart_id) + if execution_unit is not None: + execution_unit.next = store_result.flowchartId + + store_result.next = condition_unit.flowchartId + store_prev_result.next = next_iter.flowchartId + next_iter.next = next_step.flowchartId + next_step.next = execution_unit_flowchart_id + def _build_convergence_units( self, parameter_name: str, @@ -143,7 +168,18 @@ def _build_convergence_units( host.add_unit(next_step) host.add_unit(exit_unit) - next_step.next = execution_unit_flowchart_id + self._wire_convergence_flow( + host=host, + param_init=param_init, + prev_result_init=prev_result_init, + iter_init=iter_init, + store_result=store_result, + condition_unit=condition_unit, + store_prev_result=store_prev_result, + next_iter=next_iter, + next_step=next_step, + execution_unit_flowchart_id=execution_unit_flowchart_id, + ) def add_convergence( self, @@ -176,9 +212,8 @@ def add_convergence( and reciprocal_vector_ratios is None ): kgrid_item = unit_for_convergence.get_context_item("kgrid") - provider_context = {"kgrid": kgrid_item["data"]} if kgrid_item else None reciprocal_vector_ratios = PointsGridDataProvider().get_reciprocal_vector_ratios( - context=provider_context, + context=kgrid_item, ) if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") @@ -254,7 +289,6 @@ def add_template_parameter_convergence( execution_unit.replace_in_input_content( pattern, f"{parameter_name} = {scope_reference}", input_name=input_name ) - execution_unit.add_context({"name": parameter_name, "data": parameter_initial}) self._build_convergence_units( parameter_name=parameter_name, From 347589588287672158ae87f9ebf1d54bc53a4157 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 13:26:57 -0700 Subject: [PATCH 26/33] update: adjsut test --- tests/py/test_convergence.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index e21add03..b1a964d2 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -218,7 +218,9 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme ] pw_scf = subworkflow.get_unit_by_name(name="pw_scf") - assert pw_scf.get_context_item_data(param_name) == param_initial + init_parameter = subworkflow.get_unit_by_name(name="init parameter") + assert init_parameter.operand == param_name + assert init_parameter.value == param_initial input_item = pw_scf.input[0] template_content = input_item.template.content assert f"{param_name} = {{% raw %}}{{{{ {param_name} }}}}{{% endraw %}}" in template_content @@ -257,8 +259,11 @@ def test_add_template_param_convergence_multi_unit(): pw_scf = subworkflow.get_unit_by_name("pw_scf") pw_bands = subworkflow.get_unit_by_name("pw_bands") + init_parameter = subworkflow.get_unit_by_name(name="init parameter") + assert init_parameter.operand == "ecutwfc" + assert init_parameter.value == 20 + for unit in [pw_scf, pw_bands]: - assert unit.get_context_item_data("ecutwfc") == 20 input_item = unit.input[0] template_content = input_item.template.content assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content From a59b743a649aa7108c6abbd23c90c77ed5783dcb Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 17:02:57 -0700 Subject: [PATCH 27/33] update: model --- src/py/mat3ra/wode/subworkflows/subworkflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py index 8cb941e5..46f0bd67 100644 --- a/src/py/mat3ra/wode/subworkflows/subworkflow.py +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -49,7 +49,7 @@ class Subworkflow( @classmethod def _instantiate_model(cls, value: Any) -> Any: if isinstance(value, Model): - return ModelFactory.create(value.to_dict()) + value = value.to_dict() if isinstance(value, dict): return ModelFactory.create(value) return value From eb4145453ea1c6dea900e3587ae7318f960a17ab Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 17:17:44 -0700 Subject: [PATCH 28/33] chore: simplify --- .../wode/subworkflows/convergence_mixin.py | 49 +++++-------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index 73230b66..cff2b175 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -66,31 +66,6 @@ def _find_unit_for_convergence(self, result: str): return unit return None - @staticmethod - def _wire_convergence_flow( - host: ConvergenceHost, - param_init: Unit, - prev_result_init: Unit, - iter_init: Unit, - store_result: Unit, - condition_unit: Unit, - store_prev_result: Unit, - next_iter: Unit, - next_step: Unit, - execution_unit_flowchart_id: str, - ) -> None: - param_init.next = prev_result_init.flowchartId - prev_result_init.next = iter_init.flowchartId - iter_init.next = execution_unit_flowchart_id - - execution_unit = host.get_unit(execution_unit_flowchart_id) - if execution_unit is not None: - execution_unit.next = store_result.flowchartId - - store_result.next = condition_unit.flowchartId - store_prev_result.next = next_iter.flowchartId - next_iter.next = next_step.flowchartId - next_step.next = execution_unit_flowchart_id def _build_convergence_units( self, @@ -168,18 +143,18 @@ def _build_convergence_units( host.add_unit(next_step) host.add_unit(exit_unit) - self._wire_convergence_flow( - host=host, - param_init=param_init, - prev_result_init=prev_result_init, - iter_init=iter_init, - store_result=store_result, - condition_unit=condition_unit, - store_prev_result=store_prev_result, - next_iter=next_iter, - next_step=next_step, - execution_unit_flowchart_id=execution_unit_flowchart_id, - ) + param_init.next = prev_result_init.flowchartId + prev_result_init.next = iter_init.flowchartId + iter_init.next = execution_unit_flowchart_id + + execution_unit = host.get_unit(execution_unit_flowchart_id) + if execution_unit is not None: + execution_unit.next = store_result.flowchartId + + store_result.next = condition_unit.flowchartId + store_prev_result.next = next_iter.flowchartId + next_iter.next = next_step.flowchartId + next_step.next = execution_unit_flowchart_id def add_convergence( self, From e5c2a74a27c27fcbe7082991b5be5303abf18dda Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 17:28:58 -0700 Subject: [PATCH 29/33] update: test convergence --- tests/py/test_convergence.py | 47 ++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/tests/py/test_convergence.py b/tests/py/test_convergence.py index b1a964d2..6e1e20a2 100644 --- a/tests/py/test_convergence.py +++ b/tests/py/test_convergence.py @@ -12,6 +12,29 @@ def _build_total_energy_subworkflow(): return workflow.subworkflows[0] +def _assert_convergence_flowchart_links(subworkflow, execution_unit_name): + def unit(name): + return subworkflow.get_unit_by_name(name=name) + + next_links = [ + ("init parameter", "init result"), + ("init result", "init counter"), + ("init counter", execution_unit_name), + (execution_unit_name, "update result"), + ("update result", "check convergence"), + ("check convergence", "store result"), + ("store result", "update counter"), + ("update counter", "update parameter"), + ("update parameter", execution_unit_name), + ] + for from_name, to_name in next_links: + assert unit(from_name).next == unit(to_name).flowchartId + + check_convergence = unit("check convergence") + assert check_convergence.then == unit("exit").flowchartId + assert getattr(check_convergence, "else") == unit("store result").flowchartId + + def test_add_uniform_energy_convergence(): subworkflow = _build_total_energy_subworkflow() @@ -49,10 +72,11 @@ def test_add_uniform_energy_convergence(): assert subworkflow.convergence_result == "total_energy" assert subworkflow.has_convergence is True + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k.value assert update_parameter.value == f"{ConvergenceParameterNameEnum.N_k.value} + 1" - assert update_parameter.next == pw_scf.flowchartId check_convergence = subworkflow.get_unit_by_name(name="check convergence") assert check_convergence.input == [] @@ -88,13 +112,15 @@ def test_add_non_uniform_energy_convergence(): ] assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform.value assert update_parameter.input == [{"scope": pw_scf.flowchartId, "name": "context"}] assert ( - update_parameter.value - == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " - "for i in range(3)]" + update_parameter.value + == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " + "for i in range(3)]" ) @@ -123,13 +149,15 @@ def test_add_non_uniform_2d_energy_convergence(): ] assert pw_scf.get_context_item_data("kgrid")["reciprocalVectorRatios"] == reciprocal_vector_ratios + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == ConvergenceParameterNameEnum.N_k_nonuniform_2D.value assert update_parameter.input == [{"scope": pw_scf.flowchartId, "name": "context"}] assert ( - update_parameter.value - == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " - "for i in range(2)] + [1]" + update_parameter.value + == "[[2,2,1][i] + math.floor(iteration * 2 * float(context['kgrid']['reciprocalVectorRatios'][i])) " + "for i in range(2)] + [1]" ) @@ -230,11 +258,12 @@ def test_add_template_param_convergence(param_name, param_initial, param_increme assert subworkflow.convergence_result == result_name assert subworkflow.has_convergence is True + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + update_parameter = subworkflow.get_unit_by_name(name="update parameter") assert update_parameter.operand == param_name assert update_parameter.value == f"{param_name} + {param_increment}" assert update_parameter.input == [] - assert update_parameter.next == pw_scf.flowchartId exit_unit = subworkflow.get_unit_by_name(name="exit") assert exit_unit.operand == param_name @@ -269,5 +298,7 @@ def test_add_template_param_convergence_multi_unit(): assert "ecutwfc = {% raw %}{{ ecutwfc }}{% endraw %}" in template_content assert "ecutwfc = {{ cutoffs.wavefunction }}" not in template_content + _assert_convergence_flowchart_links(subworkflow, execution_unit_name="pw_scf") + assert subworkflow.convergence_parameter == "ecutwfc" assert subworkflow.convergence_result == "total_energy" From 84c413794da6c5a5af3fc4bed39a0beaea367b84 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 20:15:03 -0700 Subject: [PATCH 30/33] chore: address pr comment --- .../wode/context/providers/base/context_provider.py | 2 +- tests/py/fixtures/__init__.py | 2 +- tests/py/test_unit.py | 4 ++-- tests/py/test_workflow.py | 12 ------------ tests/py/units/test_execution_unit.py | 4 ++-- 5 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/py/mat3ra/wode/context/providers/base/context_provider.py b/src/py/mat3ra/wode/context/providers/base/context_provider.py index 62cce8c9..ab06aa6d 100644 --- a/src/py/mat3ra/wode/context/providers/base/context_provider.py +++ b/src/py/mat3ra/wode/context/providers/base/context_provider.py @@ -2,7 +2,7 @@ from mat3ra.ade.context.context_provider import ContextProvider as AdeContextProvider - +# TODO: Remove context provider from Ade -- sync with JS implementation fully class ContextProvider(AdeContextProvider): def get_context_item_data(self) -> Dict[str, Any]: return { diff --git a/tests/py/fixtures/__init__.py b/tests/py/fixtures/__init__.py index 14c7ed5a..e8bc2443 100644 --- a/tests/py/fixtures/__init__.py +++ b/tests/py/fixtures/__init__.py @@ -5,7 +5,7 @@ WORKFLOW_STANDATA = WorkflowStandata() -def execution_unit_config(application: str, workflow_name: str, unit_name: str) -> Dict[str, Any]: +def get_execution_unit_config_by_application_workflow_unit(application: str, workflow_name: str, unit_name: str) -> Dict[str, Any]: workflows = WORKFLOW_STANDATA.get_by_categories(application, workflow_name) if not workflows: raise ValueError(f"No workflow {workflow_name!r} for application {application!r}") diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index 8bb82185..ff2cdbd0 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -2,7 +2,7 @@ from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata from mat3ra.wode import ExecutionUnit, Unit -from fixtures import execution_unit_config +from fixtures import get_execution_unit_config_by_application_workflow_unit WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() @@ -16,7 +16,7 @@ NEW_CONTEXT_RELAX = {"kgrid": {"density": 0.5}, "convergence": {"threshold": 1e-6}} UNIT_CONFIG_EXECUTION = { - **execution_unit_config(APPLICATION_ESPRESSO, "total_energy", "pw_scf"), + **get_execution_unit_config_by_application_workflow_unit(APPLICATION_ESPRESSO, "total_energy", "pw_scf"), "flowchartId": UNIT_FLOWCHART_ID, "head": True, } diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 6da5115d..05ac7db0 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -216,15 +216,3 @@ def test_calculate_hash(workflow, app): wf = Workflow(**{k: v for k, v in fixture.items() if k != "hash"}) assert wf.hash == expected_hash - -def test_workflow_to_dict_is_json_serializable_after_model_assignment(): - workflow_config = WORKFLOW_STANDATA.get_by_name_first_match(BAND_STRUCTURE_SEARCH_NAME) - workflow = Workflow.create(workflow_config) - method = MethodFactory.create( - {"type": "pseudopotential", "subtype": "us", "data": {}}, - ) - assigned_model = Model(type="dft", subtype="gga", method=method, functional=EXPECTED_MODEL_FUNCTIONAL) - for subworkflow in workflow.subworkflows: - subworkflow.model = assigned_model - - json.dumps(workflow.to_dict()) diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py index 18baa6d9..01ff5d6a 100644 --- a/tests/py/units/test_execution_unit.py +++ b/tests/py/units/test_execution_unit.py @@ -2,10 +2,10 @@ import pytest from mat3ra.wode.units.execution import ExecutionUnit -from fixtures import execution_unit_config +from fixtures import get_execution_unit_config_by_application_workflow_unit UNIT_CONFIG = { - **execution_unit_config("espresso", "total_energy", "pw_scf"), + **get_execution_unit_config_by_application_workflow_unit("espresso", "total_energy", "pw_scf"), "flowchartId": "abc-123", "head": True, } From 98e10eb9b750d59cc3810a9539bf67701a0f3205 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 20:22:14 -0700 Subject: [PATCH 31/33] chore: lint --- .../context/providers/planewave_cutoffs_context_provider.py | 3 ++- .../wode/context/providers/points_grid_data_provider.py | 2 +- src/py/mat3ra/wode/subworkflows/convergence/factory.py | 1 + src/py/mat3ra/wode/subworkflows/convergence_mixin.py | 6 ++++-- src/py/mat3ra/wode/units/execution.py | 1 + tests/py/test_unit.py | 3 ++- 6 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py index 54b33606..d0140284 100644 --- a/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py +++ b/src/py/mat3ra/wode/context/providers/planewave_cutoffs_context_provider.py @@ -1,11 +1,12 @@ from typing import Any, Dict, Optional -from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.planewave_cutoffs_context_provider import ( PlanewaveCutoffsContextProviderSchema, ) from pydantic import Field +from .base.context_provider import ContextProvider + class PlanewaveCutoffsContextProvider(PlanewaveCutoffsContextProviderSchema, ContextProvider): name: str = Field(default="cutoffs") diff --git a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py index 99029f9a..85ae7eed 100644 --- a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py @@ -1,12 +1,12 @@ from typing import Any, Dict, List, Optional -from .base.context_provider import ContextProvider from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import ( GridMetricType, PointsGridDataProviderSchema, ) from pydantic import Field +from .base.context_provider import ContextProvider DEFAULT_KPPRA = -1 diff --git a/src/py/mat3ra/wode/subworkflows/convergence/factory.py b/src/py/mat3ra/wode/subworkflows/convergence/factory.py index 398659cc..4a55f1b5 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence/factory.py +++ b/src/py/mat3ra/wode/subworkflows/convergence/factory.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum + from .non_uniform_kgrid import NonUniformKGridConvergence from .non_uniform_kgrid_2d import NonUniformKGridConvergence2D from .parameter import ConvergenceParameter diff --git a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py index cff2b175..a6ccc1a5 100644 --- a/src/py/mat3ra/wode/subworkflows/convergence_mixin.py +++ b/src/py/mat3ra/wode/subworkflows/convergence_mixin.py @@ -187,9 +187,11 @@ def add_convergence( and reciprocal_vector_ratios is None ): kgrid_item = unit_for_convergence.get_context_item("kgrid") - reciprocal_vector_ratios = PointsGridDataProvider().get_reciprocal_vector_ratios( - context=kgrid_item, + provider = PointsGridDataProvider( + data=kgrid_item.get("data"), + is_edited=kgrid_item.get("isEdited"), ) + reciprocal_vector_ratios = provider.get_reciprocal_vector_ratios() if reciprocal_vector_ratios is None: raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.") diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py index 918748a6..7c346228 100644 --- a/src/py/mat3ra/wode/units/execution.py +++ b/src/py/mat3ra/wode/units/execution.py @@ -13,6 +13,7 @@ from .execution_unit_input import ExecutionUnitInput from .unit import Unit + class ExecutionUnit(Unit, ExecutionUnitSchema): type: Literal["execution"] = "execution" application: Application diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py index ff2cdbd0..8e35a8c5 100644 --- a/tests/py/test_unit.py +++ b/tests/py/test_unit.py @@ -1,8 +1,9 @@ import pytest from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.workflows import WorkflowStandata -from mat3ra.wode import ExecutionUnit, Unit + from fixtures import get_execution_unit_config_by_application_workflow_unit +from mat3ra.wode import ExecutionUnit, Unit WORKFLOW_STANDATA = WorkflowStandata() APPLICATION_STANDATA = ApplicationStandata() From 305e5bf5970d471a9e59dc2bd2bfaac2ad81cbe2 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 20:42:22 -0700 Subject: [PATCH 32/33] chore: lint --- tests/py/fixtures/__init__.py | 3 ++- tests/py/test_workflow.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/py/fixtures/__init__.py b/tests/py/fixtures/__init__.py index e8bc2443..8775c11b 100644 --- a/tests/py/fixtures/__init__.py +++ b/tests/py/fixtures/__init__.py @@ -5,7 +5,8 @@ WORKFLOW_STANDATA = WorkflowStandata() -def get_execution_unit_config_by_application_workflow_unit(application: str, workflow_name: str, unit_name: str) -> Dict[str, Any]: +def get_execution_unit_config_by_application_workflow_unit(application: str, workflow_name: str, unit_name: str) -> \ +Dict[str, Any]: workflows = WORKFLOW_STANDATA.get_by_categories(application, workflow_name) if not workflows: raise ValueError(f"No workflow {workflow_name!r} for application {application!r}") diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 05ac7db0..1d67ed7c 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -2,8 +2,6 @@ import os import pytest -from mat3ra.mode.methods.factory import MethodFactory -from mat3ra.mode.model import Model from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.subworkflows import SubworkflowStandata from mat3ra.standata.workflows import WorkflowStandata From 035f2c865e9d0281e5e3fa47fc801be3d512c31a Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 4 Jun 2026 21:58:39 -0700 Subject: [PATCH 33/33] chore: pyproj --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d9686a02..cf3a9510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "mat3ra-code", "mat3ra-utils", "mat3ra-esse", - "mat3ra-mode @ git+https://github.com/Exabyte-io/mode.git@f9136d966dd7f558153f4a02ec8070a8059bd4ef", + "mat3ra-mode", "mat3ra-ade", "mat3ra-made", "mat3ra-standata",