Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
dd5c3c3
update: adjust {} -> []
VsevolodX Jun 2, 2026
d1cbc39
update: tests:
VsevolodX Jun 2, 2026
a1af9db
update: tests
VsevolodX Jun 2, 2026
13ff5eb
update: restructure unit context
VsevolodX Jun 2, 2026
876aab1
update: model serialization
VsevolodX Jun 2, 2026
a11e694
chore: mode
VsevolodX Jun 2, 2026
9ac6a90
update: unit has no context
VsevolodX Jun 3, 2026
c84862e
update: possible type
VsevolodX Jun 3, 2026
5b714e9
update: execution to have context
VsevolodX Jun 3, 2026
872bc1d
update: adjsut convergence
VsevolodX Jun 3, 2026
f131678
update: use context item
VsevolodX Jun 3, 2026
5409a7e
chore: cleanup
VsevolodX Jun 3, 2026
9c36a36
update: execution unit test
VsevolodX Jun 3, 2026
06646aa
update: check for execution unit
VsevolodX Jun 3, 2026
97d4899
chore: simplify
VsevolodX Jun 3, 2026
aab26fe
update: context to match NBs
VsevolodX Jun 3, 2026
c3d2cb2
update: validate context item
VsevolodX Jun 3, 2026
5993c0c
update: use standata for tests
VsevolodX Jun 3, 2026
8a17b42
update: context shape in tests
VsevolodX Jun 3, 2026
dd0f48e
update: preserve yield data
VsevolodX Jun 3, 2026
066c546
update: cleanup
VsevolodX Jun 3, 2026
f232508
chore: rename
VsevolodX Jun 3, 2026
3e871d3
update: kgrid context
VsevolodX Jun 3, 2026
7ccf85a
update: correction for serialization
VsevolodX Jun 3, 2026
f5d946b
update: set convergence
VsevolodX Jun 4, 2026
3475895
update: adjsut test
VsevolodX Jun 4, 2026
a59b743
update: model
VsevolodX Jun 5, 2026
eb41454
chore: simplify
VsevolodX Jun 5, 2026
e5c2a74
update: test convergence
VsevolodX Jun 5, 2026
84c4137
chore: address pr comment
VsevolodX Jun 5, 2026
98e10eb
chore: lint
VsevolodX Jun 5, 2026
305e5bf
chore: lint
VsevolodX Jun 5, 2026
035f2c8
chore: pyproj
VsevolodX Jun 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ extend_skip_glob = ["dist/*"]
[tool.pytest.ini_options]
pythonpath = [
"src/py",
"tests/py",
]
testpaths = [
"tests/py"
Expand Down
Empty file.
13 changes: 13 additions & 0 deletions src/py/mat3ra/wode/context/providers/base/context_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Dict

from mat3ra.ade.context.context_provider import ContextProvider as AdeContextProvider

# TODO: Remove context provider from Ade -- sync with JS implementation fully
class ContextProvider(AdeContextProvider):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: remove context provider from Ade

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add todo pls

def get_context_item_data(self) -> Dict[str, Any]:
return {
"name": self.name_str,
"isEdited": self.is_edited,
"data": self.get_data(),
"extraData": self.extra_data or {},
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional

from mat3ra.ade.context.context_provider import ContextProvider
from mat3ra.esse.models.context_providers_directory.planewave_cutoffs_context_provider import (
PlanewaveCutoffsContextProviderSchema,
)
from pydantic import Field

from .base.context_provider import ContextProvider


class PlanewaveCutoffsContextProvider(PlanewaveCutoffsContextProviderSchema, ContextProvider):
name: str = Field(default="cutoffs")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, List, Optional

from mat3ra.ade.context.context_provider import ContextProvider
from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import (
GridMetricType,
PointsGridDataProviderSchema,
)
from pydantic import Field

from .base.context_provider import ContextProvider

DEFAULT_KPPRA = -1

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List

from mat3ra.ade.context.context_provider import ContextProvider
from .base.context_provider import ContextProvider
from mat3ra.esse.models.context_providers_directory.points_path_data_provider import (
PointsPathDataProviderSchemaItem,
)
Expand Down
29 changes: 15 additions & 14 deletions src/py/mat3ra/wode/mixins/flowchart_units_manager.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
from typing import List, Optional, TypeVar
from typing import Generic, List, Optional, TypeVar

from mat3ra.utils import find_by_key_or_regex

from ..units import Unit

UnitT = TypeVar("UnitT", bound=Unit)
T = TypeVar("T")


class FlowchartUnitsManager:
class FlowchartUnitsManager(Generic[UnitT]):
"""
Mixin class providing common unit operations for flowchart units.

This mixin expects the class to have a `units: List[Unit]` attribute.
This mixin expects the class to have a `units: List[UnitT]` attribute.
It provides common methods for managing units in both Workflow and Subworkflow classes.
"""

units: List[Unit]
units: List[UnitT]

def set_units(self, units: List[Unit]) -> None:
def set_units(self, units: List[UnitT]) -> None:
self.units = units

def _set_units_order_in_place(self) -> None:
self.set_units_head(self.units)
self.set_next_links(self.units)

def get_unit(self, flowchart_id: str) -> Optional[Unit]:
def get_unit(self, flowchart_id: str) -> Optional[UnitT]:
for unit in self.units:
if unit.flowchartId == flowchart_id:
return unit
return None

def find_unit_by_id(self, id: str) -> Optional[Unit]:
def find_unit_by_id(self, id: str) -> Optional[UnitT]:
for unit in self.units:
if getattr(unit, "id", None) == id:
return unit
return None

def find_unit_with_tag(self, tag: str) -> Optional[Unit]:
def find_unit_with_tag(self, tag: str) -> Optional[UnitT]:
for unit in self.units:
if hasattr(unit, "tags") and unit.tags is not None and tag in unit.tags:
return unit
Expand All @@ -46,7 +47,7 @@ def get_unit_by_name(
self,
name: Optional[str] = None,
name_regex: Optional[str] = None,
) -> Optional[Unit]:
) -> Optional[UnitT]:
return find_by_key_or_regex(self.units, key="name", value=name, value_regex=name_regex)

@staticmethod
Expand All @@ -68,7 +69,7 @@ def _add_to_list(items: List[T], item: T, head: bool = False, index: int = -1) -
else:
items.append(item)

def set_units_head(self, units: List[Unit]) -> List[Unit]:
def set_units_head(self, units: List[UnitT]) -> List[UnitT]:
"""
Set the head flag on the first unit and unset it on all others.

Expand All @@ -84,7 +85,7 @@ def set_units_head(self, units: List[Unit]) -> List[Unit]:
unit.head = False
return units

def set_next_links(self, units: List[Unit]) -> List[Unit]:
def set_next_links(self, units: List[UnitT]) -> List[UnitT]:
"""
Re-establishes the linked next => flowchartId logic in an array of units.

Expand Down Expand Up @@ -118,7 +119,7 @@ def _clear_link_to_unit(self, flowchart_id: str) -> None:
unit.next = None
break

def add_unit(self, unit: Unit, head: bool = False, index: int = -1) -> None:
def add_unit(self, unit: UnitT, head: bool = False, index: int = -1) -> None:
"""
Add a unit to the units list.

Expand Down Expand Up @@ -175,8 +176,8 @@ def replace_unit(self, index: int, unit: Unit) -> None:

def set_unit(
self,
new_unit: Unit,
unit: Optional[Unit] = None,
new_unit: UnitT,
unit: Optional[UnitT] = None,
unit_flowchart_id: Optional[str] = None,
) -> bool:
"""
Expand Down
1 change: 1 addition & 0 deletions src/py/mat3ra/wode/subworkflows/convergence/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, List, Optional

from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum

from .non_uniform_kgrid import NonUniformKGridConvergence
from .non_uniform_kgrid_2d import NonUniformKGridConvergence2D
from .parameter import ConvergenceParameter
Expand Down
5 changes: 3 additions & 2 deletions src/py/mat3ra/wode/subworkflows/convergence/uniform_kgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ class UniformKGridConvergence(ConvergenceParameter):
def _points_grid_context(
self, dimensions: List[str], reciprocal_vector_ratios: Optional[List[float]] = None
) -> Dict[str, Any]:
return PointsGridDataProvider().yield_data_with_overrides(
provider = PointsGridDataProvider(isEdited=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it need to be True here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We replace dimensions with Jinja vars, so it means that we edit it by definition.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll address all of this when removing CP from Ade and syncing py with JS

provider.data = provider.build_data(
dimensions=dimensions,
reciprocal_vector_ratios=reciprocal_vector_ratios,
is_using_jinja_variables=True,
)
return provider.get_context_item_data()

@property
def increment(self) -> str:
Expand Down
37 changes: 18 additions & 19 deletions src/py/mat3ra/wode/subworkflows/convergence_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from mat3ra.esse.models.workflow.subworkflow.convergence.enum_options import ConvergenceParameterNameEnum
from mat3ra.utils.extra.jinja import JINJA_EXPRESSION_PATTERN, NUMERIC_VALUE_PATTERN, wrap_text_in_raw_block

from .convergence.factory import create_convergence_parameter
from ..context.providers import PointsGridDataProvider
from ..units import Unit
Expand Down Expand Up @@ -67,15 +66,6 @@ def _find_unit_for_convergence(self, result: str):
return unit
return None

@staticmethod
def _merge_convergence_context(unit_context: Dict[str, Any], convergence_context: Dict[str, Any]) -> Dict[str, Any]:
merged_context = dict(unit_context)
merged_kgrid_context = dict(unit_context.get("kgrid") or {})
merged_kgrid_context.update(convergence_context.get("kgrid") or {})
merged_context.update(convergence_context)
if merged_kgrid_context:
merged_context["kgrid"] = merged_kgrid_context
return merged_context

def _build_convergence_units(
self,
Expand Down Expand Up @@ -153,6 +143,17 @@ def _build_convergence_units(
host.add_unit(next_step)
host.add_unit(exit_unit)

param_init.next = prev_result_init.flowchartId
prev_result_init.next = iter_init.flowchartId
iter_init.next = execution_unit_flowchart_id

execution_unit = host.get_unit(execution_unit_flowchart_id)
if execution_unit is not None:
execution_unit.next = store_result.flowchartId

store_result.next = condition_unit.flowchartId
store_prev_result.next = next_iter.flowchartId
next_iter.next = next_step.flowchartId
next_step.next = execution_unit_flowchart_id

def add_convergence(
Expand Down Expand Up @@ -185,9 +186,12 @@ def add_convergence(
)
and reciprocal_vector_ratios is None
):
reciprocal_vector_ratios = PointsGridDataProvider(
context=unit_for_convergence.context
).get_reciprocal_vector_ratios()
kgrid_item = unit_for_convergence.get_context_item("kgrid")
provider = PointsGridDataProvider(
data=kgrid_item.get("data"),
is_edited=kgrid_item.get("isEdited"),
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_reciprocal_vector_ratios()

reciprocal_vector_ratios = provider.get_reciprocal_vector_ratios()
if reciprocal_vector_ratios is None:
raise ValueError("Non-uniform k-grid convergence requires reciprocal_vector_ratios to be provided.")

Expand All @@ -198,11 +202,7 @@ def add_convergence(
reciprocal_vector_ratios=reciprocal_vector_ratios,
)

merged_context = self._merge_convergence_context(
unit_for_convergence.context,
parameter.unit_context,
)
unit_for_convergence.set_context(merged_context)
unit_for_convergence.add_context(parameter.unit_context)

self._build_convergence_units(
parameter_name=parameter.name,
Expand Down Expand Up @@ -266,7 +266,6 @@ def add_template_parameter_convergence(
execution_unit.replace_in_input_content(
pattern, f"{parameter_name} = {scope_reference}", input_name=input_name
)
execution_unit.add_context({parameter_name: parameter_initial})

self._build_convergence_units(
parameter_name=parameter_name,
Expand Down
12 changes: 7 additions & 5 deletions src/py/mat3ra/wode/subworkflows/subworkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mat3ra.mode.model import Model
from mat3ra.mode.models.factory import ModelFactory
from mat3ra.utils.uuid import get_uuid
from pydantic import Field, field_validator
from pydantic import ConfigDict, Field, SerializeAsAny, field_validator

from .convergence_mixin import ConvergenceMixin
from ..mixins import FlowchartUnitsManager
Expand All @@ -22,7 +22,7 @@ class Subworkflow(
SubworkflowSchema,
HashedEntityMixin,
InMemoryEntitySnakeCase,
FlowchartUnitsManager,
FlowchartUnitsManager[Unit],
):
"""
Subworkflow class representing a logical collection of workflow units.
Expand All @@ -35,19 +35,21 @@ class Subworkflow(
properties: List of properties extracted by the subworkflow
"""

model_config = ConfigDict(validate_assignment=True)

id: str = Field(default_factory=get_uuid, alias="_id")
application: Application = Field(
default_factory=lambda: Application(name="", version="", build="", shortName="", summary="")
)
properties: List[str] = Field(default_factory=list)
model: Model = Field(default_factory=DFTModel)
units: List[Union[Unit, ExecutionUnit, SubworkflowUnit]] = Field(default_factory=list)
model: SerializeAsAny[Model] = Field(default_factory=DFTModel)
units: List[SerializeAsAny[Union[Unit, ExecutionUnit, SubworkflowUnit]]] = Field(default_factory=list)

@field_validator("model", mode="before")
@classmethod
def _instantiate_model(cls, value: Any) -> Any:
if isinstance(value, Model):
return value
value = value.to_dict()
if isinstance(value, dict):
return ModelFactory.create(value)
return value
Expand Down
43 changes: 38 additions & 5 deletions src/py/mat3ra/wode/units/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Optional

from mat3ra.ade import Application, Executable, Flavor
from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchema
Expand All @@ -8,18 +8,19 @@
remove_empty_lines_from_string,
)
from mat3ra.utils.extra.jinja import replace_in_text
from pydantic import Field, field_validator
from pydantic import Field, SerializeAsAny, field_validator

from .execution_unit_input import ExecutionUnitInput
from .unit import Unit


class ExecutionUnit(Unit, ExecutionUnitSchema):
type: Literal["execution"] = "execution"
executable: Executable = None
flavor: Flavor = None
application: Application = None
application: Application
executable: Executable
flavor: Flavor
input: List[ExecutionUnitInput] = Field(default_factory=list)
context: List[SerializeAsAny[Dict[str, Any]]] = Field(default_factory=list)

@field_validator("input", mode="before")
@classmethod
Expand All @@ -34,6 +35,38 @@ def _instantiate_input(cls, value: Any) -> List[ExecutionUnitInput]:
instantiated.append(ExecutionUnitInput(**item))
return instantiated


@staticmethod
def _context_item_name(item: Any) -> Optional[str]:
if isinstance(item, dict):
return item.get("name")
return str(item.name)

def get_context_item(self, name: str) -> Optional[Dict[str, Any]]:
for item in self.context:
if self._context_item_name(item) == name:
return item
return None

def get_context_item_data(self, name: str, default: Any = None) -> Any:
if default is None:
default = {}
item = self.get_context_item(name)
return item.get("data", default) if item else default

def add_context(self, item: Dict[str, Any]) -> None:
existing_item = self.get_context_item(item["name"])
if existing_item:
existing_item.update(item)
else:
self.context.append(item)

def remove_context(self, name: str) -> None:
self.context = [item for item in self.context if self._context_item_name(item) != name]

def clear_context(self) -> None:
self.context = []

def replace_in_input_content(self, pattern: str, replacement: str, input_name=None) -> None:
for item in self.input:
if input_name is None or item.template.name == input_name:
Expand Down
Loading
Loading