From eb11341a2ecd2988732f8b38d13766eb64475510 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 00:11:49 -0700 Subject: [PATCH 01/18] update: single NB --- .../relax_structure_with_mlff.ipynb | 350 ++++++++++++++++++ 1 file changed, 350 insertions(+) create mode 100644 other/experiments/jupyterlite/relax_structure_with_mlff.ipynb diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb new file mode 100644 index 000000000..cdf147b14 --- /dev/null +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -0,0 +1,350 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Relax structure with MLFF (MACE / UMA) — Machine Learned Force Field\n", + "\n", + "Relax atomic positions locally with ASE using a selectable Machine Learned Force Field (MLFF).\n" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## 1. Set Input Parameters\n", + "### 1.1. Interface and Relaxation Parameters\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "FOLDER = \"uploads\" # local directory containing JSON material files\n", + "INTERFACE_NAME = \"Interface\" # name of the interface to load from the folder\n", + "\n", + "# MLFF selector\n", + "MLFF_NAME = \"mace\" # \"mace\" | \"uma\" | \"mattersim\" | \"nequip\"\n", + "\n", + "# MLFF-specific settings\n", + "MLFF_SETTINGS = {\n", + " # MACE settings\n", + " \"mace\": {\n", + " \"model\": \"large\", # \"small\" | \"medium\" | \"large\"\n", + " \"dispersion\": True,\n", + " \"default_dtype\": \"float32\",\n", + " \"device\": \"cpu\",\n", + " },\n", + " # UMA settings\n", + " \"uma\": {\n", + " \"model\": \"f16\", # \"f16\" | \"int8\"\n", + " \"task_name\": \"omat\", # e.g. omat | oc20 | omol\n", + " \"device\": \"cpu\",\n", + " },\n", + " # MatterSim settings\n", + " \"mattersim\": {\n", + " \"model\": \"1m\",\n", + " \"device\": \"cpu\",\n", + " },\n", + " # NequIP settings\n", + " \"nequip\": {\n", + " \"model\": \"oam_s\",\n", + " \"device\": \"cpu\",\n", + " },\n", + "}\n", + "\n", + "# Common relaxation settings\n", + "RELAXATION_PARAMETERS = {\n", + " \"FMAX\": 0.05, # final maximum force on any atom (eV/Å)\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## 2. Install Packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.notebooks_utils.packages import install_packages\n", + "from mat3ra.notebooks_utils.primitive.environment import is_pyodide_environment\n", + "from mat3ra.notebooks_utils.mlff import get_mlff_install_profiles\n", + "\n", + "profiles = get_mlff_install_profiles(MLFF_NAME)\n", + "await install_packages(profiles)\n", + "\n", + "# PyTorch patches are required in Pyodide for torch-based MLFFs.\n", + "if is_pyodide_environment():\n", + " from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + "\n", + " apply_all_patches(\n", + " include_fairchem=(MLFF_NAME == \"uma\"),\n", + " include_mattersim=(MLFF_NAME == \"mattersim\"),\n", + " include_nequip=(MLFF_NAME == \"nequip\"),\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## 3. Load Materials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.made.material import Material\n", + "from mat3ra.notebooks_utils.material import load_material_from_folder\n", + "from mat3ra.standata.materials import Materials\n", + "\n", + "interface = load_material_from_folder(FOLDER, INTERFACE_NAME) or Material.create(\n", + " Materials.get_by_name_first_match(INTERFACE_NAME))" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "### 3.1. Visualize Input Materials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.notebooks_utils.ipython.entity.material.visualize import ViewersEnum, visualize_materials as visualize\n", + "\n", + "visualize([{\"material\": interface, \"title\": interface.name}], viewer=ViewersEnum.wave)\n", + "visualize(interface, repetitions=[1, 1, 1], rotation=\"-90x\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "## 4. Apply Relaxation\n", + "### 4.1. Relax with MACE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.made.tools.convert import to_ase\n", + "from ase.optimize import BFGS\n", + "\n", + "from mat3ra.notebooks_utils.mlff import create_mlff_calculator\n", + "from mat3ra.notebooks_utils.plot import progress_callback\n", + "\n", + "calculator = create_mlff_calculator(MLFF_NAME, MLFF_SETTINGS[MLFF_NAME])\n", + "\n", + "ase_interface = to_ase(interface)\n", + "ase_interface.calc = calculator\n", + "dyn = BFGS(ase_interface)\n", + "\n", + "update = progress_callback(\n", + " dynamic_object=dyn,\n", + " value_getter=lambda: float(ase_interface.get_total_energy()),\n", + " value_label=\"Energy (eV)\",\n", + " step_label=\"Step\",\n", + " print_format=\"Step: {}, Energy: {:.4f} eV\",\n", + ")\n", + "dyn.attach(update, interval=1)\n", + "dyn.run(fmax=RELAXATION_PARAMETERS[\"FMAX\"])\n", + "\n", + "ase_original_interface = to_ase(interface)\n", + "ase_original_interface.calc = calculator\n", + "ase_final_interface = ase_interface\n", + "\n", + "original_energy = ase_original_interface.get_total_energy()\n", + "relaxed_energy = ase_interface.get_total_energy()\n", + "\n", + "print(f\"The final energy is {float(relaxed_energy):.3f} eV.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "## 5. Analyze Results\n", + "### 5.1. View Structure Before and After Relaxation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.made.tools.convert import from_ase\n", + "\n", + "material_original = Material.create(from_ase(ase_original_interface))\n", + "material_relaxed = Material.create(from_ase(ase_final_interface))\n", + "material_original.name = interface.name\n", + "material_relaxed.name = interface.name + f\" ({MLFF_NAME.upper()} Relaxed)\"\n", + "\n", + "visualize(\n", + " [\n", + " {\"material\": material_original, \"title\": material_original.name},\n", + " {\"material\": material_relaxed, \"title\": material_relaxed.name},\n", + " ],\n", + " viewer=ViewersEnum.wave,\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "### 5.2. Output interlayer distance before and after relaxation\n", + "This requires labels for substrate and film present in the interface structure, which is already done if interface was created with `mat3ra-made`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "from mat3ra.made.tools.analyze.other import get_average_interlayer_distance\n", + "\n", + "SUBSTRATE_TAG = 0\n", + "FILM_TAG = 1\n", + "\n", + "print(\n", + " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")\n", + "print(\n", + " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": {}, + "source": [ + "### 5.3. Calculate Interface Energy\n", + "Interface should have labels marking substrate and film atoms with different tags (e.g. 0 for substrate and 1 for film) for the separation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "def filter_atoms_by_tag(atoms, material_index):\n", + " return atoms[atoms.get_tags() == material_index]\n", + "\n", + "\n", + "def calculate_energy(atoms, calc):\n", + " atoms.set_calculator(calc)\n", + " return atoms.get_total_energy()\n", + "\n", + "\n", + "def calculate_delta_energy(total_energy, *component_energies):\n", + " return total_energy - sum(component_energies)\n", + "\n", + "\n", + "substrate_original = filter_atoms_by_tag(ase_original_interface, SUBSTRATE_TAG)\n", + "layer_original = filter_atoms_by_tag(ase_original_interface, FILM_TAG)\n", + "substrate_relaxed = filter_atoms_by_tag(ase_final_interface, SUBSTRATE_TAG)\n", + "layer_relaxed = filter_atoms_by_tag(ase_final_interface, FILM_TAG)\n", + "\n", + "original_substrate_energy = calculate_energy(substrate_original, calculator)\n", + "original_layer_energy = calculate_energy(layer_original, calculator)\n", + "relaxed_substrate_energy = calculate_energy(substrate_relaxed, calculator)\n", + "relaxed_layer_energy = calculate_energy(layer_relaxed, calculator)\n", + "\n", + "delta_original = calculate_delta_energy(original_energy, original_substrate_energy, original_layer_energy)\n", + "delta_relaxed = calculate_delta_energy(relaxed_energy, relaxed_substrate_energy, relaxed_layer_energy)\n", + "\n", + "area = ase_original_interface.get_volume() / ase_original_interface.cell[2, 2]\n", + "n_interface = ase_final_interface.get_global_number_of_atoms()\n", + "n_substrate = substrate_relaxed.get_global_number_of_atoms()\n", + "n_layer = layer_relaxed.get_global_number_of_atoms()\n", + "effective_delta_relaxed = (\n", + " relaxed_energy / n_interface\n", + " - (relaxed_substrate_energy / n_substrate + relaxed_layer_energy / n_layer)\n", + " ) / (2 * area)\n", + "\n", + "print(f\"Original Substrate energy: {original_substrate_energy:.4f} eV\")\n", + "print(f\"Relaxed Substrate energy: {relaxed_substrate_energy:.4f} eV\")\n", + "print(f\"Original Layer energy: {original_layer_energy:.4f} eV\")\n", + "print(f\"Relaxed Layer energy: {relaxed_layer_energy:.4f} eV\")\n", + "print(\"\\nDelta between interface energy and sum of component energies\")\n", + "print(f\"Original Delta: {delta_original:.4f} eV\")\n", + "print(f\"Relaxed Delta: {delta_relaxed:.4f} eV\")\n", + "print(f\"Original Delta per area: {delta_original / area:.4f} eV/Ang^2\")\n", + "print(f\"Relaxed Delta per area: {delta_relaxed / area:.4f} eV/Ang^2\")\n", + "print(f\"Relaxed interface energy: {relaxed_energy:.4f} eV\")\n", + "print(\n", + " f\"Effective relaxed Delta per area: {effective_delta_relaxed:.4f} eV/Ang^2 ({effective_delta_relaxed / 0.16:.4f} J/m^2)\")" + ] + }, + { + "cell_type": "markdown", + "id": "17", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "[1] mat3ra-made interface builder: https://github.com/Exabyte-io/made \n", + "[2] MACE-MP-0 foundation model: https://github.com/ACEsuit/mace?tab=readme-ov-file#foundation-models " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (Pyodide)", + "language": "python", + "name": "python" + }, + "language_info": { + "codemirror_mode": { + "name": "python", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 1a7a8309a2e349f2ce1833468ebf14c380aee708 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 00:14:31 -0700 Subject: [PATCH 02/18] update: add model related functions --- .../notebooks_utils/pyodide/packages/mace.py | 23 +- .../pyodide/packages/mattersim.py | 22 + .../pyodide/packages/nequip.py | 41 + .../notebooks_utils/pyodide/packages/torch.py | 1266 ++++++++++++++++- .../notebooks_utils/pyodide/packages/uma.py | 26 + 5 files changed, 1358 insertions(+), 20 deletions(-) create mode 100644 src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py create mode 100644 src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py create mode 100644 src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py index 70adc8121..dddd53935 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py @@ -1,4 +1,6 @@ -from mace.calculators import MACECalculator +from mace.calculators import MACECalculator, mace_mp + +from ...primitive.environment import is_pyodide_environment MODEL_PATHS_MAP = { "small": "/drive/packages/models/2023-12-10-mace-128-L0_energy_epoch-249.model", @@ -14,3 +16,22 @@ def get_mace_model_pyodide(model: str, dispersion=False, default_dtype="float32" return MACECalculator( model_path=model_path, dispersion=dispersion, default_dtype=default_dtype, device=device, **kwargs ) + + +def create_mace_calculator(model="large", dispersion=True, default_dtype="float32", device="cpu", **kwargs): + if is_pyodide_environment(): + return get_mace_model_pyodide( + model=model, + dispersion=dispersion, + default_dtype=default_dtype, + device=device, + **kwargs, + ) + + return mace_mp( + model=model, + dispersion=dispersion, + default_dtype=default_dtype, + device=device, + **kwargs, + ) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py new file mode 100644 index 000000000..2ae1870fc --- /dev/null +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py @@ -0,0 +1,22 @@ +from mattersim.forcefield import MatterSimCalculator + +from ...primitive.environment import is_pyodide_environment + +MODEL_PATHS_MAP = { + "1m": "/drive/packages/models/mattersim-v1.0.0-1M.pth", +} + + +def get_mattersim_model_pyodide(model: str, device="cpu", **kwargs): + if model not in MODEL_PATHS_MAP: + raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") + return MatterSimCalculator.from_checkpoint(load_path=MODEL_PATHS_MAP[model], device=device, **kwargs) + + +def create_mattersim_calculator(model="1m", device="cpu", model_path=None, checkpoint=None, **kwargs): + if is_pyodide_environment(): + return get_mattersim_model_pyodide(model=model, device=device, **kwargs) + + resolved_model_path = model_path or checkpoint + + return MatterSimCalculator.from_checkpoint(load_path=str(resolved_model_path), device=device, **kwargs) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py new file mode 100644 index 000000000..ecee14f5f --- /dev/null +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py @@ -0,0 +1,41 @@ +from nequip.ase import NequIPCalculator +from nequip.data.transforms import ChemicalSpeciesToAtomTypeMapper, NeighborListTransform + +from ...primitive.environment import is_pyodide_environment +from .torch import load_nequip_model + +MODEL_PATHS_MAP = { + "oam_s": "/drive/packages/models/nequip-oam-s-config-sd.pth", +} + + +def get_nequip_model_pyodide(model: str, device="cpu"): + if model not in MODEL_PATHS_MAP: + raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") + + nequip_model = load_nequip_model(MODEL_PATHS_MAP[model]) + return create_nequip_calculator_from_model(nequip_model, device=device) + + +def create_nequip_calculator_from_model(nequip_model, device="cpu"): + r_max = float(nequip_model.metadata["r_max"]) + type_names = nequip_model.metadata["type_names"].split(" ") + + return NequIPCalculator( + model=nequip_model, + device=device, + transforms=[ + ChemicalSpeciesToAtomTypeMapper(type_names), + NeighborListTransform(r_max=r_max), + ], + ) + + +def create_nequip_calculator(model="oam_s", device="cpu", model_path=None, checkpoint=None): + if is_pyodide_environment(): + return get_nequip_model_pyodide(model=model, device=device) + + resolved_model_path = model_path or checkpoint + + nequip_model = load_nequip_model(str(resolved_model_path)) + return create_nequip_calculator_from_model(nequip_model, device=device) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index 48e796c6b..2be0bd07b 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -28,8 +28,6 @@ import numpy as np import torch -from ...primitive.environment import is_pyodide_environment - # Define return types to mimic PyTorch's named tuples EigRet = namedtuple("linalg_eig", ["eigenvalues", "eigenvectors"]) # type: ignore EighRet = namedtuple("linalg_eigh", ["eigenvalues", "eigenvectors"]) # type: ignore @@ -142,13 +140,66 @@ def patch_torch_linalg(): torch.Tensor.__array__ = _tensor_array_compat torch.Tensor.numpy = lambda self: np.array(self.detach().tolist()) - # Fix torch.compiler.is_compiling for Pyodide - if not hasattr(torch, "compiler"): - torch.compiler = types.ModuleType("torch.compiler") - if not hasattr(torch.compiler, "is_compiling"): - torch.compiler.is_compiling = lambda: False + # torch.from_numpy — WASM PyTorch build lacks NumPy interop + _orig_from_numpy = torch.from_numpy + + def _patched_from_numpy(ndarray): + try: + return _orig_from_numpy(ndarray) + except RuntimeError: + return torch.tensor(ndarray.tolist()) + + torch.from_numpy = _patched_from_numpy + + # torch.as_tensor — also uses numpy interop internally + _orig_as_tensor = torch.as_tensor + + def _patched_as_tensor(data, dtype=None, device=None): + try: + return _orig_as_tensor(data, dtype=dtype, device=device) + except (RuntimeError, TypeError): + if hasattr(data, "tolist"): + return torch.tensor(data.tolist(), dtype=dtype, device=device) + return torch.tensor(data, dtype=dtype, device=device) + + torch.as_tensor = _patched_as_tensor + + # Tensor indexing — WASM PyTorch can't infer numpy dtypes for boolean/integer masks + _orig_getitem = torch.Tensor.__getitem__ + + def _patched_getitem(self, key): + if isinstance(key, np.ndarray): + key = torch.tensor(key.tolist()) + return _orig_getitem(self, key) + + torch.Tensor.__getitem__ = _patched_getitem - print("✓ Torch linalg patches applied") + _orig_setitem = torch.Tensor.__setitem__ + + def _patched_setitem(self, key, value): + if isinstance(key, np.ndarray): + key = torch.tensor(key.tolist()) + return _orig_setitem(self, key, value) + + torch.Tensor.__setitem__ = _patched_setitem + + # torch.tensor — handle lists of 0-d tensors (WASM calls len() on elements, fails for 0-d) + _orig_torch_tensor = torch.tensor + + def _patched_torch_tensor(data, *args, **kwargs): + if isinstance(data, (list, tuple)): + + def _unwrap(item): + if isinstance(item, torch.Tensor) and item.ndim == 0: + return item.item() + return item + + data = type(data)(_unwrap(x) for x in data) + return _orig_torch_tensor(data, *args, **kwargs) + + torch.tensor = _patched_torch_tensor + + print("✓ Torch linalg + numpy interop patches applied") # ============================================================================== @@ -191,9 +242,643 @@ def patch_torch_testing(): sys.modules["torch.testing._internal.common_utils"] = _common_utils sys.modules["torch.testing._internal.logging_tensor"] = _logging_tensor + # Import torch.utils.checkpoint AFTER logging_tensor stubs are set up + # (checkpoint.py imports logging_tensor at import time) + import torch.utils.checkpoint # noqa: F401 + print("✓ Torch testing patches applied") +# ============================================================================== +# Torch compiler patches +# ============================================================================== + + +def patch_torch_compiler(): + """ + Patch torch.compiler for Pyodide WASM. + + - Stubs torch.compiler.is_compiling (not available in WASM build) + - Makes torch.compiler.disable a no-op decorator (avoids torch._dynamo + import chain which requires C extensions missing in WASM) + - Needed by e3nn >= 0.5 and fairchem-core which use @torch.compiler.disable + """ + if not hasattr(torch, "compiler"): + torch.compiler = types.ModuleType("torch.compiler") + if not hasattr(torch.compiler, "is_compiling"): + torch.compiler.is_compiling = lambda: False + + def _compiler_disable(fn=None, recursive=True): + if fn is not None: + return fn + return lambda f: f + + torch.compiler.disable = _compiler_disable + + print("✓ Torch compiler patches applied") + + +# ============================================================================== +# Torch distributed patches (for FAIRChem) +# ============================================================================== + + +def patch_torch_distributed(): + """ + Stub torch.distributed modules that require C extensions missing in WASM. + + FAIRChem's mlip_unit.py imports from torch.distributed.checkpoint, + torch.distributed.fsdp, and torch.distributed.device_mesh. + All of these trigger torch._C._distributed_c10d which doesn't exist in WASM. + """ + import enum + + import torch.distributed as _dist + + # Core distributed API stubs + if not hasattr(_dist, "group"): + + class _Group: + WORLD = None + + _dist.group = _Group + _dist.WORLD = None + if not hasattr(_dist, "ReduceOp"): + + class _ReduceOp: + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + AVG = 7 + + _dist.ReduceOp = _ReduceOp + if not hasattr(_dist, "is_initialized"): + _dist.is_initialized = lambda: False + if not hasattr(_dist, "get_rank"): + _dist.get_rank = lambda group=None: 0 + if not hasattr(_dist, "get_world_size"): + _dist.get_world_size = lambda group=None: 1 + + # torch.distributed.nn.functional + _dist_nn = types.ModuleType("torch.distributed.nn") + _dist_nn.__path__ = [] + _dist_nn.__package__ = "torch.distributed.nn" + _dist_nn_func = types.ModuleType("torch.distributed.nn.functional") + _dist_nn_func.__package__ = "torch.distributed.nn" + _dist_nn_func.all_reduce = lambda tensor, *a, **k: tensor + _dist_nn_func.reduce_scatter = lambda output, input_list, *a, **k: output + _dist_nn_func.all_gather = lambda tensor_list, tensor, *a, **k: tensor_list + _dist_nn.functional = _dist_nn_func + sys.modules["torch.distributed.nn"] = _dist_nn + sys.modules["torch.distributed.nn.functional"] = _dist_nn_func + + # C extension stubs + sys.modules["torch._C._distributed_c10d"] = types.ModuleType("torch._C._distributed_c10d") + _dc10d = types.ModuleType("torch.distributed.distributed_c10d") + _dc10d.__package__ = "torch.distributed" + sys.modules["torch.distributed.distributed_c10d"] = _dc10d + + # torch.distributed._shard + _shard = types.ModuleType("torch.distributed._shard") + _shard.__path__ = [] + _shard.__package__ = "torch.distributed._shard" + sys.modules["torch.distributed._shard"] = _shard + sys.modules["torch.distributed._shard.api"] = types.ModuleType("torch.distributed._shard.api") + _shard_st = types.ModuleType("torch.distributed._shard.sharded_tensor") + _shard_st.__path__ = [] + sys.modules["torch.distributed._shard.sharded_tensor"] = _shard_st + _shard_st_meta = types.ModuleType("torch.distributed._shard.sharded_tensor.metadata") + + class _TensorProperties: + pass + + _shard_st_meta.TensorProperties = _TensorProperties + sys.modules["torch.distributed._shard.sharded_tensor.metadata"] = _shard_st_meta + + # torch.distributed.checkpoint + _dcp = types.ModuleType("torch.distributed.checkpoint") + _dcp.__path__ = [] + _dcp.__package__ = "torch.distributed.checkpoint" + _dcp.save = lambda *a, **k: None + _dcp.load = lambda *a, **k: None + _dcp_meta = types.ModuleType("torch.distributed.checkpoint.metadata") + _dcp_meta.TensorProperties = _TensorProperties + + class _BytesStorageMetadata: + pass + + class _TensorStorageMetadata: + pass + + class _Metadata: + pass + + _dcp_meta.BytesStorageMetadata = _BytesStorageMetadata + _dcp_meta.TensorStorageMetadata = _TensorStorageMetadata + _dcp_meta.Metadata = _Metadata + _dcp.metadata = _dcp_meta + sys.modules["torch.distributed.checkpoint"] = _dcp + sys.modules["torch.distributed.checkpoint.metadata"] = _dcp_meta + + for sub in [ + "state_dict", + "stateful", + "planner", + "storage", + "default_planner", + "filesystem", + "optimizer", + "format_utils", + ]: + _sub_mod = types.ModuleType(f"torch.distributed.checkpoint.{sub}") + _sub_mod.__package__ = "torch.distributed.checkpoint" + sys.modules[f"torch.distributed.checkpoint.{sub}"] = _sub_mod + setattr(_dcp, sub, _sub_mod) + + # format_utils stubs + sys.modules["torch.distributed.checkpoint.format_utils"].dcp_to_torch_save = lambda *a, **k: None + sys.modules["torch.distributed.checkpoint.format_utils"].torch_save_to_dcp = lambda *a, **k: None + + # state_dict stubs + _sd_mod = sys.modules["torch.distributed.checkpoint.state_dict"] + _sd_mod.get_model_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {} + _sd_mod.set_model_state_dict = ( + lambda model, sd, *a, **k: model.load_state_dict(sd) if hasattr(model, "load_state_dict") else None + ) + _sd_mod.get_optimizer_state_dict = ( + lambda model, optim, *a, **k: optim.state_dict() if hasattr(optim, "state_dict") else {} + ) + _sd_mod.get_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {} + _sd_mod.set_state_dict = lambda model, sd, *a, **k: None + + class _StateDictOptions: + def __init__(self, **k): + self.__dict__.update(k) + + _sd_mod.StateDictOptions = _StateDictOptions + + # stateful stub + class _Stateful: + pass + + sys.modules["torch.distributed.checkpoint.stateful"].Stateful = _Stateful + + # torch.distributed.fsdp + _fsdp = types.ModuleType("torch.distributed.fsdp") + _fsdp.__path__ = [] + _fsdp.__package__ = "torch.distributed.fsdp" + sys.modules["torch.distributed.fsdp"] = _fsdp + for fsdp_sub in ["fully_sharded_data_parallel", "api", "wrap", "sharded_grad_scaler"]: + _fsub = types.ModuleType(f"torch.distributed.fsdp.{fsdp_sub}") + _fsub.__package__ = "torch.distributed.fsdp" + sys.modules[f"torch.distributed.fsdp.{fsdp_sub}"] = _fsub + setattr(_fsdp, fsdp_sub, _fsub) + + # FSDP wrap policy stubs + class _ModuleWrapPolicy: + def __init__(self, module_classes=None): + self.module_classes = module_classes or set() + + _fsdp_wrap = sys.modules["torch.distributed.fsdp.wrap"] + _fsdp_wrap.ModuleWrapPolicy = _ModuleWrapPolicy + _fsdp_wrap.lambda_auto_wrap_policy = lambda *a, **k: None + _fsdp_wrap.transformer_auto_wrap_policy = lambda *a, **k: None + + # FSDP classes + class _FullyShardedDataParallel(torch.nn.Module): + def __init__(self, module, *a, **k): + super().__init__() + self.module = module + + class _ShardingStrategy(enum.Enum): + FULL_SHARD = "FULL_SHARD" + SHARD_GRAD_OP = "SHARD_GRAD_OP" + NO_SHARD = "NO_SHARD" + HYBRID_SHARD = "HYBRID_SHARD" + + class _MixedPrecision: + def __init__(self, *a, **k): + pass + + class _CPUOffload: + def __init__(self, offload_params=False): + self.offload_params = offload_params + + class _BackwardPrefetch(enum.Enum): + BACKWARD_PRE = "BACKWARD_PRE" + BACKWARD_POST = "BACKWARD_POST" + + class _StateDictTypeFSDP(enum.Enum): + FULL_STATE_DICT = 0 + LOCAL_STATE_DICT = 1 + SHARDED_STATE_DICT = 2 + + _fsdp.FullyShardedDataParallel = _FullyShardedDataParallel + _fsdp.ShardingStrategy = _ShardingStrategy + _fsdp.MixedPrecision = _MixedPrecision + _fsdp.CPUOffload = _CPUOffload + _fsdp.BackwardPrefetch = _BackwardPrefetch + _fsdp.StateDictType = _StateDictTypeFSDP + + # FSDP StateDictConfig stubs + class _StateDictConfig: + def __init__(self, **k): + self.__dict__.update(k) + + class _ShardedStateDictConfig(_StateDictConfig): + pass + + class _FullStateDictConfig(_StateDictConfig): + def __init__(self, offload_to_cpu=False, rank0_only=False, **k): + super().__init__(**k) + self.offload_to_cpu = offload_to_cpu + self.rank0_only = rank0_only + + class _FullOptimStateDictConfig(_StateDictConfig): + def __init__(self, offload_to_cpu=False, rank0_only=False, **k): + super().__init__(**k) + + _fsdp.StateDictConfig = _StateDictConfig + _fsdp.ShardedStateDictConfig = _ShardedStateDictConfig + _fsdp.FullStateDictConfig = _FullStateDictConfig + _fsdp.FullOptimStateDictConfig = _FullOptimStateDictConfig + _fsdp.LocalStateDictConfig = type("LocalStateDictConfig", (_StateDictConfig,), {}) + _fsdp.OptimStateDictConfig = type("OptimStateDictConfig", (_StateDictConfig,), {}) + _fsdp.ShardedOptimStateDictConfig = type("ShardedOptimStateDictConfig", (_StateDictConfig,), {}) + + # torch.distributed.device_mesh + _dm = types.ModuleType("torch.distributed.device_mesh") + _dm.__package__ = "torch.distributed" + + class _DeviceMesh: + def __init__(self, *a, **k): + pass + + _dm.DeviceMesh = _DeviceMesh + _dm.init_device_mesh = lambda *a, **k: _DeviceMesh() + sys.modules["torch.distributed.device_mesh"] = _dm + + # torch.distributed.tensor (DTensor) + _dtensor = types.ModuleType("torch.distributed.tensor") + _dtensor.__path__ = [] + _dtensor.__package__ = "torch.distributed.tensor" + + class _DTensor: + pass + + _dtensor.DTensor = _DTensor + sys.modules["torch.distributed.tensor"] = _dtensor + + # torch.distributed.algorithms + _dalgo = types.ModuleType("torch.distributed.algorithms") + _dalgo.__path__ = [] + _dalgo.__package__ = "torch.distributed.algorithms" + sys.modules["torch.distributed.algorithms"] = _dalgo + + print("✓ Torch distributed patches applied") + + +# ============================================================================== +# FAIRChem heavy dependency stubs +# ============================================================================== + + +def _make_stub_module(name, attrs=None, submodules=None): + """Create a stub module with optional attributes and submodules.""" + from importlib.machinery import ModuleSpec + + mod = types.ModuleType(name) + mod.__path__ = [] + mod.__package__ = name + mod.__version__ = "0.0.0" + mod.__spec__ = ModuleSpec(name, None, is_package=True) + if attrs: + for k, v in attrs.items(): + setattr(mod, k, v) + sys.modules[name] = mod + if submodules: + for sub_name in submodules: + full_name = f"{name}.{sub_name}" + sub_mod = types.ModuleType(full_name) + sub_mod.__path__ = [] + sub_mod.__package__ = name + sub_mod.__spec__ = ModuleSpec(full_name, None, is_package=True) + setattr(mod, sub_name, sub_mod) + sys.modules[full_name] = sub_mod + return mod + + +def patch_fairchem_deps(): + """ + Stub heavy dependencies that fairchem-core imports but doesn't need for inference. + + This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf, + submitit, clusterscope, tqdm, huggingface_hub, websockets. + """ + # --- numba --- + numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"]) + numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f + numba_mod.jit = numba_mod.njit + numba_mod.prange = range + for t in ("int32", "int64", "float32", "float64", "boolean"): + setattr(numba_mod, t, t) + + class _TypedList(list): + pass + + sys.modules["numba.typed"].List = _TypedList + + # --- ray --- + ray_mod = _make_stub_module( + "ray", + submodules=[ + "serve", + "runtime_env", + "train", + "data", + "util", + "util.scheduling_strategies", + "util.queue", + ], + ) + + def _ray_remote(*args, **kwargs): + if args and callable(args[0]): + args[0]._remote = lambda *a, **kw: None + return args[0] + + def wrapper(fn_or_cls): + fn_or_cls._remote = lambda *a, **kw: None + return fn_or_cls + + return wrapper + + ray_mod.remote = _ray_remote + ray_mod.init = lambda *a, **k: None + ray_mod.get = lambda *a, **k: None + ray_mod.put = lambda *a, **k: None + ray_mod.wait = lambda *a, **k: ([], []) + ray_mod.is_initialized = lambda: False + + class _ObjectRef: + pass + + ray_mod.ObjectRef = _ObjectRef + + class _PlacementGroupSchedulingStrategy: + def __init__(self, *a, **k): + pass + + sys.modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = _PlacementGroupSchedulingStrategy + + # ray.serve stubs + _serve = sys.modules["ray.serve"] + + def _serve_deployment(*args, **kwargs): + if args and callable(args[0]): + return args[0] + return lambda cls_or_fn: cls_or_fn + + _serve.deployment = _serve_deployment + _serve.ingress = lambda *a, **k: (lambda cls: cls) + _serve.run = lambda *a, **k: None + _serve.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] + + _serve_schema = types.ModuleType("ray.serve.schema") + _serve_schema.__package__ = "ray.serve" + + class _LoggingConfig: + def __init__(self, **k): + self.__dict__.update(k) + + _serve_schema.LoggingConfig = _LoggingConfig + _serve.schema = _serve_schema + sys.modules["ray.serve.schema"] = _serve_schema + + # --- wandb --- + wandb_mod = _make_stub_module("wandb") + wandb_mod.init = lambda *a, **k: None + wandb_mod.log = lambda *a, **k: None + wandb_mod.finish = lambda *a, **k: None + + # --- torchtnt --- + _make_stub_module( + "torchtnt", + submodules=[ + "framework", + "framework.state", + "framework.unit", + "framework.callback", + "framework.auto_unit", + "framework.fit", + "framework.train", + "framework.evaluate", + "framework.predict", + "utils", + "utils.loggers", + "utils.timer", + "utils.distributed", + "utils.prepare_module", + ], + ) + + class _PredictUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _TrainUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _EvalUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _State: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _Callback: + pass + + tnt_framework = sys.modules["torchtnt.framework"] + tnt_unit = sys.modules["torchtnt.framework.unit"] + tnt_state = sys.modules["torchtnt.framework.state"] + tnt_cb = sys.modules["torchtnt.framework.callback"] + for m in [tnt_framework, tnt_unit]: + m.PredictUnit = _PredictUnit + m.TrainUnit = _TrainUnit + m.EvalUnit = _EvalUnit + tnt_state.State = _State + tnt_cb.Callback = _Callback + tnt_framework.State = _State + tnt_framework.Callback = _Callback + + # torchtnt entry point functions + sys.modules["torchtnt.framework.fit"].fit = lambda *a, **k: None + sys.modules["torchtnt.framework.train"].train = lambda *a, **k: None + sys.modules["torchtnt.framework.evaluate"].evaluate = lambda *a, **k: None + sys.modules["torchtnt.framework.predict"].predict = lambda *a, **k: None + + # torchtnt.utils stubs + tnt_dist = sys.modules["torchtnt.utils.distributed"] + tnt_dist.get_file_init_method = lambda *a, **k: "" + tnt_dist.get_tcp_init_method = lambda *a, **k: "" + tnt_dist.spawn_multi_process = lambda *a, **k: None + + tnt_prep = sys.modules["torchtnt.utils.prepare_module"] + tnt_prep.prepare_module = lambda module, *a, **k: module + tnt_prep.FSDPStrategy = type("FSDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.DDPStrategy = type("DDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) + + # --- hydra / omegaconf --- + omegaconf_mod = _make_stub_module("omegaconf") + + class _DictConfig(dict): + pass + + class _ListConfig(list): + pass + + omegaconf_mod.DictConfig = _DictConfig + omegaconf_mod.ListConfig = _ListConfig + omegaconf_mod.OmegaConf = type( + "OmegaConf", + (), + { + "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), + "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), + }, + ) + _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) + + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + import importlib + + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + pos_args = list(args) + list(config.get("_args_", [])) + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v] + pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] + + # --- submitit / clusterscope --- + _make_stub_module("submitit") + _make_stub_module("clusterscope") + + # --- websockets --- + _make_stub_module("websockets") + + # --- tqdm --- + tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) + + def _tqdm_passthrough(iterable=None, *a, **k): + return iterable if iterable is not None else iter([]) + + tqdm_mod.tqdm = _tqdm_passthrough + sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough + sys.modules["tqdm.std"].tqdm = _tqdm_passthrough + + # --- huggingface_hub --- + hf_mod = _make_stub_module("huggingface_hub", submodules=["utils"]) + hf_mod.hf_hub_download = lambda *a, **k: "" + hf_mod.snapshot_download = lambda *a, **k: "" + + # --- ase_db_backends --- + _make_stub_module("ase_db_backends") + # --- INT8 quantized model support --- + _orig_torch_load = torch.load + + def _int8_aware_torch_load(f, *args, **kwargs): + result = _orig_torch_load(f, *args, **kwargs) + if isinstance(result, dict) and "quantized_ema_state_dict" in result: + import gc as _gc + + from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint + + print(" Dequantizing INT8 -> FP16 (streaming)...") + quantized_ema = result.pop("quantized_ema_state_dict") + scales = result.pop("quantization_scales") + ema_state_dict = {} + names = list(quantized_ema.keys()) + for name in names: + tensor = quantized_ema.pop(name) + if name in scales: + scale = scales.pop(name) + ema_state_dict[name] = (tensor.float() * scale.float()).half() + del scale + else: + ema_state_dict[name] = tensor + del tensor + del quantized_ema, scales + _gc.collect() + checkpoint = MLIPInferenceCheckpoint( + model_config=result["model_config"], + model_state_dict=result.get("model_state_dict", {}), + ema_state_dict=ema_state_dict, + tasks_config=result["tasks_config"], + ) + del result + _gc.collect() + print(" ✓ Dequantization complete") + return checkpoint + return result + + torch.load = _int8_aware_torch_load + + # --- Model registry fallback --- + try: + from fairchem.core.common.registry import registry as _registry + + _orig_get_model = _registry.get_model_class + + def _fallback_get_model_class(name): + try: + return _orig_get_model(name) + except RuntimeError: + import importlib as _imp + + module_path, class_name = name.rsplit(".", 1) + mod = _imp.import_module(module_path) + return getattr(mod, class_name) + + _registry.get_model_class = _fallback_get_model_class + except Exception: + pass + + print("✓ FAIRChem dependency stubs applied") + + # ============================================================================== # Matscipy patches # ============================================================================== @@ -275,14 +960,557 @@ def patch_mace_tools(): # ============================================================================== -def apply_all_patches(): - """Apply all torch and MACE patches for Pyodide in one call.""" - if is_pyodide_environment(): - patch_torch_linalg() - patch_torch_testing() - patch_matscipy() - patch_mace_training() - patch_mace_tools() - print("\n✅ All Pyodide patches applied successfully!") - else: - print("⚠ Not in Pyodide environment. Patches not applied.") +def apply_all_patches(include_fairchem=False, include_mattersim=False, include_nequip=False): + """ + Apply all torch and model patches for Pyodide in one call. + + Args: + include_fairchem: If True, also apply FAIRChem-specific patches + (torch.distributed, heavy dependency stubs). Set this when + using fairchem-core / UMA models. + include_mattersim: If True, also apply MatterSim-specific patches + (loguru, azure, e3nn JIT stubs). Set this when + using MatterSim / M3GNet models. + include_nequip: If True, also apply NequIP-specific patches + (lightning, hydra, torchmetrics, e3nn JIT stubs). Set this + when using NequIP models. + """ + patch_torch_linalg() + patch_torch_compiler() + patch_torch_testing() + patch_matscipy() + patch_mace_training() + patch_mace_tools() + + if include_fairchem: + patch_torch_distributed() + patch_fairchem_deps() + + if include_mattersim: + patch_torch_distributed() + patch_mattersim_deps() + + if include_nequip: + patch_nequip_deps() + + print("\n✅ All Pyodide patches applied successfully!") + + +# ============================================================================== +# MatterSim patches +# ============================================================================== + + +def patch_mattersim_deps(): + """ + Stub heavy dependencies required by MatterSim but not needed for inference. + + Stubs: loguru, azure.*, atomate2, seekpath, phonopy, phono3py, mp_api, + sklearn, and patches e3nn to disable JIT/torch.compile. + """ + loguru_mod = _make_stub_module("loguru") + + class _Logger: + def info(self, msg, *a, **k): + print(f"INFO: {msg}") + + def warning(self, msg, *a, **k): + print(f"WARNING: {msg}") + + def error(self, msg, *a, **k): + print(f"ERROR: {msg}") + + def debug(self, msg, *a, **k): + pass + + def trace(self, msg, *a, **k): + pass + + def success(self, msg, *a, **k): + print(f"✓ {msg}") + + def __getattr__(self, name): + return lambda *a, **k: None + + loguru_mod.logger = _Logger() + + _make_stub_module("azure", submodules=["identity", "storage", "storage.blob"]) + + for pkg in [ + "atomate2", + "seekpath", + "phonopy", + "phono3py", + "mp_api", + "jobflow", + "emmet", + "emmet.core", + "emmet.core.tasks", + "maggma", + ]: + _make_stub_module(pkg) + + _make_stub_module( + "sklearn", + submodules=[ + "base", + "utils", + "utils.validation", + "preprocessing", + "model_selection", + "gaussian_process", + "gaussian_process.kernels", + ], + ) + + class _GPR: + def __init__(self, *a, **k): + pass + + def fit(self, *a, **k): + return self + + def predict(self, X, return_std=False): + import numpy as _np + + mean = _np.zeros(X.shape[0]) + if return_std: + return mean, _np.ones(X.shape[0]) + return mean + + def log_marginal_likelihood(self): + return 0.0 + + sys.modules["sklearn.gaussian_process"].GaussianProcessRegressor = _GPR + + class _Kernel: + pass + + class _DotProduct(_Kernel): + def __init__(self, *a, **k): + pass + + class _Hyperparameter: + def __init__(self, *a, **k): + pass + + _sk_kernels = sys.modules["sklearn.gaussian_process.kernels"] + _sk_kernels.Kernel = _Kernel + _sk_kernels.DotProduct = _DotProduct + _sk_kernels.Hyperparameter = _Hyperparameter + + try: + import pyodide_http # noqa: F401 + + pyodide_http.patch_all() + except ImportError: + pass + + try: + import e3nn + + e3nn._SO3_INITIALIZED = True + except Exception: + pass + + import torch + + if not hasattr(torch.jit, "_original_script"): + _orig_script = torch.jit.script + + def _noop_script(obj=None, *a, **k): + if obj is not None: + return obj + return lambda fn: fn + + torch.jit.script = _noop_script + + _te = _make_stub_module("torch_ema") + + class _EMA: + def __init__(self, *a, **k): + pass + + _te.ExponentialMovingAverage = _EMA + + _tm = _make_stub_module("torchmetrics") + + class _MeanMetric: + def __init__(self, *a, **k): + pass + + _tm.MeanMetric = _MeanMetric + + _make_stub_module("torch_geometric", submodules=["data", "loader", "utils"]) + + class _Data: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def to(self, device): + import torch as _t + + for k, v in self.__dict__.items(): + if isinstance(v, _t.Tensor): + setattr(self, k, v.to(device)) + return self + + sys.modules["torch_geometric.data"].Data = _Data + + class _DataLoader: + def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): + self._dataset = list(dataset) + + def __iter__(self): + import torch as _t + + for item in self._dataset: + for attr in list(vars(item).keys()): + val = getattr(item, attr) + if isinstance(val, (int, float)): + setattr(item, attr, _t.tensor([val])) + if not hasattr(item, "num_graphs"): + item.num_graphs = 1 + if not hasattr(item, "batch"): + n_atoms = item.num_atoms if hasattr(item, "num_atoms") else _t.tensor([0]) + if isinstance(n_atoms, _t.Tensor): + n_atoms = int(n_atoms.item()) + item.batch = _t.zeros(n_atoms, dtype=_t.long) + yield item + + def __len__(self): + return len(self._dataset) + + sys.modules["torch_geometric.loader"].DataLoader = _DataLoader + + _make_stub_module("torch_runstats", submodules=["scatter"]) + import torch as _torch + + def _scatter(src, index, dim_size=None, dim=0, reduce="sum"): + if dim_size is None: + dim_size = int(index.max()) + 1 + out = _torch.zeros(dim_size, *src.shape[1:], dtype=src.dtype, device=src.device) + if src.dim() == 1: + idx = index + else: + idx = index.unsqueeze(-1).expand_as(src) + if reduce == "sum" or reduce == "add": + out.scatter_add_(0, idx, src) + elif reduce == "mean": + out.scatter_add_(0, idx, src) + count = _torch.zeros(dim_size, dtype=src.dtype, device=src.device) + count.scatter_add_(0, index, _torch.ones(index.shape[0], dtype=src.dtype, device=src.device)) + count = count.clamp(min=1) + if src.dim() > 1: + count = count.unsqueeze(-1) + out = out / count + return out + + sys.modules["torch_runstats.scatter"].scatter = _scatter + + def _scatter_mean(src, index, dim_size=None, dim=0): + return _scatter(src, index, dim_size=dim_size, dim=dim, reduce="mean") + + sys.modules["torch_runstats.scatter"].scatter_mean = _scatter_mean + + print("✓ MatterSim dependency stubs applied") + + +# ============================================================================== +# NequIP patches +# ============================================================================== + + +def patch_nequip_deps(): + """ + Stub heavy dependencies required by NequIP but not needed for inference. + + Stubs: hydra, lightning, pytorch_lightning, torchmetrics, lmdb, matscipy. + Patches e3nn and torch.jit for Pyodide compatibility. + Sets NEQUIP_NL=ase to use ASE neighbor lists instead of matscipy. + """ + import os + + import torch + + os.environ["NEQUIP_NL"] = "ase" + + _make_stub_module("lightning_utilities", submodules=["core", "core.rank_zero"]) + + def _rank_prefixed_message(msg, rank=None): + return msg + + def _rank_zero_only(fn): + return fn + + sys.modules["lightning_utilities.core.rank_zero"].rank_prefixed_message = _rank_prefixed_message + sys.modules["lightning_utilities.core.rank_zero"].rank_zero_only = _rank_zero_only + + _make_stub_module( + "lightning", + submodules=[ + "pytorch", + "pytorch.utilities", + "pytorch.utilities.seed", + "pytorch.utilities.warnings", + "pytorch.callbacks", + ], + ) + + class _IsolateRng: + def __enter__(self): + return self + + def __exit__(self, *a): + pass + + def _seed_everything(seed=None, workers=False, **kwargs): + pass + + sys.modules["lightning.pytorch.utilities.seed"].isolate_rng = lambda: _IsolateRng() + sys.modules["lightning.pytorch"].seed_everything = _seed_everything + + class _PossibleUserWarning(UserWarning): + pass + + sys.modules["lightning.pytorch.utilities.warnings"].PossibleUserWarning = _PossibleUserWarning + + class _LightningModule(torch.nn.Module): + def __init__(self, *a, **k): + super().__init__() + + def log(self, *a, **k): + pass + + sys.modules["lightning.pytorch"].LightningModule = _LightningModule + sys.modules["lightning"].pytorch = sys.modules["lightning.pytorch"] + sys.modules["lightning.pytorch.callbacks"].Callback = type("Callback", (), {}) + + _make_stub_module("pytorch_lightning", submodules=["utilities", "utilities.seed"]) + sys.modules["pytorch_lightning.utilities.seed"].isolate_rng = lambda: _IsolateRng() + + _tm = _make_stub_module("torchmetrics") + + class _Metric(torch.nn.Module): + def __init__(self, *a, **k): + super().__init__() + + def add_state(self, name, default=None, dist_reduce_fx=None, **k): + if default is not None: + setattr(self, name, default) + + _tm.Metric = _Metric + _tm.MeanMetric = _Metric + + if "packaging" not in sys.modules: + try: + import packaging.version # noqa: F401 + except (ImportError, ModuleNotFoundError): + _make_stub_module("packaging", submodules=["version"]) + + class _Version: + def __init__(self, v): + self._v = str(v) + parts = self._v.split(".") + self.major = int(parts[0]) if parts else 0 + self.minor = int(parts[1]) if len(parts) > 1 else 0 + + def __lt__(self, other): + return (self.major, self.minor) < (other.major, other.minor) + + def __ge__(self, other): + return not self.__lt__(other) + + def __repr__(self): + return f"Version('{self._v}')" + + sys.modules["packaging.version"].Version = _Version + sys.modules["packaging.version"].parse = lambda v: _Version(v) + + if "lmdb" not in sys.modules: + sys.modules["lmdb"] = types.ModuleType("lmdb") + + if "hydra" not in sys.modules: + _make_stub_module( + "hydra", + submodules=[ + "core", + "core.global_hydra", + "utils", + "_internal", + "_internal.instantiate", + "_internal.instantiate._instantiate2", + ], + ) + + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + import importlib + + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + pos_args = list(args) + list(config.get("_args_", [])) + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [ + _hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v + ] + pos_args = [ + _hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args + ] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] + sys.modules["hydra._internal.instantiate._instantiate2"].InstantiationException = type( + "InstantiationException", (Exception,), {} + ) + + def _hydra_get_target(target_str): + import importlib + + mod_path, name = target_str.rsplit(".", 1) + return getattr(importlib.import_module(mod_path), name) + + sys.modules["hydra.utils"].get_method = _hydra_get_target + sys.modules["hydra.utils"].get_class = _hydra_get_target + + if "omegaconf" not in sys.modules: + omegaconf_mod = _make_stub_module("omegaconf") + + class _DictConfig(dict): + pass + + class _ListConfig(list): + pass + + omegaconf_mod.DictConfig = _DictConfig + omegaconf_mod.ListConfig = _ListConfig + omegaconf_mod.OmegaConf = type( + "OmegaConf", + (), + { + "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), + "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), + "register_new_resolver": staticmethod(lambda name, func, **k: None), + }, + ) + + if "matscipy" not in sys.modules: + _matscipy = types.ModuleType("matscipy") + _matscipy.__path__ = [] + _matscipy.__package__ = "matscipy" + _matscipy_neighbours = types.ModuleType("matscipy.neighbours") + _matscipy_neighbours.neighbour_list = _matscipy_neighbour_list_compat + _matscipy.neighbours = _matscipy_neighbours + sys.modules["matscipy"] = _matscipy + sys.modules["matscipy.neighbours"] = _matscipy_neighbours + + try: + import e3nn + + e3nn._SO3_INITIALIZED = True + if hasattr(e3nn, "_OPT_DEFAULTS"): + e3nn._OPT_DEFAULTS["jit_mode"] = "eager" + except Exception: + pass + + if not hasattr(torch.jit, "_original_script"): + _orig_script = torch.jit.script + + def _noop_script(obj=None, *a, **k): + if obj is not None: + return obj + return lambda fn: fn + + torch.jit.script = _noop_script + + try: + import e3nn.util.jit + + e3nn.util.jit.compile_mode = lambda mode: lambda cls: cls + except Exception: + pass + + if "tqdm" not in sys.modules: + tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) + + def _tqdm_passthrough(iterable=None, *a, **k): + return iterable if iterable is not None else iter([]) + + tqdm_mod.tqdm = _tqdm_passthrough + sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough + sys.modules["tqdm.std"].tqdm = _tqdm_passthrough + + print("✓ NequIP dependency stubs applied") + + +def load_nequip_model(checkpoint_path): + import torch + + data = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + import nequip.utils.global_state as _gs + + def _pyodide_set_global_state(allow_tf32=False, warn_on_override=False): + if not _gs._GLOBAL_STATE_INITIALIZED: + torch.set_default_dtype(torch.float64) + try: + import e3nn + + e3nn.set_optimization_defaults( + specialized_code=True, + optimize_einsums=True, + jit_script_fx=False, + ) + except Exception: + pass + _gs._GLOBAL_STATE_INITIALIZED = True + _gs._latest_global_config["allow_tf32"] = allow_tf32 + + _gs.set_global_state = _pyodide_set_global_state + _gs.set_global_state(allow_tf32=False) + + from nequip.model import FullNequIPGNNModel + + model = FullNequIPGNNModel( + seed=0, + model_dtype="float32", + r_max=data["r_max"], + type_names=data["type_names"], + irreps_edge_sh=data["irreps_edge_sh"], + type_embed_num_features=data["type_embed_num_features"], + feature_irreps_hidden=data["feature_irreps_hidden"], + radial_mlp_depth=data["radial_mlp_depth"], + radial_mlp_width=data["radial_mlp_width"], + avg_num_neighbors=data["avg_num_neighbors"], + per_type_energy_scales=data["per_type_energy_scales"], + per_type_energy_shifts=data["per_type_energy_shifts"], + polynomial_cutoff_p=data.get("polynomial_cutoff_p", 6), + ) + + if data.get("has_zbl", False): + from nequip.nn.pair_potential import ZBL + + seq_net = model.model.func + zbl = ZBL( + type_names=data["type_names"], + chemical_species=data["type_names"], + units="metal", + irreps_in=seq_net.irreps_out, + ) + seq_net.insert(name="pair_potential", module=zbl, before="total_energy_sum") + + model.load_state_dict(data["state_dict"], strict=True) + model.eval() + + print(f"✓ NequIP model loaded ({sum(p.numel() for p in model.parameters()):,} parameters)") + return model diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py new file mode 100644 index 000000000..0aad4b331 --- /dev/null +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py @@ -0,0 +1,26 @@ +from fairchem.core import FAIRChemCalculator +from fairchem.core.units.mlip_unit import load_predict_unit + +from ...primitive.environment import is_pyodide_environment + +MODEL_PATHS_MAP = { + "f16": "/drive/packages/models/uma-s-1p1-f16.pt", + "int8": "/drive/packages/models/uma-s-1p1-int8.pt", +} + + +def get_uma_model_pyodide(model: str, task_name="omat", device="cpu", **kwargs): + if model not in MODEL_PATHS_MAP: + raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") + predictor = load_predict_unit(MODEL_PATHS_MAP[model], device=device, **kwargs) + return FAIRChemCalculator(predictor, task_name=task_name) + + +def create_uma_calculator(model="f16", task_name="omat", device="cpu", model_path=None, checkpoint=None, **kwargs): + if is_pyodide_environment(): + return get_uma_model_pyodide(model=model, task_name=task_name, device=device, **kwargs) + + resolved_model_path = model_path or checkpoint + + predictor = load_predict_unit(str(resolved_model_path), device=device, **kwargs) + return FAIRChemCalculator(predictor, task_name=task_name) From 911643ea7b2aa2ed70f5abaef58ae0c0d6e8b6c3 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 00:15:41 -0700 Subject: [PATCH 03/18] update: add top level switch --- src/py/mat3ra/notebooks_utils/mlff.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/py/mat3ra/notebooks_utils/mlff.py diff --git a/src/py/mat3ra/notebooks_utils/mlff.py b/src/py/mat3ra/notebooks_utils/mlff.py new file mode 100644 index 000000000..b8b30a4f6 --- /dev/null +++ b/src/py/mat3ra/notebooks_utils/mlff.py @@ -0,0 +1,27 @@ +from importlib import import_module +from typing import Any, Dict + +MLFF_MODULES = { + "mace": ("mat3ra.notebooks_utils.pyodide.packages.mace", "create_mace_calculator"), + "uma": ("mat3ra.notebooks_utils.pyodide.packages.uma", "create_uma_calculator"), + "mattersim": ("mat3ra.notebooks_utils.pyodide.packages.mattersim", "create_mattersim_calculator"), + "nequip": ("mat3ra.notebooks_utils.pyodide.packages.nequip", "create_nequip_calculator"), +} + + +def get_mlff_install_profiles(mlff_name: str) -> str: + mlff = (mlff_name or "").strip().lower() + if mlff in MLFF_MODULES: + return f"made|api_examples|torch|{mlff}" + raise ValueError(f"Unsupported MLFF: {mlff_name!r}") + + +def create_mlff_calculator(mlff_name: str, settings: Dict[str, Any]): + mlff = (mlff_name or "").strip().lower() + if mlff not in MLFF_MODULES: + raise ValueError(f"Unsupported MLFF: {mlff_name!r}") + + module_name, factory_name = MLFF_MODULES[mlff] + module = import_module(module_name) + factory = getattr(module, factory_name) + return factory(**settings) From 12235bc52fa7aeb49832ab197900dbf0f8357846 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 09:13:44 -0700 Subject: [PATCH 04/18] chore: packages via LFS --- packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl | 3 +++ packages/mattersim-1.1.2-py3-none-any.whl | 3 +++ packages/models/nequip-oam-s-config-sd.pth | 3 +++ packages/models/uma-s-1p1-f16.pt | 3 +++ packages/models/uma-s-1p1-int8.pt | 3 +++ packages/nequip-0.15.0-py3-none-any.whl | 3 +++ 6 files changed, 18 insertions(+) create mode 100644 packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl create mode 100644 packages/mattersim-1.1.2-py3-none-any.whl create mode 100644 packages/models/nequip-oam-s-config-sd.pth create mode 100644 packages/models/uma-s-1p1-f16.pt create mode 100644 packages/models/uma-s-1p1-int8.pt create mode 100644 packages/nequip-0.15.0-py3-none-any.whl diff --git a/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl new file mode 100644 index 000000000..2a6ca7233 --- /dev/null +++ b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59f990d0c30e35b3ec80e0daa1c2aff9bd3115794fcb8f119125c9235e2e1178 +size 144573 diff --git a/packages/mattersim-1.1.2-py3-none-any.whl b/packages/mattersim-1.1.2-py3-none-any.whl new file mode 100644 index 000000000..a871dc1f0 --- /dev/null +++ b/packages/mattersim-1.1.2-py3-none-any.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ab6c9bdcded031f2657a38837b4e4efdc6b211023078c0555c91c9d0b8422fe +size 72010 diff --git a/packages/models/nequip-oam-s-config-sd.pth b/packages/models/nequip-oam-s-config-sd.pth new file mode 100644 index 000000000..ebaf076c5 --- /dev/null +++ b/packages/models/nequip-oam-s-config-sd.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e547801c211ebdef3750e9dfc1baa8b1919c7d4e66624e9f316e50875e445eea +size 2495009 diff --git a/packages/models/uma-s-1p1-f16.pt b/packages/models/uma-s-1p1-f16.pt new file mode 100644 index 000000000..641385575 --- /dev/null +++ b/packages/models/uma-s-1p1-f16.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dad47e7d35b42b08917bc4ebbca98cc76e32b3df15b6256f817ba3f42b9bdfe +size 294266951 diff --git a/packages/models/uma-s-1p1-int8.pt b/packages/models/uma-s-1p1-int8.pt new file mode 100644 index 000000000..a000109a4 --- /dev/null +++ b/packages/models/uma-s-1p1-int8.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b6fea3cebdb512078d5cb45feaf701b786be7b1fa024cee0a7dd24aafae8fc9 +size 146678700 diff --git a/packages/nequip-0.15.0-py3-none-any.whl b/packages/nequip-0.15.0-py3-none-any.whl new file mode 100644 index 000000000..2462ad72a --- /dev/null +++ b/packages/nequip-0.15.0-py3-none-any.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b3ffe7b2c95dd4954515c150e374890a49dd091b1c8db255772601d037df6e3 +size 260116 From df3e359a329252af85e066e45878b704b2aa7a5d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 09:26:41 -0700 Subject: [PATCH 05/18] update: deps for mlffs --- config.yml | 52 +++++++++++++++++++++++---- src/py/mat3ra/notebooks_utils/mlff.py | 2 +- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/config.yml b/config.yml index da8e0fae1..1a9fa1bdd 100644 --- a/config.yml +++ b/config.yml @@ -83,21 +83,61 @@ notebooks: - name: torch packages_pyodide: - emfs:/drive/packages/torch-2.1.0a0-cp311-cp311-emscripten_3_1_45_wasm32.whl + - name: mlff + packages_pyodide: + - opt_einsum + - nodeps:opt_einsum_fx + - ssl + - h5py + - lmdb - name: mace packages_pyodide: # Packages with dependencies - - opt_einsum - prettytable - orjson - anywidget # Packages without dependencies (using nodeps: prefix) - - nodeps:opt_einsum_fx - nodeps:e3nn==0.4.4 - nodeps:torch_ema - nodeps:lightning-utilities - nodeps:torchmetrics - nodeps:mace-torch - # Stubbed packages (will be patched by torch_pyodide) - - ssl - - h5py - - lmdb + - name: uma + packages_pyodide: + # antlr4 local wheel (required by omegaconf, PyPI version has ATN mismatch) + - emfs:/drive/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl + # Packages with dependencies + - orjson + - pyyaml + - sqlite3 + - omegaconf + - hydra-core + # Packages without dependencies (using nodeps: prefix) + - nodeps:e3nn>=0.5 + - nodeps:ase + - nodeps:monty + - nodeps:fairchem-core + - name: mattersim + packages_pyodide: + # Packages with dependencies + - orjson + - pyyaml + - setuptools + # Packages without dependencies + - nodeps:e3nn>=0.5 + - nodeps:ase + - nodeps:monty + - nodeps:deprecated + - wrapt + # MatterSim local wheel (pure Python, Cython replaced with NumPy) + - emfs:/drive/packages/mattersim-1.1.2-py3-none-any.whl + - name: nequip + packages_pyodide: + # Packages with dependencies + - pyyaml + - setuptools + # Packages without dependencies + - nodeps:e3nn>=0.5 + - nodeps:ase + # NequIP wheel (stripped of heavy deps: hydra, lightning, torchmetrics) + - emfs:/drive/packages/nequip-0.15.0-py3-none-any.whl diff --git a/src/py/mat3ra/notebooks_utils/mlff.py b/src/py/mat3ra/notebooks_utils/mlff.py index b8b30a4f6..538d044f8 100644 --- a/src/py/mat3ra/notebooks_utils/mlff.py +++ b/src/py/mat3ra/notebooks_utils/mlff.py @@ -12,7 +12,7 @@ def get_mlff_install_profiles(mlff_name: str) -> str: mlff = (mlff_name or "").strip().lower() if mlff in MLFF_MODULES: - return f"made|api_examples|torch|{mlff}" + return f"made|api_examples|torch|mlff|{mlff}" raise ValueError(f"Unsupported MLFF: {mlff_name!r}") From 69109efd2a0bf6c38bc3e846dc05e9d942037c73 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 09:51:00 -0700 Subject: [PATCH 06/18] chore: gitattr --- .gitattributes | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitattributes b/.gitattributes index 790431e68..14e098b94 100644 --- a/.gitattributes +++ b/.gitattributes @@ -18,3 +18,6 @@ examples/assets/bash_workflow_template.json filter=lfs diff=lfs merge=lfs -text *.whl filter=lfs diff=lfs merge=lfs -text *.model filter=lfs diff=lfs merge=lfs -text packages filter=lfs diff=lfs merge=lfs -text +packages/models/uma-s-1p1-int8.pt filter=lfs diff=lfs merge=lfs -text +packages/models/nequip-oam-s-config-sd.pth filter=lfs diff=lfs merge=lfs -text +packages/models/uma-s-1p1-f16.pt filter=lfs diff=lfs merge=lfs -text From a72db9c97bac64e1d38d93df044af8e62171eab6 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 09:51:33 -0700 Subject: [PATCH 07/18] chore: gitattr --- .gitattributes | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitattributes b/.gitattributes index 14e098b94..9269836fb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -18,6 +18,4 @@ examples/assets/bash_workflow_template.json filter=lfs diff=lfs merge=lfs -text *.whl filter=lfs diff=lfs merge=lfs -text *.model filter=lfs diff=lfs merge=lfs -text packages filter=lfs diff=lfs merge=lfs -text -packages/models/uma-s-1p1-int8.pt filter=lfs diff=lfs merge=lfs -text -packages/models/nequip-oam-s-config-sd.pth filter=lfs diff=lfs merge=lfs -text -packages/models/uma-s-1p1-f16.pt filter=lfs diff=lfs merge=lfs -text +packages/models/*.pt* filter=lfs diff=lfs merge=lfs -text From 6d8bd59bb76a0f687a403e697883cd280d9f49cb Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 11:00:44 -0700 Subject: [PATCH 08/18] chore: mattersim model --- packages/models/mattersim-v1.0.0-1M.pth | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 packages/models/mattersim-v1.0.0-1M.pth diff --git a/packages/models/mattersim-v1.0.0-1M.pth b/packages/models/mattersim-v1.0.0-1M.pth new file mode 100644 index 000000000..3ba62274c --- /dev/null +++ b/packages/models/mattersim-v1.0.0-1M.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28b0b0b0f13efefee06b47ea4c9105a26bd3e2c8396da193430da96b3b49a8be +size 17932943 From 2945a8ef295df971d5cb93487d6f1500200b6fce Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Mon, 25 May 2026 18:59:00 -0700 Subject: [PATCH 09/18] chore: cleanup --- .../relax_structure_with_mlff.ipynb | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb index cdf147b14..8a4c886af 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -5,7 +5,7 @@ "id": "0", "metadata": {}, "source": [ - "# Relax structure with MLFF (MACE / UMA) — Machine Learned Force Field\n", + "# Relax structure with MLFF (MACE / UMA / MATTERSIM / NEQUIP) — Machine Learned Force Field\n", "\n", "Relax atomic positions locally with ASE using a selectable Machine Learned Force Field (MLFF).\n" ] @@ -82,10 +82,13 @@ "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "from mat3ra.notebooks_utils.primitive.environment import is_pyodide_environment\n", - "from mat3ra.notebooks_utils.mlff import get_mlff_install_profiles\n", + "from mat3ra.notebooks_utils.mlff import ensure_mlff_switchable, get_mlff_install_profiles, mark_active_mlff\n", + "\n", + "ensure_mlff_switchable(MLFF_NAME)\n", "\n", "profiles = get_mlff_install_profiles(MLFF_NAME)\n", "await install_packages(profiles)\n", + "mark_active_mlff(MLFF_NAME)\n", "\n", "# PyTorch patches are required in Pyodide for torch-based MLFFs.\n", "if is_pyodide_environment():\n", @@ -148,7 +151,7 @@ "metadata": {}, "source": [ "## 4. Apply Relaxation\n", - "### 4.1. Relax with MACE" + "### 4.1. Relax with selected MLFF" ] }, { @@ -162,7 +165,7 @@ "from ase.optimize import BFGS\n", "\n", "from mat3ra.notebooks_utils.mlff import create_mlff_calculator\n", - "from mat3ra.notebooks_utils.plot import progress_callback\n", + "from mat3ra.notebooks_utils.ipython.plot._plotly import progress_callback\n", "\n", "calculator = create_mlff_calculator(MLFF_NAME, MLFF_SETTINGS[MLFF_NAME])\n", "\n", @@ -328,13 +331,13 @@ ], "metadata": { "kernelspec": { - "display_name": "Python (Pyodide)", + "display_name": ".venv-3.11.2 (3.11.2)", "language": "python", - "name": "python" + "name": "python3" }, "language_info": { "codemirror_mode": { - "name": "python", + "name": "ipython", "version": 3 }, "file_extension": ".py", @@ -342,7 +345,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8" + "version": "3.11.2" } }, "nbformat": 4, From 913b6802b23687e0d4d9cf22ba532826b0cb2cba Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 16:44:44 -0700 Subject: [PATCH 10/18] chore: remove useless --- .../experiments/jupyterlite/relax_structure_with_mlff.ipynb | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb index 8a4c886af..4fc86afee 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -82,13 +82,10 @@ "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "from mat3ra.notebooks_utils.primitive.environment import is_pyodide_environment\n", - "from mat3ra.notebooks_utils.mlff import ensure_mlff_switchable, get_mlff_install_profiles, mark_active_mlff\n", - "\n", - "ensure_mlff_switchable(MLFF_NAME)\n", + "from mat3ra.notebooks_utils.mlff import get_mlff_install_profiles\n", "\n", "profiles = get_mlff_install_profiles(MLFF_NAME)\n", "await install_packages(profiles)\n", - "mark_active_mlff(MLFF_NAME)\n", "\n", "# PyTorch patches are required in Pyodide for torch-based MLFFs.\n", "if is_pyodide_environment():\n", From d727b4d2ceb3ee00527a6202fa6309e47fa20b01 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 18:05:22 -0700 Subject: [PATCH 11/18] update: split patching --- .../pyodide/packages/__init__.py | 29 + .../pyodide/packages/install.py | 47 +- .../notebooks_utils/pyodide/packages/mace.py | 46 +- .../pyodide/packages/mattersim.py | 250 ++++- .../pyodide/packages/nequip.py | 331 +++++- .../notebooks_utils/pyodide/packages/torch.py | 970 +----------------- .../notebooks_utils/pyodide/packages/uma.py | 320 +++++- 7 files changed, 1044 insertions(+), 949 deletions(-) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py index e69de29bb..028d82af3 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py @@ -0,0 +1,29 @@ +from importlib import import_module + +from .torch import apply_common_torch_patches + +MLFF_PATCH_MODULES = { + "mace": "mat3ra.notebooks_utils.pyodide.packages.mace", + "uma": "mat3ra.notebooks_utils.pyodide.packages.uma", + "mattersim": "mat3ra.notebooks_utils.pyodide.packages.mattersim", + "nequip": "mat3ra.notebooks_utils.pyodide.packages.nequip", +} + + +def normalize_mlff_name(mlff_name: str) -> str: + return (mlff_name or "").strip().lower() + + +def apply_mlff_patches(mlff_name: str): + mlff = normalize_mlff_name(mlff_name) + if mlff not in MLFF_PATCH_MODULES: + raise ValueError(f"Unsupported MLFF: {mlff_name!r}") + + patch_module = import_module(MLFF_PATCH_MODULES[mlff]) + patch_module.apply_patches() + + +def apply_all_patches(mlff_name: str): + apply_common_torch_patches() + apply_mlff_patches(mlff_name) + print("\n✅ All Pyodide patches applied successfully!") diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py index b1aadabe0..294c4b931 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py @@ -1,7 +1,7 @@ import json import os import re -from typing import List +from typing import List, Union from ...primitive.environment import ENVIRONMENT from ...primitive.logger import log @@ -72,12 +72,18 @@ def get_packages_list(requirements_dict: dict, notebook_name_pattern: str = "") # Note: environment specific packages have to be installed first, # because in Pyodide common packages might depend on them - return [ - *packages_default_environment_specific, - *packages_notebook_environment_specific, - *packages_default_common, - *packages_notebook_common, - ] + return deduplicate_packages( + [ + *packages_default_environment_specific, + *packages_notebook_environment_specific, + *packages_default_common, + *packages_notebook_common, + ] + ) + + +def deduplicate_packages(packages: List[str]) -> List[str]: + return list(dict.fromkeys(packages)) async def get_package_list_from_config(config_file_path: str, notebook_name_pattern: str) -> list: @@ -86,7 +92,20 @@ async def get_package_list_from_config(config_file_path: str, notebook_name_patt return packages -async def install_package_pyodide(pkg: str, verbose: bool = True): +def should_reinstall_packages(previous_hash: Union[str, None], requirements_hash: str) -> bool: + return previous_hash is not None and previous_hash != requirements_hash + + +def package_has_version_specifier(pkg: str) -> bool: + spec = pkg.split("nodeps:")[-1] # Remove nodeps: prefix if present + return any(op in spec for op in ("==", ">=", "<=", "!=", "~=", ">", "<")) + + +def should_reinstall_package(pkg: str, profile_changed: bool) -> bool: + return profile_changed and package_has_version_specifier(pkg) + + +async def install_package_pyodide(pkg: str, verbose: bool = True, reinstall: bool = False): """ Install a package in a Pyodide environment. @@ -105,7 +124,7 @@ async def install_package_pyodide(pkg: str, verbose: bool = True): is_url = pkg.startswith("http://") or pkg.startswith("https://") or pkg.startswith("emfs:/") are_dependencies_installed = not is_url - await micropip.install(pkg, deps=are_dependencies_installed) + await micropip.install(pkg, deps=are_dependencies_installed, reinstall=reinstall) pkg_name = pkg.split("/")[-1].split("-")[0] if "://" in pkg else pkg.split("==")[0] if verbose: log(f"Installed {pkg_name}", force_verbose=verbose) @@ -121,9 +140,15 @@ async def install_packages_pyodide(notebook_name_pattern: str, verbose: bool = T """ packages = await get_package_list_from_config(get_config_yml_file_path(""), notebook_name_pattern) requirements_hash = str(hash(json.dumps(packages))) - if os.environ.get("requirements_hash") != requirements_hash: + previous_hash = os.environ.get("requirements_hash") + profile_changed = should_reinstall_packages(previous_hash, requirements_hash) + if previous_hash != requirements_hash: for pkg in packages: - await install_package_pyodide(pkg, verbose) + await install_package_pyodide( + pkg, + verbose, + reinstall=should_reinstall_package(pkg, profile_changed), + ) if verbose: log("Packages installed successfully.", force_verbose=verbose) os.environ["requirements_hash"] = requirements_hash diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py index dddd53935..80a541d79 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/mace.py @@ -1,7 +1,45 @@ -from mace.calculators import MACECalculator, mace_mp +import sys +import types +from importlib import import_module from ...primitive.environment import is_pyodide_environment + +def apply_patches(): + patch_mace_training() + patch_mace_tools() + + +def patch_mace_training(): + """ + Stub lmdb and h5py packages. + + These are C-extension packages used by MACE's training/dataset code + but not needed for inference. Stubs allow imports to succeed. + """ + for package_name in ("lmdb", "h5py"): + if package_name not in sys.modules: + sys.modules[package_name] = types.ModuleType(package_name) + + print("✓ LMDB and HDF5 stubs applied") + + +def patch_mace_tools(): + """ + Fix MACE's torch_geometric import order issues in Pyodide. + + In Pyodide, torch_geometric.data may not be set during circular imports. + Pre-importing ensures the attribute is available when MACE needs it. + """ + try: + torch_geometric = import_module("mace.tools.torch_geometric") + torch_geometric_data = import_module("mace.tools.torch_geometric.data") + torch_geometric.data = torch_geometric_data + print("✓ MACE tools patches applied") + except Exception as exc: + print(f"⚠ MACE tools patches skipped: {exc}") + + MODEL_PATHS_MAP = { "small": "/drive/packages/models/2023-12-10-mace-128-L0_energy_epoch-249.model", "medium": "/drive/packages/models/2023-12-03-mace-128-L1_epoch-199.model", @@ -13,7 +51,8 @@ def get_mace_model_pyodide(model: str, dispersion=False, default_dtype="float32" if model not in MODEL_PATHS_MAP: raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") model_path = MODEL_PATHS_MAP[model] - return MACECalculator( + mace_calculators = import_module("mace.calculators") + return mace_calculators.MACECalculator( model_path=model_path, dispersion=dispersion, default_dtype=default_dtype, device=device, **kwargs ) @@ -28,7 +67,8 @@ def create_mace_calculator(model="large", dispersion=True, default_dtype="float3 **kwargs, ) - return mace_mp( + mace_calculators = import_module("mace.calculators") + return mace_calculators.mace_mp( model=model, dispersion=dispersion, default_dtype=default_dtype, diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py index 2ae1870fc..6186b5d13 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/mattersim.py @@ -1,6 +1,248 @@ -from mattersim.forcefield import MatterSimCalculator +import sys +from importlib import import_module + +import numpy as np +import torch from ...primitive.environment import is_pyodide_environment +from .torch import _make_stub_module, patch_torch_distributed + + +def apply_patches(): + patch_torch_distributed() + patch_mattersim_deps() + + +def patch_mattersim_deps(): + """ + Stub heavy dependencies required by MatterSim but not needed for inference. + + Stubs: loguru, azure.*, atomate2, seekpath, phonopy, phono3py, mp_api, + sklearn, and patches e3nn to disable JIT/torch.compile. + """ + loguru_mod = _make_stub_module("loguru") + + class _Logger: + def info(self, msg, *a, **k): + print(f"INFO: {msg}") + + def warning(self, msg, *a, **k): + print(f"WARNING: {msg}") + + def error(self, msg, *a, **k): + print(f"ERROR: {msg}") + + def debug(self, msg, *a, **k): + pass + + def trace(self, msg, *a, **k): + pass + + def success(self, msg, *a, **k): + print(f"✓ {msg}") + + def __getattr__(self, name): + return lambda *a, **k: None + + loguru_mod.logger = _Logger() + + _make_stub_module("azure", submodules=["identity", "storage", "storage.blob"]) + + for package_name in [ + "atomate2", + "seekpath", + "phonopy", + "phono3py", + "mp_api", + "jobflow", + "emmet", + "emmet.core", + "emmet.core.tasks", + "maggma", + ]: + _make_stub_module(package_name) + + _patch_sklearn() + _patch_pyodide_http() + _patch_e3nn() + _patch_torch_jit() + _patch_torch_ema() + _patch_torchmetrics() + _patch_torch_geometric() + _patch_torch_runstats() + + print("✓ MatterSim dependency stubs applied") + + +def _patch_sklearn(): + _make_stub_module( + "sklearn", + submodules=[ + "base", + "utils", + "utils.validation", + "preprocessing", + "model_selection", + "gaussian_process", + "gaussian_process.kernels", + ], + ) + + class _GPR: + def __init__(self, *a, **k): + pass + + def fit(self, *a, **k): + return self + + def predict(self, X, return_std=False): + mean = np.zeros(X.shape[0]) + if return_std: + return mean, np.ones(X.shape[0]) + return mean + + def log_marginal_likelihood(self): + return 0.0 + + sys.modules["sklearn.gaussian_process"].GaussianProcessRegressor = _GPR + + class _Kernel: + pass + + class _DotProduct(_Kernel): + def __init__(self, *a, **k): + pass + + class _Hyperparameter: + def __init__(self, *a, **k): + pass + + sklearn_kernels = sys.modules["sklearn.gaussian_process.kernels"] + sklearn_kernels.Kernel = _Kernel + sklearn_kernels.DotProduct = _DotProduct + sklearn_kernels.Hyperparameter = _Hyperparameter + + +def _patch_pyodide_http(): + try: + import pyodide_http + + pyodide_http.patch_all() + except ImportError: + pass + + +def _patch_e3nn(): + try: + import e3nn + + e3nn._SO3_INITIALIZED = True + except Exception: + pass + + +def _patch_torch_jit(): + if not hasattr(torch.jit, "_original_script"): + + def _noop_script(obj=None, *a, **k): + if obj is not None: + return obj + return lambda fn: fn + + torch.jit.script = _noop_script + + +def _patch_torch_ema(): + torch_ema_mod = _make_stub_module("torch_ema") + + class _EMA: + def __init__(self, *a, **k): + pass + + torch_ema_mod.ExponentialMovingAverage = _EMA + + +def _patch_torchmetrics(): + torchmetrics_mod = _make_stub_module("torchmetrics") + + class _MeanMetric: + def __init__(self, *a, **k): + pass + + torchmetrics_mod.MeanMetric = _MeanMetric + + +def _patch_torch_geometric(): + _make_stub_module("torch_geometric", submodules=["data", "loader", "utils"]) + + class _Data: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def to(self, device): + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + setattr(self, k, v.to(device)) + return self + + sys.modules["torch_geometric.data"].Data = _Data + + class _DataLoader: + def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): + self._dataset = list(dataset) + + def __iter__(self): + for item in self._dataset: + for attr in list(vars(item).keys()): + val = getattr(item, attr) + if isinstance(val, (int, float)): + setattr(item, attr, torch.tensor([val])) + if not hasattr(item, "num_graphs"): + item.num_graphs = 1 + if not hasattr(item, "batch"): + n_atoms = item.num_atoms if hasattr(item, "num_atoms") else torch.tensor([0]) + if isinstance(n_atoms, torch.Tensor): + n_atoms = int(n_atoms.item()) + item.batch = torch.zeros(n_atoms, dtype=torch.long) + yield item + + def __len__(self): + return len(self._dataset) + + sys.modules["torch_geometric.loader"].DataLoader = _DataLoader + + +def _patch_torch_runstats(): + _make_stub_module("torch_runstats", submodules=["scatter"]) + + def _scatter(src, index, dim_size=None, dim=0, reduce="sum"): + if dim_size is None: + dim_size = int(index.max()) + 1 + out = torch.zeros(dim_size, *src.shape[1:], dtype=src.dtype, device=src.device) + if src.dim() == 1: + idx = index + else: + idx = index.unsqueeze(-1).expand_as(src) + if reduce == "sum" or reduce == "add": + out.scatter_add_(0, idx, src) + elif reduce == "mean": + out.scatter_add_(0, idx, src) + count = torch.zeros(dim_size, dtype=src.dtype, device=src.device) + count.scatter_add_(0, index, torch.ones(index.shape[0], dtype=src.dtype, device=src.device)) + count = count.clamp(min=1) + if src.dim() > 1: + count = count.unsqueeze(-1) + out = out / count + return out + + sys.modules["torch_runstats.scatter"].scatter = _scatter + + def _scatter_mean(src, index, dim_size=None, dim=0): + return _scatter(src, index, dim_size=dim_size, dim=dim, reduce="mean") + + sys.modules["torch_runstats.scatter"].scatter_mean = _scatter_mean + MODEL_PATHS_MAP = { "1m": "/drive/packages/models/mattersim-v1.0.0-1M.pth", @@ -10,7 +252,8 @@ def get_mattersim_model_pyodide(model: str, device="cpu", **kwargs): if model not in MODEL_PATHS_MAP: raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") - return MatterSimCalculator.from_checkpoint(load_path=MODEL_PATHS_MAP[model], device=device, **kwargs) + forcefield = import_module("mattersim.forcefield") + return forcefield.MatterSimCalculator.from_checkpoint(load_path=MODEL_PATHS_MAP[model], device=device, **kwargs) def create_mattersim_calculator(model="1m", device="cpu", model_path=None, checkpoint=None, **kwargs): @@ -19,4 +262,5 @@ def create_mattersim_calculator(model="1m", device="cpu", model_path=None, check resolved_model_path = model_path or checkpoint - return MatterSimCalculator.from_checkpoint(load_path=str(resolved_model_path), device=device, **kwargs) + forcefield = import_module("mattersim.forcefield") + return forcefield.MatterSimCalculator.from_checkpoint(load_path=str(resolved_model_path), device=device, **kwargs) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py index ecee14f5f..15628005a 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/nequip.py @@ -1,8 +1,264 @@ -from nequip.ase import NequIPCalculator -from nequip.data.transforms import ChemicalSpeciesToAtomTypeMapper, NeighborListTransform +import os +import sys +import types +from importlib import import_module + +import torch from ...primitive.environment import is_pyodide_environment -from .torch import load_nequip_model +from .torch import _ensure_omegaconf_stub, _make_stub_module, _matscipy_neighbour_list_compat + + +def apply_patches(): + patch_nequip_deps() + + +def patch_nequip_deps(): + """ + Stub heavy dependencies required by NequIP but not needed for inference. + + Stubs: hydra, lightning, pytorch_lightning, torchmetrics, lmdb, matscipy. + Patches e3nn and torch.jit for Pyodide compatibility. + Sets NEQUIP_NL=ase to use ASE neighbor lists instead of matscipy. + """ + os.environ["NEQUIP_NL"] = "ase" + _patch_lightning_utilities() + _patch_lightning() + _patch_pytorch_lightning() + _patch_torchmetrics() + _patch_packaging() + _patch_lmdb() + _patch_hydra() + _ensure_omegaconf_stub() + _patch_matscipy() + _patch_e3nn() + _patch_torch_jit() + _patch_e3nn_jit() + _patch_tqdm() + + print("✓ NequIP dependency stubs applied") + + +def _patch_lightning_utilities(): + _make_stub_module("lightning_utilities", submodules=["core", "core.rank_zero"]) + + def _rank_prefixed_message(msg, rank=None): + return msg + + def _rank_zero_only(fn): + return fn + + rank_zero_mod = sys.modules["lightning_utilities.core.rank_zero"] + rank_zero_mod.rank_prefixed_message = _rank_prefixed_message + rank_zero_mod.rank_zero_only = _rank_zero_only + + +def _patch_lightning(): + _make_stub_module( + "lightning", + submodules=[ + "pytorch", + "pytorch.utilities", + "pytorch.utilities.seed", + "pytorch.utilities.warnings", + "pytorch.callbacks", + ], + ) + + class _IsolateRng: + def __enter__(self): + return self + + def __exit__(self, *a): + pass + + def _seed_everything(seed=None, workers=False, **kwargs): + pass + + sys.modules["lightning.pytorch.utilities.seed"].isolate_rng = lambda: _IsolateRng() + sys.modules["lightning.pytorch"].seed_everything = _seed_everything + + class _PossibleUserWarning(UserWarning): + pass + + sys.modules["lightning.pytorch.utilities.warnings"].PossibleUserWarning = _PossibleUserWarning + + class _LightningModule(torch.nn.Module): + def __init__(self, *a, **k): + super().__init__() + + def log(self, *a, **k): + pass + + sys.modules["lightning.pytorch"].LightningModule = _LightningModule + sys.modules["lightning"].pytorch = sys.modules["lightning.pytorch"] + sys.modules["lightning.pytorch.callbacks"].Callback = type("Callback", (), {}) + + +def _patch_pytorch_lightning(): + _make_stub_module("pytorch_lightning", submodules=["utilities", "utilities.seed"]) + + class _IsolateRng: + def __enter__(self): + return self + + def __exit__(self, *a): + pass + + sys.modules["pytorch_lightning.utilities.seed"].isolate_rng = lambda: _IsolateRng() + + +def _patch_torchmetrics(): + torchmetrics_mod = _make_stub_module("torchmetrics") + + class _Metric(torch.nn.Module): + def __init__(self, *a, **k): + super().__init__() + + def add_state(self, name, default=None, dist_reduce_fx=None, **k): + if default is not None: + setattr(self, name, default) + + torchmetrics_mod.Metric = _Metric + torchmetrics_mod.MeanMetric = _Metric + + +def _patch_packaging(): + if "packaging" in sys.modules: + return + try: + import packaging.version # noqa: F401 + except (ImportError, ModuleNotFoundError): + _make_stub_module("packaging", submodules=["version"]) + + class _Version: + def __init__(self, v): + self._v = str(v) + parts = self._v.split(".") + self.major = int(parts[0]) if parts else 0 + self.minor = int(parts[1]) if len(parts) > 1 else 0 + + def __lt__(self, other): + return (self.major, self.minor) < (other.major, other.minor) + + def __ge__(self, other): + return not self.__lt__(other) + + def __repr__(self): + return f"Version('{self._v}')" + + sys.modules["packaging.version"].Version = _Version + sys.modules["packaging.version"].parse = lambda v: _Version(v) + + +def _patch_lmdb(): + if "lmdb" not in sys.modules: + sys.modules["lmdb"] = types.ModuleType("lmdb") + + +def _patch_hydra(): + if "hydra" in sys.modules: + return + _make_stub_module( + "hydra", + submodules=[ + "core", + "core.global_hydra", + "utils", + "_internal", + "_internal.instantiate", + "_internal.instantiate._instantiate2", + ], + ) + + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = import_module(mod_path) + cls = getattr(mod, cls_name) + pos_args = list(args) + list(config.get("_args_", [])) + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v] + pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] + sys.modules["hydra._internal.instantiate._instantiate2"].InstantiationException = type( + "InstantiationException", (Exception,), {} + ) + + def _hydra_get_target(target_str): + mod_path, name = target_str.rsplit(".", 1) + return getattr(import_module(mod_path), name) + + sys.modules["hydra.utils"].get_method = _hydra_get_target + sys.modules["hydra.utils"].get_class = _hydra_get_target + + +def _patch_matscipy(): + if "matscipy" in sys.modules: + return + matscipy_mod = types.ModuleType("matscipy") + matscipy_mod.__path__ = [] + matscipy_mod.__package__ = "matscipy" + matscipy_neighbours = types.ModuleType("matscipy.neighbours") + matscipy_neighbours.neighbour_list = _matscipy_neighbour_list_compat + matscipy_mod.neighbours = matscipy_neighbours + sys.modules["matscipy"] = matscipy_mod + sys.modules["matscipy.neighbours"] = matscipy_neighbours + + +def _patch_e3nn(): + try: + import e3nn + + e3nn._SO3_INITIALIZED = True + if hasattr(e3nn, "_OPT_DEFAULTS"): + e3nn._OPT_DEFAULTS["jit_mode"] = "eager" + except Exception: + pass + + +def _patch_torch_jit(): + if not hasattr(torch.jit, "_original_script"): + + def _noop_script(obj=None, *a, **k): + if obj is not None: + return obj + return lambda fn: fn + + torch.jit.script = _noop_script + + +def _patch_e3nn_jit(): + try: + import e3nn.util.jit + + e3nn.util.jit.compile_mode = lambda mode: lambda cls: cls + except Exception: + pass + + +def _patch_tqdm(): + if "tqdm" in sys.modules: + return + tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) + + def _tqdm_passthrough(iterable=None, *a, **k): + return iterable if iterable is not None else iter([]) + + tqdm_mod.tqdm = _tqdm_passthrough + sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough + sys.modules["tqdm.std"].tqdm = _tqdm_passthrough + MODEL_PATHS_MAP = { "oam_s": "/drive/packages/models/nequip-oam-s-config-sd.pth", @@ -20,13 +276,15 @@ def get_nequip_model_pyodide(model: str, device="cpu"): def create_nequip_calculator_from_model(nequip_model, device="cpu"): r_max = float(nequip_model.metadata["r_max"]) type_names = nequip_model.metadata["type_names"].split(" ") + nequip_ase = import_module("nequip.ase") + transforms = import_module("nequip.data.transforms") - return NequIPCalculator( + return nequip_ase.NequIPCalculator( model=nequip_model, device=device, transforms=[ - ChemicalSpeciesToAtomTypeMapper(type_names), - NeighborListTransform(r_max=r_max), + transforms.ChemicalSpeciesToAtomTypeMapper(type_names), + transforms.NeighborListTransform(r_max=r_max), ], ) @@ -39,3 +297,64 @@ def create_nequip_calculator(model="oam_s", device="cpu", model_path=None, check nequip_model = load_nequip_model(str(resolved_model_path)) return create_nequip_calculator_from_model(nequip_model, device=device) + + +def load_nequip_model(checkpoint_path): + data = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + import nequip.utils.global_state as global_state + + def _pyodide_set_global_state(allow_tf32=False, warn_on_override=False): + if not global_state._GLOBAL_STATE_INITIALIZED: + torch.set_default_dtype(torch.float64) + try: + import e3nn + + e3nn.set_optimization_defaults( + specialized_code=True, + optimize_einsums=True, + jit_script_fx=False, + ) + except Exception: + pass + global_state._GLOBAL_STATE_INITIALIZED = True + global_state._latest_global_config["allow_tf32"] = allow_tf32 + + global_state.set_global_state = _pyodide_set_global_state + global_state.set_global_state(allow_tf32=False) + + from nequip.model import FullNequIPGNNModel + + model = FullNequIPGNNModel( + seed=0, + model_dtype="float32", + r_max=data["r_max"], + type_names=data["type_names"], + irreps_edge_sh=data["irreps_edge_sh"], + type_embed_num_features=data["type_embed_num_features"], + feature_irreps_hidden=data["feature_irreps_hidden"], + radial_mlp_depth=data["radial_mlp_depth"], + radial_mlp_width=data["radial_mlp_width"], + avg_num_neighbors=data["avg_num_neighbors"], + per_type_energy_scales=data["per_type_energy_scales"], + per_type_energy_shifts=data["per_type_energy_shifts"], + polynomial_cutoff_p=data.get("polynomial_cutoff_p", 6), + ) + + if data.get("has_zbl", False): + from nequip.nn.pair_potential import ZBL + + seq_net = model.model.func + zbl = ZBL( + type_names=data["type_names"], + chemical_species=data["type_names"], + units="metal", + irreps_in=seq_net.irreps_out, + ) + seq_net.insert(name="pair_potential", module=zbl, before="total_energy_sum") + + model.load_state_dict(data["state_dict"], strict=True) + model.eval() + + print(f"✓ NequIP model loaded ({sum(p.numel() for p in model.parameters()):,} parameters)") + return model diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index 2be0bd07b..2e7557a5d 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -1,24 +1,13 @@ """ Patches for PyTorch and related packages to work in Pyodide environment. -This module provides patches for various torch-related packages that don't work -in Pyodide's WASM environment, organized by functionality. +This module provides shared Pyodide compatibility patches for torch-based MLFF +packages. Package-specific patches live in their corresponding package modules. Usage: - from mat3ra.notebooks_utils.other.torch_pyodide import ( - patch_torch_linalg, - patch_torch_testing, - patch_matscipy, - patch_lmdb_h5py, - patch_mace_tools, - ) + from mat3ra.notebooks_utils.pyodide.packages import apply_all_patches - # Apply all patches - patch_torch_linalg() - patch_torch_testing() - patch_matscipy() - patch_lmdb_h5py() - patch_mace_tools() + apply_all_patches("mace") """ import sys @@ -32,6 +21,49 @@ EigRet = namedtuple("linalg_eig", ["eigenvalues", "eigenvectors"]) # type: ignore EighRet = namedtuple("linalg_eigh", ["eigenvalues", "eigenvectors"]) # type: ignore LUFactorReturn = namedtuple("LUFactorReturn", ["LU", "pivots"]) +_OMEGACONF_RESOLVERS = {} + + +class _DictConfig(dict): + pass + + +class _ListConfig(list): + pass + + +class _OmegaConf: + @staticmethod + def to_container(cfg, **kwargs): + return dict(cfg) if isinstance(cfg, dict) else cfg + + @staticmethod + def create(data): + return _DictConfig(data) if isinstance(data, dict) else data + + @staticmethod + def register_new_resolver(name, func, **kwargs): + _OMEGACONF_RESOLVERS[name] = func + + @staticmethod + def register_resolver(name, func, **kwargs): + _OMEGACONF_RESOLVERS[name] = func + + @staticmethod + def has_resolver(name): + return name in _OMEGACONF_RESOLVERS + + +def _ensure_omegaconf_stub(): + omegaconf_mod = sys.modules.get("omegaconf") or _make_stub_module("omegaconf") + omegaconf_mod.DictConfig = getattr(omegaconf_mod, "DictConfig", _DictConfig) + omegaconf_mod.ListConfig = getattr(omegaconf_mod, "ListConfig", _ListConfig) + omega_conf = getattr(omegaconf_mod, "OmegaConf", _OmegaConf) + for name in ("register_new_resolver", "register_resolver", "has_resolver"): + if not hasattr(omega_conf, name): + setattr(omega_conf, name, getattr(_OmegaConf, name)) + omegaconf_mod.OmegaConf = omega_conf + return omegaconf_mod def _to_np(tensor): @@ -543,7 +575,7 @@ class _DTensor: # ============================================================================== -# FAIRChem heavy dependency stubs +# Stub helpers # ============================================================================== @@ -572,313 +604,6 @@ def _make_stub_module(name, attrs=None, submodules=None): return mod -def patch_fairchem_deps(): - """ - Stub heavy dependencies that fairchem-core imports but doesn't need for inference. - - This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf, - submitit, clusterscope, tqdm, huggingface_hub, websockets. - """ - # --- numba --- - numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"]) - numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f - numba_mod.jit = numba_mod.njit - numba_mod.prange = range - for t in ("int32", "int64", "float32", "float64", "boolean"): - setattr(numba_mod, t, t) - - class _TypedList(list): - pass - - sys.modules["numba.typed"].List = _TypedList - - # --- ray --- - ray_mod = _make_stub_module( - "ray", - submodules=[ - "serve", - "runtime_env", - "train", - "data", - "util", - "util.scheduling_strategies", - "util.queue", - ], - ) - - def _ray_remote(*args, **kwargs): - if args and callable(args[0]): - args[0]._remote = lambda *a, **kw: None - return args[0] - - def wrapper(fn_or_cls): - fn_or_cls._remote = lambda *a, **kw: None - return fn_or_cls - - return wrapper - - ray_mod.remote = _ray_remote - ray_mod.init = lambda *a, **k: None - ray_mod.get = lambda *a, **k: None - ray_mod.put = lambda *a, **k: None - ray_mod.wait = lambda *a, **k: ([], []) - ray_mod.is_initialized = lambda: False - - class _ObjectRef: - pass - - ray_mod.ObjectRef = _ObjectRef - - class _PlacementGroupSchedulingStrategy: - def __init__(self, *a, **k): - pass - - sys.modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = _PlacementGroupSchedulingStrategy - - # ray.serve stubs - _serve = sys.modules["ray.serve"] - - def _serve_deployment(*args, **kwargs): - if args and callable(args[0]): - return args[0] - return lambda cls_or_fn: cls_or_fn - - _serve.deployment = _serve_deployment - _serve.ingress = lambda *a, **k: (lambda cls: cls) - _serve.run = lambda *a, **k: None - _serve.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] - - _serve_schema = types.ModuleType("ray.serve.schema") - _serve_schema.__package__ = "ray.serve" - - class _LoggingConfig: - def __init__(self, **k): - self.__dict__.update(k) - - _serve_schema.LoggingConfig = _LoggingConfig - _serve.schema = _serve_schema - sys.modules["ray.serve.schema"] = _serve_schema - - # --- wandb --- - wandb_mod = _make_stub_module("wandb") - wandb_mod.init = lambda *a, **k: None - wandb_mod.log = lambda *a, **k: None - wandb_mod.finish = lambda *a, **k: None - - # --- torchtnt --- - _make_stub_module( - "torchtnt", - submodules=[ - "framework", - "framework.state", - "framework.unit", - "framework.callback", - "framework.auto_unit", - "framework.fit", - "framework.train", - "framework.evaluate", - "framework.predict", - "utils", - "utils.loggers", - "utils.timer", - "utils.distributed", - "utils.prepare_module", - ], - ) - - class _PredictUnit: - def __init__(self, *a, **k): - pass - - def __class_getitem__(cls, item): - return cls - - class _TrainUnit: - def __init__(self, *a, **k): - pass - - def __class_getitem__(cls, item): - return cls - - class _EvalUnit: - def __init__(self, *a, **k): - pass - - def __class_getitem__(cls, item): - return cls - - class _State: - def __init__(self, *a, **k): - pass - - def __class_getitem__(cls, item): - return cls - - class _Callback: - pass - - tnt_framework = sys.modules["torchtnt.framework"] - tnt_unit = sys.modules["torchtnt.framework.unit"] - tnt_state = sys.modules["torchtnt.framework.state"] - tnt_cb = sys.modules["torchtnt.framework.callback"] - for m in [tnt_framework, tnt_unit]: - m.PredictUnit = _PredictUnit - m.TrainUnit = _TrainUnit - m.EvalUnit = _EvalUnit - tnt_state.State = _State - tnt_cb.Callback = _Callback - tnt_framework.State = _State - tnt_framework.Callback = _Callback - - # torchtnt entry point functions - sys.modules["torchtnt.framework.fit"].fit = lambda *a, **k: None - sys.modules["torchtnt.framework.train"].train = lambda *a, **k: None - sys.modules["torchtnt.framework.evaluate"].evaluate = lambda *a, **k: None - sys.modules["torchtnt.framework.predict"].predict = lambda *a, **k: None - - # torchtnt.utils stubs - tnt_dist = sys.modules["torchtnt.utils.distributed"] - tnt_dist.get_file_init_method = lambda *a, **k: "" - tnt_dist.get_tcp_init_method = lambda *a, **k: "" - tnt_dist.spawn_multi_process = lambda *a, **k: None - - tnt_prep = sys.modules["torchtnt.utils.prepare_module"] - tnt_prep.prepare_module = lambda module, *a, **k: module - tnt_prep.FSDPStrategy = type("FSDPStrategy", (), {"__init__": lambda self, **k: None}) - tnt_prep.DDPStrategy = type("DDPStrategy", (), {"__init__": lambda self, **k: None}) - tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) - - # --- hydra / omegaconf --- - omegaconf_mod = _make_stub_module("omegaconf") - - class _DictConfig(dict): - pass - - class _ListConfig(list): - pass - - omegaconf_mod.DictConfig = _DictConfig - omegaconf_mod.ListConfig = _ListConfig - omegaconf_mod.OmegaConf = type( - "OmegaConf", - (), - { - "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), - "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), - }, - ) - _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) - - def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): - import importlib - - if isinstance(config, dict) and "_target_" in config: - target = config["_target_"] - mod_path, cls_name = target.rsplit(".", 1) - mod = importlib.import_module(mod_path) - cls = getattr(mod, cls_name) - pos_args = list(args) + list(config.get("_args_", [])) - cfg = {k: v for k, v in config.items() if not k.startswith("_")} - if _recursive_: - for k, v in cfg.items(): - if isinstance(v, dict) and "_target_" in v: - cfg[k] = _hydra_instantiate(v) - elif isinstance(v, list): - cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v] - pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] - return cls(*pos_args, **{**cfg, **kwargs}) - return config - - sys.modules["hydra.utils"].instantiate = _hydra_instantiate - sys.modules["hydra"].utils = sys.modules["hydra.utils"] - - # --- submitit / clusterscope --- - _make_stub_module("submitit") - _make_stub_module("clusterscope") - - # --- websockets --- - _make_stub_module("websockets") - - # --- tqdm --- - tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) - - def _tqdm_passthrough(iterable=None, *a, **k): - return iterable if iterable is not None else iter([]) - - tqdm_mod.tqdm = _tqdm_passthrough - sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough - sys.modules["tqdm.std"].tqdm = _tqdm_passthrough - - # --- huggingface_hub --- - hf_mod = _make_stub_module("huggingface_hub", submodules=["utils"]) - hf_mod.hf_hub_download = lambda *a, **k: "" - hf_mod.snapshot_download = lambda *a, **k: "" - - # --- ase_db_backends --- - _make_stub_module("ase_db_backends") - # --- INT8 quantized model support --- - _orig_torch_load = torch.load - - def _int8_aware_torch_load(f, *args, **kwargs): - result = _orig_torch_load(f, *args, **kwargs) - if isinstance(result, dict) and "quantized_ema_state_dict" in result: - import gc as _gc - - from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint - - print(" Dequantizing INT8 -> FP16 (streaming)...") - quantized_ema = result.pop("quantized_ema_state_dict") - scales = result.pop("quantization_scales") - ema_state_dict = {} - names = list(quantized_ema.keys()) - for name in names: - tensor = quantized_ema.pop(name) - if name in scales: - scale = scales.pop(name) - ema_state_dict[name] = (tensor.float() * scale.float()).half() - del scale - else: - ema_state_dict[name] = tensor - del tensor - del quantized_ema, scales - _gc.collect() - checkpoint = MLIPInferenceCheckpoint( - model_config=result["model_config"], - model_state_dict=result.get("model_state_dict", {}), - ema_state_dict=ema_state_dict, - tasks_config=result["tasks_config"], - ) - del result - _gc.collect() - print(" ✓ Dequantization complete") - return checkpoint - return result - - torch.load = _int8_aware_torch_load - - # --- Model registry fallback --- - try: - from fairchem.core.common.registry import registry as _registry - - _orig_get_model = _registry.get_model_class - - def _fallback_get_model_class(name): - try: - return _orig_get_model(name) - except RuntimeError: - import importlib as _imp - - module_path, class_name = name.rsplit(".", 1) - mod = _imp.import_module(module_path) - return getattr(mod, class_name) - - _registry.get_model_class = _fallback_get_model_class - except Exception: - pass - - print("✓ FAIRChem dependency stubs applied") - - # ============================================================================== # Matscipy patches # ============================================================================== @@ -913,604 +638,9 @@ def patch_matscipy(): print("✓ Matscipy patches applied") -# ============================================================================== -# LMDB and HDF5 patches -# ============================================================================== - - -def patch_mace_training(): - """ - Stub lmdb and h5py packages. - - These are C-extension packages used by MACE's training/dataset code - but not needed for inference. Stubs allow imports to succeed. - """ - for _pkg in ("lmdb", "h5py"): - if _pkg not in sys.modules: - sys.modules[_pkg] = types.ModuleType(_pkg) - - print("✓ LMDB and HDF5 stubs applied") - - -# ============================================================================== -# MACE tools patches -# ============================================================================== - - -def patch_mace_tools(): - """ - Fix MACE's torch_geometric import order issues in Pyodide. - - In Pyodide, torch_geometric.data may not be set during circular imports. - Pre-importing ensures the attribute is available when MACE needs it. - """ - try: - import importlib as _importlib - - _tg = _importlib.import_module("mace.tools.torch_geometric") - _tg_data = _importlib.import_module("mace.tools.torch_geometric.data") - _tg.data = _tg_data - print("✓ MACE tools patches applied") - except Exception as e: - print(f"⚠ MACE tools patches skipped: {e}") - - -# ============================================================================== -# Convenience function to apply all patches -# ============================================================================== - - -def apply_all_patches(include_fairchem=False, include_mattersim=False, include_nequip=False): - """ - Apply all torch and model patches for Pyodide in one call. - - Args: - include_fairchem: If True, also apply FAIRChem-specific patches - (torch.distributed, heavy dependency stubs). Set this when - using fairchem-core / UMA models. - include_mattersim: If True, also apply MatterSim-specific patches - (loguru, azure, e3nn JIT stubs). Set this when - using MatterSim / M3GNet models. - include_nequip: If True, also apply NequIP-specific patches - (lightning, hydra, torchmetrics, e3nn JIT stubs). Set this - when using NequIP models. - """ +def apply_common_torch_patches(): + """Apply shared Pyodide patches needed by torch-based MLFF packages.""" patch_torch_linalg() patch_torch_compiler() patch_torch_testing() patch_matscipy() - patch_mace_training() - patch_mace_tools() - - if include_fairchem: - patch_torch_distributed() - patch_fairchem_deps() - - if include_mattersim: - patch_torch_distributed() - patch_mattersim_deps() - - if include_nequip: - patch_nequip_deps() - - print("\n✅ All Pyodide patches applied successfully!") - - -# ============================================================================== -# MatterSim patches -# ============================================================================== - - -def patch_mattersim_deps(): - """ - Stub heavy dependencies required by MatterSim but not needed for inference. - - Stubs: loguru, azure.*, atomate2, seekpath, phonopy, phono3py, mp_api, - sklearn, and patches e3nn to disable JIT/torch.compile. - """ - loguru_mod = _make_stub_module("loguru") - - class _Logger: - def info(self, msg, *a, **k): - print(f"INFO: {msg}") - - def warning(self, msg, *a, **k): - print(f"WARNING: {msg}") - - def error(self, msg, *a, **k): - print(f"ERROR: {msg}") - - def debug(self, msg, *a, **k): - pass - - def trace(self, msg, *a, **k): - pass - - def success(self, msg, *a, **k): - print(f"✓ {msg}") - - def __getattr__(self, name): - return lambda *a, **k: None - - loguru_mod.logger = _Logger() - - _make_stub_module("azure", submodules=["identity", "storage", "storage.blob"]) - - for pkg in [ - "atomate2", - "seekpath", - "phonopy", - "phono3py", - "mp_api", - "jobflow", - "emmet", - "emmet.core", - "emmet.core.tasks", - "maggma", - ]: - _make_stub_module(pkg) - - _make_stub_module( - "sklearn", - submodules=[ - "base", - "utils", - "utils.validation", - "preprocessing", - "model_selection", - "gaussian_process", - "gaussian_process.kernels", - ], - ) - - class _GPR: - def __init__(self, *a, **k): - pass - - def fit(self, *a, **k): - return self - - def predict(self, X, return_std=False): - import numpy as _np - - mean = _np.zeros(X.shape[0]) - if return_std: - return mean, _np.ones(X.shape[0]) - return mean - - def log_marginal_likelihood(self): - return 0.0 - - sys.modules["sklearn.gaussian_process"].GaussianProcessRegressor = _GPR - - class _Kernel: - pass - - class _DotProduct(_Kernel): - def __init__(self, *a, **k): - pass - - class _Hyperparameter: - def __init__(self, *a, **k): - pass - - _sk_kernels = sys.modules["sklearn.gaussian_process.kernels"] - _sk_kernels.Kernel = _Kernel - _sk_kernels.DotProduct = _DotProduct - _sk_kernels.Hyperparameter = _Hyperparameter - - try: - import pyodide_http # noqa: F401 - - pyodide_http.patch_all() - except ImportError: - pass - - try: - import e3nn - - e3nn._SO3_INITIALIZED = True - except Exception: - pass - - import torch - - if not hasattr(torch.jit, "_original_script"): - _orig_script = torch.jit.script - - def _noop_script(obj=None, *a, **k): - if obj is not None: - return obj - return lambda fn: fn - - torch.jit.script = _noop_script - - _te = _make_stub_module("torch_ema") - - class _EMA: - def __init__(self, *a, **k): - pass - - _te.ExponentialMovingAverage = _EMA - - _tm = _make_stub_module("torchmetrics") - - class _MeanMetric: - def __init__(self, *a, **k): - pass - - _tm.MeanMetric = _MeanMetric - - _make_stub_module("torch_geometric", submodules=["data", "loader", "utils"]) - - class _Data: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - - def to(self, device): - import torch as _t - - for k, v in self.__dict__.items(): - if isinstance(v, _t.Tensor): - setattr(self, k, v.to(device)) - return self - - sys.modules["torch_geometric.data"].Data = _Data - - class _DataLoader: - def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): - self._dataset = list(dataset) - - def __iter__(self): - import torch as _t - - for item in self._dataset: - for attr in list(vars(item).keys()): - val = getattr(item, attr) - if isinstance(val, (int, float)): - setattr(item, attr, _t.tensor([val])) - if not hasattr(item, "num_graphs"): - item.num_graphs = 1 - if not hasattr(item, "batch"): - n_atoms = item.num_atoms if hasattr(item, "num_atoms") else _t.tensor([0]) - if isinstance(n_atoms, _t.Tensor): - n_atoms = int(n_atoms.item()) - item.batch = _t.zeros(n_atoms, dtype=_t.long) - yield item - - def __len__(self): - return len(self._dataset) - - sys.modules["torch_geometric.loader"].DataLoader = _DataLoader - - _make_stub_module("torch_runstats", submodules=["scatter"]) - import torch as _torch - - def _scatter(src, index, dim_size=None, dim=0, reduce="sum"): - if dim_size is None: - dim_size = int(index.max()) + 1 - out = _torch.zeros(dim_size, *src.shape[1:], dtype=src.dtype, device=src.device) - if src.dim() == 1: - idx = index - else: - idx = index.unsqueeze(-1).expand_as(src) - if reduce == "sum" or reduce == "add": - out.scatter_add_(0, idx, src) - elif reduce == "mean": - out.scatter_add_(0, idx, src) - count = _torch.zeros(dim_size, dtype=src.dtype, device=src.device) - count.scatter_add_(0, index, _torch.ones(index.shape[0], dtype=src.dtype, device=src.device)) - count = count.clamp(min=1) - if src.dim() > 1: - count = count.unsqueeze(-1) - out = out / count - return out - - sys.modules["torch_runstats.scatter"].scatter = _scatter - - def _scatter_mean(src, index, dim_size=None, dim=0): - return _scatter(src, index, dim_size=dim_size, dim=dim, reduce="mean") - - sys.modules["torch_runstats.scatter"].scatter_mean = _scatter_mean - - print("✓ MatterSim dependency stubs applied") - - -# ============================================================================== -# NequIP patches -# ============================================================================== - - -def patch_nequip_deps(): - """ - Stub heavy dependencies required by NequIP but not needed for inference. - - Stubs: hydra, lightning, pytorch_lightning, torchmetrics, lmdb, matscipy. - Patches e3nn and torch.jit for Pyodide compatibility. - Sets NEQUIP_NL=ase to use ASE neighbor lists instead of matscipy. - """ - import os - - import torch - - os.environ["NEQUIP_NL"] = "ase" - - _make_stub_module("lightning_utilities", submodules=["core", "core.rank_zero"]) - - def _rank_prefixed_message(msg, rank=None): - return msg - - def _rank_zero_only(fn): - return fn - - sys.modules["lightning_utilities.core.rank_zero"].rank_prefixed_message = _rank_prefixed_message - sys.modules["lightning_utilities.core.rank_zero"].rank_zero_only = _rank_zero_only - - _make_stub_module( - "lightning", - submodules=[ - "pytorch", - "pytorch.utilities", - "pytorch.utilities.seed", - "pytorch.utilities.warnings", - "pytorch.callbacks", - ], - ) - - class _IsolateRng: - def __enter__(self): - return self - - def __exit__(self, *a): - pass - - def _seed_everything(seed=None, workers=False, **kwargs): - pass - - sys.modules["lightning.pytorch.utilities.seed"].isolate_rng = lambda: _IsolateRng() - sys.modules["lightning.pytorch"].seed_everything = _seed_everything - - class _PossibleUserWarning(UserWarning): - pass - - sys.modules["lightning.pytorch.utilities.warnings"].PossibleUserWarning = _PossibleUserWarning - - class _LightningModule(torch.nn.Module): - def __init__(self, *a, **k): - super().__init__() - - def log(self, *a, **k): - pass - - sys.modules["lightning.pytorch"].LightningModule = _LightningModule - sys.modules["lightning"].pytorch = sys.modules["lightning.pytorch"] - sys.modules["lightning.pytorch.callbacks"].Callback = type("Callback", (), {}) - - _make_stub_module("pytorch_lightning", submodules=["utilities", "utilities.seed"]) - sys.modules["pytorch_lightning.utilities.seed"].isolate_rng = lambda: _IsolateRng() - - _tm = _make_stub_module("torchmetrics") - - class _Metric(torch.nn.Module): - def __init__(self, *a, **k): - super().__init__() - - def add_state(self, name, default=None, dist_reduce_fx=None, **k): - if default is not None: - setattr(self, name, default) - - _tm.Metric = _Metric - _tm.MeanMetric = _Metric - - if "packaging" not in sys.modules: - try: - import packaging.version # noqa: F401 - except (ImportError, ModuleNotFoundError): - _make_stub_module("packaging", submodules=["version"]) - - class _Version: - def __init__(self, v): - self._v = str(v) - parts = self._v.split(".") - self.major = int(parts[0]) if parts else 0 - self.minor = int(parts[1]) if len(parts) > 1 else 0 - - def __lt__(self, other): - return (self.major, self.minor) < (other.major, other.minor) - - def __ge__(self, other): - return not self.__lt__(other) - - def __repr__(self): - return f"Version('{self._v}')" - - sys.modules["packaging.version"].Version = _Version - sys.modules["packaging.version"].parse = lambda v: _Version(v) - - if "lmdb" not in sys.modules: - sys.modules["lmdb"] = types.ModuleType("lmdb") - - if "hydra" not in sys.modules: - _make_stub_module( - "hydra", - submodules=[ - "core", - "core.global_hydra", - "utils", - "_internal", - "_internal.instantiate", - "_internal.instantiate._instantiate2", - ], - ) - - def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): - import importlib - - if isinstance(config, dict) and "_target_" in config: - target = config["_target_"] - mod_path, cls_name = target.rsplit(".", 1) - mod = importlib.import_module(mod_path) - cls = getattr(mod, cls_name) - pos_args = list(args) + list(config.get("_args_", [])) - cfg = {k: v for k, v in config.items() if not k.startswith("_")} - if _recursive_: - for k, v in cfg.items(): - if isinstance(v, dict) and "_target_" in v: - cfg[k] = _hydra_instantiate(v) - elif isinstance(v, list): - cfg[k] = [ - _hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v - ] - pos_args = [ - _hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args - ] - return cls(*pos_args, **{**cfg, **kwargs}) - return config - - sys.modules["hydra.utils"].instantiate = _hydra_instantiate - sys.modules["hydra"].utils = sys.modules["hydra.utils"] - sys.modules["hydra._internal.instantiate._instantiate2"].InstantiationException = type( - "InstantiationException", (Exception,), {} - ) - - def _hydra_get_target(target_str): - import importlib - - mod_path, name = target_str.rsplit(".", 1) - return getattr(importlib.import_module(mod_path), name) - - sys.modules["hydra.utils"].get_method = _hydra_get_target - sys.modules["hydra.utils"].get_class = _hydra_get_target - - if "omegaconf" not in sys.modules: - omegaconf_mod = _make_stub_module("omegaconf") - - class _DictConfig(dict): - pass - - class _ListConfig(list): - pass - - omegaconf_mod.DictConfig = _DictConfig - omegaconf_mod.ListConfig = _ListConfig - omegaconf_mod.OmegaConf = type( - "OmegaConf", - (), - { - "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), - "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), - "register_new_resolver": staticmethod(lambda name, func, **k: None), - }, - ) - - if "matscipy" not in sys.modules: - _matscipy = types.ModuleType("matscipy") - _matscipy.__path__ = [] - _matscipy.__package__ = "matscipy" - _matscipy_neighbours = types.ModuleType("matscipy.neighbours") - _matscipy_neighbours.neighbour_list = _matscipy_neighbour_list_compat - _matscipy.neighbours = _matscipy_neighbours - sys.modules["matscipy"] = _matscipy - sys.modules["matscipy.neighbours"] = _matscipy_neighbours - - try: - import e3nn - - e3nn._SO3_INITIALIZED = True - if hasattr(e3nn, "_OPT_DEFAULTS"): - e3nn._OPT_DEFAULTS["jit_mode"] = "eager" - except Exception: - pass - - if not hasattr(torch.jit, "_original_script"): - _orig_script = torch.jit.script - - def _noop_script(obj=None, *a, **k): - if obj is not None: - return obj - return lambda fn: fn - - torch.jit.script = _noop_script - - try: - import e3nn.util.jit - - e3nn.util.jit.compile_mode = lambda mode: lambda cls: cls - except Exception: - pass - - if "tqdm" not in sys.modules: - tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) - - def _tqdm_passthrough(iterable=None, *a, **k): - return iterable if iterable is not None else iter([]) - - tqdm_mod.tqdm = _tqdm_passthrough - sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough - sys.modules["tqdm.std"].tqdm = _tqdm_passthrough - - print("✓ NequIP dependency stubs applied") - - -def load_nequip_model(checkpoint_path): - import torch - - data = torch.load(checkpoint_path, map_location="cpu", weights_only=False) - - import nequip.utils.global_state as _gs - - def _pyodide_set_global_state(allow_tf32=False, warn_on_override=False): - if not _gs._GLOBAL_STATE_INITIALIZED: - torch.set_default_dtype(torch.float64) - try: - import e3nn - - e3nn.set_optimization_defaults( - specialized_code=True, - optimize_einsums=True, - jit_script_fx=False, - ) - except Exception: - pass - _gs._GLOBAL_STATE_INITIALIZED = True - _gs._latest_global_config["allow_tf32"] = allow_tf32 - - _gs.set_global_state = _pyodide_set_global_state - _gs.set_global_state(allow_tf32=False) - - from nequip.model import FullNequIPGNNModel - - model = FullNequIPGNNModel( - seed=0, - model_dtype="float32", - r_max=data["r_max"], - type_names=data["type_names"], - irreps_edge_sh=data["irreps_edge_sh"], - type_embed_num_features=data["type_embed_num_features"], - feature_irreps_hidden=data["feature_irreps_hidden"], - radial_mlp_depth=data["radial_mlp_depth"], - radial_mlp_width=data["radial_mlp_width"], - avg_num_neighbors=data["avg_num_neighbors"], - per_type_energy_scales=data["per_type_energy_scales"], - per_type_energy_shifts=data["per_type_energy_shifts"], - polynomial_cutoff_p=data.get("polynomial_cutoff_p", 6), - ) - - if data.get("has_zbl", False): - from nequip.nn.pair_potential import ZBL - - seq_net = model.model.func - zbl = ZBL( - type_names=data["type_names"], - chemical_species=data["type_names"], - units="metal", - irreps_in=seq_net.irreps_out, - ) - seq_net.insert(name="pair_potential", module=zbl, before="total_energy_sum") - - model.load_state_dict(data["state_dict"], strict=True) - model.eval() - - print(f"✓ NequIP model loaded ({sum(p.numel() for p in model.parameters()):,} parameters)") - return model diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py index 0aad4b331..4acee9c5b 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py @@ -1,7 +1,311 @@ -from fairchem.core import FAIRChemCalculator -from fairchem.core.units.mlip_unit import load_predict_unit +import gc +import sys +import types +from importlib import import_module + +import torch from ...primitive.environment import is_pyodide_environment +from .torch import _ensure_omegaconf_stub, _make_stub_module, patch_torch_distributed + + +def apply_patches(): + patch_torch_distributed() + patch_fairchem_deps() + + +def patch_fairchem_deps(): + """ + Stub heavy dependencies that fairchem-core imports but doesn't need for inference. + + This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf, + submitit, clusterscope, tqdm, huggingface_hub, websockets. + """ + numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"]) + numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f + numba_mod.jit = numba_mod.njit + numba_mod.prange = range + for t in ("int32", "int64", "float32", "float64", "boolean"): + setattr(numba_mod, t, t) + + class _TypedList(list): + pass + + sys.modules["numba.typed"].List = _TypedList + ray_mod = _make_stub_module( + "ray", + submodules=[ + "serve", + "runtime_env", + "train", + "data", + "util", + "util.scheduling_strategies", + "util.queue", + ], + ) + + def _ray_remote(*args, **kwargs): + if args and callable(args[0]): + args[0]._remote = lambda *a, **kw: None + return args[0] + + def wrapper(fn_or_cls): + fn_or_cls._remote = lambda *a, **kw: None + return fn_or_cls + + return wrapper + + ray_mod.remote = _ray_remote + ray_mod.init = lambda *a, **k: None + ray_mod.get = lambda *a, **k: None + ray_mod.put = lambda *a, **k: None + ray_mod.wait = lambda *a, **k: ([], []) + ray_mod.is_initialized = lambda: False + + class _ObjectRef: + pass + + ray_mod.ObjectRef = _ObjectRef + + class _PlacementGroupSchedulingStrategy: + def __init__(self, *a, **k): + pass + + sys.modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = _PlacementGroupSchedulingStrategy + _patch_ray_serve() + _patch_wandb() + _patch_torchtnt() + _patch_hydra() + _make_stub_module("submitit") + _make_stub_module("clusterscope") + _make_stub_module("websockets") + _patch_tqdm() + _patch_huggingface_hub() + _make_stub_module("ase_db_backends") + _patch_int8_torch_load() + _patch_model_registry_fallback() + + print("✓ FAIRChem dependency stubs applied") + + +def _patch_ray_serve(): + serve_mod = sys.modules["ray.serve"] + + def _serve_deployment(*args, **kwargs): + if args and callable(args[0]): + return args[0] + return lambda cls_or_fn: cls_or_fn + + serve_mod.deployment = _serve_deployment + serve_mod.ingress = lambda *a, **k: (lambda cls: cls) + serve_mod.run = lambda *a, **k: None + serve_mod.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] + + serve_schema = types.ModuleType("ray.serve.schema") + serve_schema.__package__ = "ray.serve" + + class _LoggingConfig: + def __init__(self, **k): + self.__dict__.update(k) + + serve_schema.LoggingConfig = _LoggingConfig + serve_mod.schema = serve_schema + sys.modules["ray.serve.schema"] = serve_schema + + +def _patch_wandb(): + wandb_mod = _make_stub_module("wandb") + wandb_mod.init = lambda *a, **k: None + wandb_mod.log = lambda *a, **k: None + wandb_mod.finish = lambda *a, **k: None + + +def _patch_torchtnt(): + _make_stub_module( + "torchtnt", + submodules=[ + "framework", + "framework.state", + "framework.unit", + "framework.callback", + "framework.auto_unit", + "framework.fit", + "framework.train", + "framework.evaluate", + "framework.predict", + "utils", + "utils.loggers", + "utils.timer", + "utils.distributed", + "utils.prepare_module", + ], + ) + + class _PredictUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _TrainUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _EvalUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _State: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _Callback: + pass + + tnt_framework = sys.modules["torchtnt.framework"] + tnt_unit = sys.modules["torchtnt.framework.unit"] + tnt_state = sys.modules["torchtnt.framework.state"] + tnt_cb = sys.modules["torchtnt.framework.callback"] + for module in [tnt_framework, tnt_unit]: + module.PredictUnit = _PredictUnit + module.TrainUnit = _TrainUnit + module.EvalUnit = _EvalUnit + tnt_state.State = _State + tnt_cb.Callback = _Callback + tnt_framework.State = _State + tnt_framework.Callback = _Callback + sys.modules["torchtnt.framework.fit"].fit = lambda *a, **k: None + sys.modules["torchtnt.framework.train"].train = lambda *a, **k: None + sys.modules["torchtnt.framework.evaluate"].evaluate = lambda *a, **k: None + sys.modules["torchtnt.framework.predict"].predict = lambda *a, **k: None + + tnt_dist = sys.modules["torchtnt.utils.distributed"] + tnt_dist.get_file_init_method = lambda *a, **k: "" + tnt_dist.get_tcp_init_method = lambda *a, **k: "" + tnt_dist.spawn_multi_process = lambda *a, **k: None + + tnt_prep = sys.modules["torchtnt.utils.prepare_module"] + tnt_prep.prepare_module = lambda module, *a, **k: module + tnt_prep.FSDPStrategy = type("FSDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.DDPStrategy = type("DDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) + + +def _patch_hydra(): + _ensure_omegaconf_stub() + _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) + + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = import_module(mod_path) + cls = getattr(mod, cls_name) + pos_args = list(args) + list(config.get("_args_", [])) + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v] + pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] + + +def _patch_tqdm(): + tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) + + def _tqdm_passthrough(iterable=None, *a, **k): + return iterable if iterable is not None else iter([]) + + tqdm_mod.tqdm = _tqdm_passthrough + sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough + sys.modules["tqdm.std"].tqdm = _tqdm_passthrough + + +def _patch_huggingface_hub(): + hf_mod = _make_stub_module("huggingface_hub", submodules=["utils"]) + hf_mod.hf_hub_download = lambda *a, **k: "" + hf_mod.snapshot_download = lambda *a, **k: "" + + +def _patch_int8_torch_load(): + original_torch_load = torch.load + + def _int8_aware_torch_load(f, *args, **kwargs): + result = original_torch_load(f, *args, **kwargs) + if isinstance(result, dict) and "quantized_ema_state_dict" in result: + return _dequantize_int8_checkpoint(result) + return result + + torch.load = _int8_aware_torch_load + + +def _dequantize_int8_checkpoint(result): + from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint + + print(" Dequantizing INT8 -> FP16 (streaming)...") + quantized_ema = result.pop("quantized_ema_state_dict") + scales = result.pop("quantization_scales") + ema_state_dict = {} + names = list(quantized_ema.keys()) + for name in names: + tensor = quantized_ema.pop(name) + if name in scales: + scale = scales.pop(name) + ema_state_dict[name] = (tensor.float() * scale.float()).half() + del scale + else: + ema_state_dict[name] = tensor + del tensor + del quantized_ema, scales + gc.collect() + checkpoint = MLIPInferenceCheckpoint( + model_config=result["model_config"], + model_state_dict=result.get("model_state_dict", {}), + ema_state_dict=ema_state_dict, + tasks_config=result["tasks_config"], + ) + del result + gc.collect() + print(" ✓ Dequantization complete") + return checkpoint + + +def _patch_model_registry_fallback(): + try: + from fairchem.core.common.registry import registry as fairchem_registry + + original_get_model = fairchem_registry.get_model_class + + def _fallback_get_model_class(name): + try: + return original_get_model(name) + except RuntimeError: + module_path, class_name = name.rsplit(".", 1) + mod = import_module(module_path) + return getattr(mod, class_name) + + fairchem_registry.get_model_class = _fallback_get_model_class + except Exception: + pass + MODEL_PATHS_MAP = { "f16": "/drive/packages/models/uma-s-1p1-f16.pt", @@ -12,8 +316,10 @@ def get_uma_model_pyodide(model: str, task_name="omat", device="cpu", **kwargs): if model not in MODEL_PATHS_MAP: raise ValueError(f"Invalid model name: {model}. Valid options are: {list(MODEL_PATHS_MAP.keys())}") - predictor = load_predict_unit(MODEL_PATHS_MAP[model], device=device, **kwargs) - return FAIRChemCalculator(predictor, task_name=task_name) + fairchem_core = import_module("fairchem.core") + mlip_unit = import_module("fairchem.core.units.mlip_unit") + predictor = mlip_unit.load_predict_unit(MODEL_PATHS_MAP[model], device=device, **kwargs) + return fairchem_core.FAIRChemCalculator(predictor, task_name=task_name) def create_uma_calculator(model="f16", task_name="omat", device="cpu", model_path=None, checkpoint=None, **kwargs): @@ -22,5 +328,7 @@ def create_uma_calculator(model="f16", task_name="omat", device="cpu", model_pat resolved_model_path = model_path or checkpoint - predictor = load_predict_unit(str(resolved_model_path), device=device, **kwargs) - return FAIRChemCalculator(predictor, task_name=task_name) + fairchem_core = import_module("fairchem.core") + mlip_unit = import_module("fairchem.core.units.mlip_unit") + predictor = mlip_unit.load_predict_unit(str(resolved_model_path), device=device, **kwargs) + return fairchem_core.FAIRChemCalculator(predictor, task_name=task_name) From 3d0ba98381d30c61dea1665dedd8cf89be94858e Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 18:05:53 -0700 Subject: [PATCH 12/18] update: use split patching --- .../jupyterlite/relax_structure_with_mlff.ipynb | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb index 4fc86afee..0a811b0d2 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -30,7 +30,7 @@ "INTERFACE_NAME = \"Interface\" # name of the interface to load from the folder\n", "\n", "# MLFF selector\n", - "MLFF_NAME = \"mace\" # \"mace\" | \"uma\" | \"mattersim\" | \"nequip\"\n", + "MLFF_NAME = \"uma\" # \"mace\" | \"uma\" | \"mattersim\" | \"nequip\"\n", "\n", "# MLFF-specific settings\n", "MLFF_SETTINGS = {\n", @@ -89,13 +89,9 @@ "\n", "# PyTorch patches are required in Pyodide for torch-based MLFFs.\n", "if is_pyodide_environment():\n", - " from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + " from mat3ra.notebooks_utils.pyodide.packages import apply_all_patches\n", "\n", - " apply_all_patches(\n", - " include_fairchem=(MLFF_NAME == \"uma\"),\n", - " include_mattersim=(MLFF_NAME == \"mattersim\"),\n", - " include_nequip=(MLFF_NAME == \"nequip\"),\n", - " )\n" + " apply_all_patches(MLFF_NAME)\n" ] }, { From 23c4b5dbcb36b04ea47ce6c637d917b82120b55f Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 18:20:55 -0700 Subject: [PATCH 13/18] update: split --- .../relax_structure_with_mlff.ipynb | 2 +- src/py/mat3ra/notebooks_utils/packages.py | 3 +- .../pyodide/packages/__init__.py | 29 ------------------- .../pyodide/packages/patches.py | 29 +++++++++++++++++++ .../notebooks_utils/pyodide/packages/torch.py | 2 +- 5 files changed, 33 insertions(+), 32 deletions(-) create mode 100644 src/py/mat3ra/notebooks_utils/pyodide/packages/patches.py diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb index 0a811b0d2..82d3b515a 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -89,7 +89,7 @@ "\n", "# PyTorch patches are required in Pyodide for torch-based MLFFs.\n", "if is_pyodide_environment():\n", - " from mat3ra.notebooks_utils.pyodide.packages import apply_all_patches\n", + " from mat3ra.notebooks_utils.pyodide.packages.patches import apply_all_patches\n", "\n", " apply_all_patches(MLFF_NAME)\n" ] diff --git a/src/py/mat3ra/notebooks_utils/packages.py b/src/py/mat3ra/notebooks_utils/packages.py index c6d74ce6d..12396cd9d 100644 --- a/src/py/mat3ra/notebooks_utils/packages.py +++ b/src/py/mat3ra/notebooks_utils/packages.py @@ -1,6 +1,5 @@ from .ipython.packages.install import install_packages_python from .primitive.environment import is_pyodide_environment -from .pyodide.packages.install import install_packages_pyodide async def install_packages(notebook_name_pattern: str, config_file_path: str = "", verbose: bool = True): @@ -17,6 +16,8 @@ async def install_packages(notebook_name_pattern: str, config_file_path: str = " verbose (bool): Whether to print install progress. """ if is_pyodide_environment(): + from .pyodide.packages.install import install_packages_pyodide + await install_packages_pyodide(notebook_name_pattern, verbose) else: install_packages_python(notebook_name_pattern, verbose) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py index 028d82af3..e69de29bb 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/__init__.py @@ -1,29 +0,0 @@ -from importlib import import_module - -from .torch import apply_common_torch_patches - -MLFF_PATCH_MODULES = { - "mace": "mat3ra.notebooks_utils.pyodide.packages.mace", - "uma": "mat3ra.notebooks_utils.pyodide.packages.uma", - "mattersim": "mat3ra.notebooks_utils.pyodide.packages.mattersim", - "nequip": "mat3ra.notebooks_utils.pyodide.packages.nequip", -} - - -def normalize_mlff_name(mlff_name: str) -> str: - return (mlff_name or "").strip().lower() - - -def apply_mlff_patches(mlff_name: str): - mlff = normalize_mlff_name(mlff_name) - if mlff not in MLFF_PATCH_MODULES: - raise ValueError(f"Unsupported MLFF: {mlff_name!r}") - - patch_module = import_module(MLFF_PATCH_MODULES[mlff]) - patch_module.apply_patches() - - -def apply_all_patches(mlff_name: str): - apply_common_torch_patches() - apply_mlff_patches(mlff_name) - print("\n✅ All Pyodide patches applied successfully!") diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/patches.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/patches.py new file mode 100644 index 000000000..028d82af3 --- /dev/null +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/patches.py @@ -0,0 +1,29 @@ +from importlib import import_module + +from .torch import apply_common_torch_patches + +MLFF_PATCH_MODULES = { + "mace": "mat3ra.notebooks_utils.pyodide.packages.mace", + "uma": "mat3ra.notebooks_utils.pyodide.packages.uma", + "mattersim": "mat3ra.notebooks_utils.pyodide.packages.mattersim", + "nequip": "mat3ra.notebooks_utils.pyodide.packages.nequip", +} + + +def normalize_mlff_name(mlff_name: str) -> str: + return (mlff_name or "").strip().lower() + + +def apply_mlff_patches(mlff_name: str): + mlff = normalize_mlff_name(mlff_name) + if mlff not in MLFF_PATCH_MODULES: + raise ValueError(f"Unsupported MLFF: {mlff_name!r}") + + patch_module = import_module(MLFF_PATCH_MODULES[mlff]) + patch_module.apply_patches() + + +def apply_all_patches(mlff_name: str): + apply_common_torch_patches() + apply_mlff_patches(mlff_name) + print("\n✅ All Pyodide patches applied successfully!") diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index 2e7557a5d..323db67bb 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -5,7 +5,7 @@ packages. Package-specific patches live in their corresponding package modules. Usage: - from mat3ra.notebooks_utils.pyodide.packages import apply_all_patches + from mat3ra.notebooks_utils.pyodide.packages.patches import apply_all_patches apply_all_patches("mace") """ From 6347729da451e532603bd5a48f8d7266c590f32d Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 18:42:08 -0700 Subject: [PATCH 14/18] update: fix reinstall --- .../pyodide/packages/install.py | 76 +++++++++++++++---- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py index 294c4b931..5a8e10cae 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/install.py @@ -1,7 +1,8 @@ import json import os import re -from typing import List, Union +import sys +from typing import List, Tuple, Union from ...primitive.environment import ENVIRONMENT from ...primitive.logger import log @@ -11,6 +12,10 @@ except ImportError: micropip = None # type: ignore +NODEPS_PREFIX = "nodeps:" +URL_PREFIXES = ("http://", "https://", "emfs:/") +VERSION_SPECIFIERS = ("==", ">=", "<=", "!=", "~=", ">", "<") + def get_config_yml_file_path(config_file_path: str) -> str: """ @@ -92,17 +97,61 @@ async def get_package_list_from_config(config_file_path: str, notebook_name_patt return packages -def should_reinstall_packages(previous_hash: Union[str, None], requirements_hash: str) -> bool: - return previous_hash is not None and previous_hash != requirements_hash +def should_install_packages(previous_hash: Union[str, None], requirements_hash: str) -> bool: + return previous_hash != requirements_hash + + +def is_url_package(pkg: str) -> bool: + return pkg.startswith(URL_PREFIXES) + + +def remove_nodeps_prefix(pkg: str) -> str: + return pkg.replace(NODEPS_PREFIX, "", 1) if pkg.startswith(NODEPS_PREFIX) else pkg def package_has_version_specifier(pkg: str) -> bool: - spec = pkg.split("nodeps:")[-1] # Remove nodeps: prefix if present - return any(op in spec for op in ("==", ">=", "<=", "!=", "~=", ">", "<")) + spec = remove_nodeps_prefix(pkg) + return any(op in spec for op in VERSION_SPECIFIERS) def should_reinstall_package(pkg: str, profile_changed: bool) -> bool: - return profile_changed and package_has_version_specifier(pkg) + return profile_changed and package_has_version_specifier(pkg) and not is_url_package(remove_nodeps_prefix(pkg)) + + +def get_package_name(pkg: str) -> Union[str, None]: + spec = remove_nodeps_prefix(pkg) + match = re.match(r"^[A-Za-z0-9_.-]+", spec) + return match.group(0) if match else None + + +def get_import_package_name(package_name: str) -> str: + return package_name.replace("-", "_") + + +def clear_imported_package_modules(package_name: str): + import_name = get_import_package_name(package_name) + module_names = [name for name in sys.modules if name == import_name or name.startswith(f"{import_name}.")] + for module_name in module_names: + sys.modules.pop(module_name, None) + + +async def uninstall_package_pyodide(pkg: str): + package_name = get_package_name(pkg) + if not package_name: + return + if not hasattr(micropip, "uninstall"): + raise RuntimeError(f"Cannot reinstall {package_name}: micropip.uninstall is unavailable.") + + uninstall_result = micropip.uninstall(package_name) + if hasattr(uninstall_result, "__await__"): + await uninstall_result + clear_imported_package_modules(package_name) + + +def get_install_spec_and_deps(pkg: str) -> Tuple[str, bool]: + if pkg.startswith(NODEPS_PREFIX): + return remove_nodeps_prefix(pkg), False + return pkg, not is_url_package(pkg) async def install_package_pyodide(pkg: str, verbose: bool = True, reinstall: bool = False): @@ -117,14 +166,11 @@ async def install_package_pyodide(pkg: str, verbose: bool = True, reinstall: boo await install_package_pyodide("numpy") # installs with deps await install_package_pyodide("nodeps:e3nn==0.4.4") # installs without deps """ - if pkg.startswith("nodeps:"): - pkg = pkg.replace("nodeps:", "") - are_dependencies_installed = False - else: - is_url = pkg.startswith("http://") or pkg.startswith("https://") or pkg.startswith("emfs:/") - are_dependencies_installed = not is_url + pkg, are_dependencies_installed = get_install_spec_and_deps(pkg) + if reinstall: + await uninstall_package_pyodide(pkg) - await micropip.install(pkg, deps=are_dependencies_installed, reinstall=reinstall) + await micropip.install(pkg, deps=are_dependencies_installed) pkg_name = pkg.split("/")[-1].split("-")[0] if "://" in pkg else pkg.split("==")[0] if verbose: log(f"Installed {pkg_name}", force_verbose=verbose) @@ -141,8 +187,8 @@ async def install_packages_pyodide(notebook_name_pattern: str, verbose: bool = T packages = await get_package_list_from_config(get_config_yml_file_path(""), notebook_name_pattern) requirements_hash = str(hash(json.dumps(packages))) previous_hash = os.environ.get("requirements_hash") - profile_changed = should_reinstall_packages(previous_hash, requirements_hash) - if previous_hash != requirements_hash: + profile_changed = previous_hash is not None and previous_hash != requirements_hash + if should_install_packages(previous_hash, requirements_hash): for pkg in packages: await install_package_pyodide( pkg, From 304bed325cb2778a1ce9ea9d5868b79df393251a Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 19:14:19 -0700 Subject: [PATCH 15/18] update: fix uma --- src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py index 4acee9c5b..412a0ef6d 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py @@ -1,3 +1,4 @@ +import enum import gc import sys import types @@ -109,7 +110,15 @@ class _LoggingConfig: def __init__(self, **k): self.__dict__.update(k) + class _ApplicationStatus(enum.Enum): + NOT_STARTED = "NOT_STARTED" + DEPLOYING = "DEPLOYING" + RUNNING = "RUNNING" + DEPLOY_FAILED = "DEPLOY_FAILED" + DELETING = "DELETING" + serve_schema.LoggingConfig = _LoggingConfig + serve_schema.ApplicationStatus = _ApplicationStatus serve_mod.schema = serve_schema sys.modules["ray.serve.schema"] = serve_schema From ebe6f0cd7805e96711cb3ff8d81cce4c29f8550b Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Tue, 9 Jun 2026 19:34:17 -0700 Subject: [PATCH 16/18] update: patches --- .../relax_structure_with_mace.ipynb | 4 +-- .../local/relaxation_mlff_mace.ipynb | 4 +-- .../notebooks_utils/pyodide/packages/uma.py | 30 +++++++++++++++++-- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_mace.ipynb b/other/experiments/jupyterlite/relax_structure_with_mace.ipynb index b0b1a032e..8dc876ac6 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mace.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mace.ipynb @@ -71,9 +71,9 @@ "\n", "await install_packages(\"made|api_examples|torch|mace\")\n", "\n", - "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + "from mat3ra.notebooks_utils.pyodide.packages.patches import apply_all_patches\n", "\n", - "apply_all_patches()" + "apply_all_patches(\"mace\")" ] }, { diff --git a/other/materials_designer/workflows/local/relaxation_mlff_mace.ipynb b/other/materials_designer/workflows/local/relaxation_mlff_mace.ipynb index fcfdc9b63..1d918b12e 100644 --- a/other/materials_designer/workflows/local/relaxation_mlff_mace.ipynb +++ b/other/materials_designer/workflows/local/relaxation_mlff_mace.ipynb @@ -44,9 +44,9 @@ "\n", "await install_packages(\"made|api_examples|torch|mace\")\n", "\n", - "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + "from mat3ra.notebooks_utils.pyodide.packages.patches import apply_all_patches\n", "\n", - "apply_all_patches()" + "apply_all_patches(\"mace\")" ] }, { diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py index 412a0ef6d..339264b07 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/uma.py @@ -20,7 +20,7 @@ def patch_fairchem_deps(): Stub heavy dependencies that fairchem-core imports but doesn't need for inference. This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf, - submitit, clusterscope, tqdm, huggingface_hub, websockets. + submitit, psutil, clusterscope, tqdm, huggingface_hub, websockets. """ numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"]) numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f @@ -78,7 +78,8 @@ def __init__(self, *a, **k): _patch_wandb() _patch_torchtnt() _patch_hydra() - _make_stub_module("submitit") + _patch_submitit() + _patch_psutil() _make_stub_module("clusterscope") _make_stub_module("websockets") _patch_tqdm() @@ -102,6 +103,7 @@ def _serve_deployment(*args, **kwargs): serve_mod.ingress = lambda *a, **k: (lambda cls: cls) serve_mod.run = lambda *a, **k: None serve_mod.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] + serve_mod.multiplexed = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] serve_schema = types.ModuleType("ray.serve.schema") serve_schema.__package__ = "ray.serve" @@ -123,6 +125,30 @@ class _ApplicationStatus(enum.Enum): sys.modules["ray.serve.schema"] = serve_schema +def _patch_submitit(): + submitit_mod = _make_stub_module("submitit", submodules=["helpers", "core", "core.utils"]) + + class _SubmititPlaceholder: + pass + + submitit_mod.Job = _SubmititPlaceholder + sys.modules["submitit.helpers"].Checkpointable = _SubmititPlaceholder + sys.modules["submitit.helpers"].DelayedSubmission = _SubmititPlaceholder + sys.modules["submitit.core.utils"].JobPaths = _SubmititPlaceholder + + +def _patch_psutil(): + psutil_mod = types.ModuleType("psutil") + + class _UnavailableProcess: + def __init__(self, *args, **kwargs): + raise RuntimeError("psutil process management is unavailable in Pyodide") + + psutil_mod.Process = _UnavailableProcess + psutil_mod.wait_procs = lambda *args, **kwargs: ([], []) + sys.modules["psutil"] = psutil_mod + + def _patch_wandb(): wandb_mod = _make_stub_module("wandb") wandb_mod.init = lambda *a, **k: None From b0dc169efd67f27533d2a3670a483183a9fe67de Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 10 Jun 2026 10:24:37 -0700 Subject: [PATCH 17/18] update: import fairchem --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 798e0192e..121be4b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ forcefields = [ # WARNING: mattersim will automatically upgrade numpy and pymatgen # This WILL conflict with base project dependencies # Use this optional dependency ONLY in isolated environments + "fairchem-core>=0.3.0", ] all_dev = [ "mat3ra-notebooks-utils[all]", From 25300f26e59b4e5038b08fffb99bd9f65a9e02bd Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Wed, 10 Jun 2026 19:45:51 -0700 Subject: [PATCH 18/18] update: cleanup --- other/experiments/jupyterlite/relax_structure_with_mace.ipynb | 2 +- other/experiments/jupyterlite/relax_structure_with_mlff.ipynb | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_mace.ipynb b/other/experiments/jupyterlite/relax_structure_with_mace.ipynb index 8dc876ac6..a6ccd9a6d 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mace.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mace.ipynb @@ -69,7 +69,7 @@ "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "\n", - "await install_packages(\"made|api_examples|torch|mace\")\n", + "await install_packages(\"made|api_examples|torch|mlff|mace\")\n", "\n", "from mat3ra.notebooks_utils.pyodide.packages.patches import apply_all_patches\n", "\n", diff --git a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb index 82d3b515a..470169162 100644 --- a/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_mlff.ipynb @@ -30,7 +30,7 @@ "INTERFACE_NAME = \"Interface\" # name of the interface to load from the folder\n", "\n", "# MLFF selector\n", - "MLFF_NAME = \"uma\" # \"mace\" | \"uma\" | \"mattersim\" | \"nequip\"\n", + "MLFF_NAME = \"nequip\" # \"mace\" | \"uma\" | \"mattersim\" | \"nequip\"\n", "\n", "# MLFF-specific settings\n", "MLFF_SETTINGS = {\n",