diff --git a/.gitattributes b/.gitattributes
index 790431e68..7f638c094 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -17,4 +17,5 @@ images/*.png filter=lfs diff=lfs merge=lfs -text
examples/assets/bash_workflow_template.json filter=lfs diff=lfs merge=lfs -text
*.whl filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
packages filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
index 04de5d664..dc0a1e54a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -110,3 +110,4 @@ examples/assets/Molecules_Dataset_Collection
pw_scf.*
examples/**/*.html
.DS_Store
+node_modules
diff --git a/config.yml b/config.yml
index da8e0fae1..3beec21b9 100644
--- a/config.yml
+++ b/config.yml
@@ -101,3 +101,24 @@ notebooks:
- ssl
- h5py
- lmdb
+ - name: uma
+ packages_pyodide:
+ # antlr4 local wheel (required by omegaconf, PyPI version has ATN mismatch)
+ - emfs:/drive/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl
+ # Packages with dependencies
+ - opt_einsum
+ - orjson
+ - pyyaml
+ - sqlite3
+ - omegaconf
+ - hydra-core
+ # Packages without dependencies (using nodeps: prefix)
+ - nodeps:opt_einsum_fx
+ - nodeps:e3nn>=0.5
+ - nodeps:ase
+ - nodeps:monty
+ - nodeps:fairchem-core
+ # Stubbed packages (will be patched by torch_pyodide with include_fairchem=True)
+ - ssl
+ - h5py
+ - lmdb
diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb
new file mode 100644
index 000000000..78f0baf9b
--- /dev/null
+++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb
@@ -0,0 +1,300 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "9515ed910c085db8",
+ "metadata": {},
+ "source": [
+ "# Relax Structure with FAIRChem UMA \u2014 Universal Machine-learning Force Field\n",
+ "\n",
+ "Use FAIRChem's [UMA](https://github.com/FAIR-Chem/fairchem) (Universal Model for Atoms) to relax crystal structures using machine-learned interatomic potentials.\n",
+ "\n",
+ "
Usage
\n",
+ "\n",
+ "1. Drop the materials files into the \"uploads\" folder in the JupyterLab file browser\n",
+ "1. Set Input Parameters below or use the default values\n",
+ "1. Click \"Run\" > \"Run All\" to run all cells\n",
+ "1. Wait for the run to complete. Scroll down to view cell results.\n",
+ "1. Review the relaxation plot and modify parameters as needed\n",
+ "\n",
+ "## Methodology\n",
+ "\n",
+ "1. Load materials from JSON files and create structure via `mat3ra-made`\n",
+ "2. Install FAIRChem UMA and apply Pyodide patches\n",
+ "3. Convert to ASE atoms with `to_ase()`\n",
+ "4. Relax the structure with FAIRChem UMA and visualize convergence\n",
+ "5. Compute relaxation energy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4737d145950b1cc8",
+ "metadata": {},
+ "source": [
+ "## 1. Set Input Parameters\n",
+ "### 1.1. Structure and Relaxation Parameters\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "57dc565952aa2e4e",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "FOLDER = \"uploads\"\n",
+ "STRUCTURE_NAME = \"Interface\" # Name of the structure to load from local file\n",
+ "\n",
+ "RELAXATION_PARAMETERS = {\n",
+ " \"FMAX\": 0.05,\n",
+ "}\n",
+ "UMA_TASK = \"omat\" # Task name for the UMA model (e.g. \"omat\", \"oc20\", \"omol\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1e9d9e76d619cd91",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-int8.pt\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4e89f2d820acb1ed",
+ "metadata": {},
+ "source": [
+ "## 2. Install Packages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ece62358982306f2",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.notebooks_utils.packages import install_packages\n",
+ "\n",
+ "await install_packages(\"made|api_examples|torch|uma\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c7f8e5e6a1b34d90",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n",
+ "\n",
+ "apply_all_patches(include_fairchem=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f89d3c98ddce2ab5",
+ "metadata": {},
+ "source": [
+ "## 3. Load Materials"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8fd400dace70549e",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.made.material import Material\n",
+ "from mat3ra.notebooks_utils.material import load_material_from_folder\n",
+ "from mat3ra.standata.materials import Materials\n",
+ "\n",
+ "structure = load_material_from_folder(FOLDER, STRUCTURE_NAME) or Material.create(\n",
+ " Materials.get_by_name_first_match(STRUCTURE_NAME))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "42f12abf6b65aa2c",
+ "metadata": {},
+ "source": [
+ "### 3.1. Visualize Input Structure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fcbb4e6c1de21233",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.notebooks_utils.ipython.entity.material.visualize import ViewersEnum, visualize_materials as visualize\n",
+ "\n",
+ "visualize([{\"material\": structure, \"title\": structure.name}], viewer=ViewersEnum.wave)\n",
+ "visualize(structure, repetitions=[1, 1, 1], rotation=\"-90x\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e64688fc18c49bb6",
+ "metadata": {},
+ "source": [
+ "## 4. Apply Relaxation\n",
+ "### 4.1. Load UMA Model and Create Calculator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a1b2c3d4e5f60001",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from fairchem.core import FAIRChemCalculator\n",
+ "from fairchem.core.units.mlip_unit import load_predict_unit\n",
+ "\n",
+ "predictor = load_predict_unit(UMA_MODEL_PATH, device=\"cpu\")\n",
+ "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n",
+ "\n",
+ "print(f\"UMA model loaded from {UMA_MODEL_PATH}\")\n",
+ "print(f\"Task: {UMA_TASK}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4e5f6a7b8c90002",
+ "metadata": {},
+ "source": [
+ "### 4.2. Relax with UMA"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3d8746a77f71bab5",
+ "metadata": {
+ "jupyter": {
+ "is_executing": true
+ },
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "import plotly.graph_objs as go\nfrom IPython.display import display\nfrom plotly.subplots import make_subplots\n\nfrom mat3ra.made.tools.convert import to_ase\nfrom ase.optimize import BFGS\n\nase_structure = to_ase(structure)\nase_structure.set_calculator(calculator)\ndyn = BFGS(ase_structure)\n\nsteps = []\nenergies = []\n\nfig = make_subplots(rows=1, cols=1, specs=[[{\"type\": \"scatter\"}]])\nscatter = go.Scatter(x=[], y=[], mode=\"lines+markers\", name=\"Energy\")\nfig.add_trace(scatter)\nfig.update_layout(title_text=\"Real-time Optimization Progress\", xaxis_title=\"Step\", yaxis_title=\"Energy (eV)\")\n\ntry:\n f = go.FigureWidget(fig)\nexcept ImportError:\n f = go.Figure(fig)\ndisplay(f)\n\n\ndef plotly_callback():\n step = dyn.nsteps\n energy = ase_structure.get_total_energy()\n steps.append(step)\n energies.append(energy)\n print(f\"Step: {step}, Energy: {energy:.4f} eV\")\n if hasattr(f, \"batch_update\"):\n with f.batch_update():\n f.data[0].x = steps\n f.data[0].y = energies\n else:\n f.data[0].x = steps\n f.data[0].y = energies\n\n\ndyn.attach(plotly_callback, interval=1)\ndyn.run(fmax=RELAXATION_PARAMETERS[\"FMAX\"])\n\nase_original_structure = to_ase(structure)\nase_original_structure.set_calculator(calculator)\nase_final_structure = ase_structure\n\noriginal_energy = ase_original_structure.get_total_energy()\nrelaxed_energy = ase_structure.get_total_energy()\n\nprint(f\"The final energy is {float(relaxed_energy):.3f} eV.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "abfa372909a96bf8",
+ "metadata": {},
+ "source": [
+ "## 5. Analyze Results\n",
+ "### 5.1. View Structure Before and After Relaxation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9565d0931b198f63",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.made.tools.convert import from_ase\n",
+ "\n",
+ "material_original = Material.create(from_ase(ase_original_structure))\n",
+ "material_relaxed = Material.create(from_ase(ase_final_structure))\n",
+ "material_original.name = structure.name\n",
+ "material_relaxed.name = structure.name + \" (UMA Relaxed)\"\n",
+ "\n",
+ "visualize(\n",
+ " [\n",
+ " {\"material\": material_original, \"title\": material_original.name},\n",
+ " {\"material\": material_relaxed, \"title\": material_relaxed.name},\n",
+ " ],\n",
+ " viewer=ViewersEnum.wave,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4b49774283e5517",
+ "metadata": {},
+ "source": [
+ "### 5.2. Output interlayer distance before and after relaxation\n",
+ "This requires labels for substrate and film present in the interface structure, which is already done if interface was created with `mat3ra-made`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6dd00402bc2e9d59",
+ "metadata": {
+ "trusted": true
+ },
+ "outputs": [],
+ "source": [
+ "from mat3ra.made.tools.analyze.other import get_average_interlayer_distance\n",
+ "\n",
+ "SUBSTRATE_TAG = 0\n",
+ "FILM_TAG = 1\n",
+ "\n",
+ "print(\n",
+ " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")\n",
+ "print(\n",
+ " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f60fdb73e44c09c",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "\n",
+ "[1] FAIRChem: https://github.com/FAIR-Chem/fairchem \n",
+ "[2] UMA \u2014 Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n",
+ "[3] mat3ra-made interface builder: https://github.com/Exabyte-io/made "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python (Pyodide)",
+ "language": "python",
+ "name": "python"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "python",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl
new file mode 100644
index 000000000..2a6ca7233
--- /dev/null
+++ b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59f990d0c30e35b3ec80e0daa1c2aff9bd3115794fcb8f119125c9235e2e1178
+size 144573
diff --git a/packages/models/uma-s-1p1-f16.pt b/packages/models/uma-s-1p1-f16.pt
new file mode 100644
index 000000000..641385575
--- /dev/null
+++ b/packages/models/uma-s-1p1-f16.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3dad47e7d35b42b08917bc4ebbca98cc76e32b3df15b6256f817ba3f42b9bdfe
+size 294266951
diff --git a/packages/models/uma-s-1p1-int8.pt b/packages/models/uma-s-1p1-int8.pt
new file mode 100644
index 000000000..a000109a4
--- /dev/null
+++ b/packages/models/uma-s-1p1-int8.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b6fea3cebdb512078d5cb45feaf701b786be7b1fa024cee0a7dd24aafae8fc9
+size 146678700
diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py
index 3aa373ac2..d4eb832d3 100644
--- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py
+++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py
@@ -140,13 +140,66 @@ def patch_torch_linalg():
torch.Tensor.__array__ = _tensor_array_compat
torch.Tensor.numpy = lambda self: np.array(self.detach().tolist())
- # Fix torch.compiler.is_compiling for Pyodide
- if not hasattr(torch, "compiler"):
- torch.compiler = types.ModuleType("torch.compiler")
- if not hasattr(torch.compiler, "is_compiling"):
- torch.compiler.is_compiling = lambda: False
+ # torch.from_numpy — WASM PyTorch build lacks NumPy interop
+ _orig_from_numpy = torch.from_numpy
+
+ def _patched_from_numpy(ndarray):
+ try:
+ return _orig_from_numpy(ndarray)
+ except RuntimeError:
+ return torch.tensor(ndarray.tolist())
+
+ torch.from_numpy = _patched_from_numpy
+
+ # torch.as_tensor — also uses numpy interop internally
+ _orig_as_tensor = torch.as_tensor
+
+ def _patched_as_tensor(data, dtype=None, device=None):
+ try:
+ return _orig_as_tensor(data, dtype=dtype, device=device)
+ except (RuntimeError, TypeError):
+ if hasattr(data, "tolist"):
+ return torch.tensor(data.tolist(), dtype=dtype, device=device)
+ return torch.tensor(data, dtype=dtype, device=device)
+
+ torch.as_tensor = _patched_as_tensor
+
+ # Tensor indexing — WASM PyTorch can't infer numpy dtypes for boolean/integer masks
+ _orig_getitem = torch.Tensor.__getitem__
+
+ def _patched_getitem(self, key):
+ if isinstance(key, np.ndarray):
+ key = torch.tensor(key.tolist())
+ return _orig_getitem(self, key)
- print("✓ Torch linalg patches applied")
+ torch.Tensor.__getitem__ = _patched_getitem
+
+ _orig_setitem = torch.Tensor.__setitem__
+
+ def _patched_setitem(self, key, value):
+ if isinstance(key, np.ndarray):
+ key = torch.tensor(key.tolist())
+ return _orig_setitem(self, key, value)
+
+ torch.Tensor.__setitem__ = _patched_setitem
+
+ # torch.tensor — handle lists of 0-d tensors (WASM calls len() on elements, fails for 0-d)
+ _orig_torch_tensor = torch.tensor
+
+ def _patched_torch_tensor(data, *args, **kwargs):
+ if isinstance(data, (list, tuple)):
+
+ def _unwrap(item):
+ if isinstance(item, torch.Tensor) and item.ndim == 0:
+ return item.item()
+ return item
+
+ data = type(data)(_unwrap(x) for x in data)
+ return _orig_torch_tensor(data, *args, **kwargs)
+
+ torch.tensor = _patched_torch_tensor
+
+ print("✓ Torch linalg + numpy interop patches applied")
# ==============================================================================
@@ -189,9 +242,642 @@ def patch_torch_testing():
sys.modules["torch.testing._internal.common_utils"] = _common_utils
sys.modules["torch.testing._internal.logging_tensor"] = _logging_tensor
+ # Import torch.utils.checkpoint AFTER logging_tensor stubs are set up
+ # (checkpoint.py imports logging_tensor at import time)
+ import torch.utils.checkpoint # noqa: F401
+
print("✓ Torch testing patches applied")
+# ==============================================================================
+# Torch compiler patches
+# ==============================================================================
+
+
+def patch_torch_compiler():
+ """
+ Patch torch.compiler for Pyodide WASM.
+
+ - Stubs torch.compiler.is_compiling (not available in WASM build)
+ - Makes torch.compiler.disable a no-op decorator (avoids torch._dynamo
+ import chain which requires C extensions missing in WASM)
+ - Needed by e3nn >= 0.5 and fairchem-core which use @torch.compiler.disable
+ """
+ if not hasattr(torch, "compiler"):
+ torch.compiler = types.ModuleType("torch.compiler")
+ if not hasattr(torch.compiler, "is_compiling"):
+ torch.compiler.is_compiling = lambda: False
+
+ def _compiler_disable(fn=None, recursive=True):
+ if fn is not None:
+ return fn
+ return lambda f: f
+
+ torch.compiler.disable = _compiler_disable
+
+ print("✓ Torch compiler patches applied")
+
+
+# ==============================================================================
+# Torch distributed patches (for FAIRChem)
+# ==============================================================================
+
+
+def patch_torch_distributed():
+ """
+ Stub torch.distributed modules that require C extensions missing in WASM.
+
+ FAIRChem's mlip_unit.py imports from torch.distributed.checkpoint,
+ torch.distributed.fsdp, and torch.distributed.device_mesh.
+ All of these trigger torch._C._distributed_c10d which doesn't exist in WASM.
+ """
+ import enum
+
+ import torch.distributed as _dist
+
+ # Core distributed API stubs
+ if not hasattr(_dist, "group"):
+
+ class _Group:
+ WORLD = None
+
+ _dist.group = _Group
+ _dist.WORLD = None
+ if not hasattr(_dist, "ReduceOp"):
+
+ class _ReduceOp:
+ SUM = 0
+ PRODUCT = 1
+ MIN = 2
+ MAX = 3
+ BAND = 4
+ BOR = 5
+ BXOR = 6
+ AVG = 7
+
+ _dist.ReduceOp = _ReduceOp
+ if not hasattr(_dist, "is_initialized"):
+ _dist.is_initialized = lambda: False
+ if not hasattr(_dist, "get_rank"):
+ _dist.get_rank = lambda group=None: 0
+ if not hasattr(_dist, "get_world_size"):
+ _dist.get_world_size = lambda group=None: 1
+
+ # torch.distributed.nn.functional
+ _dist_nn = types.ModuleType("torch.distributed.nn")
+ _dist_nn.__path__ = []
+ _dist_nn.__package__ = "torch.distributed.nn"
+ _dist_nn_func = types.ModuleType("torch.distributed.nn.functional")
+ _dist_nn_func.__package__ = "torch.distributed.nn"
+ _dist_nn_func.all_reduce = lambda tensor, *a, **k: tensor
+ _dist_nn_func.reduce_scatter = lambda output, input_list, *a, **k: output
+ _dist_nn_func.all_gather = lambda tensor_list, tensor, *a, **k: tensor_list
+ _dist_nn.functional = _dist_nn_func
+ sys.modules["torch.distributed.nn"] = _dist_nn
+ sys.modules["torch.distributed.nn.functional"] = _dist_nn_func
+
+ # C extension stubs
+ sys.modules["torch._C._distributed_c10d"] = types.ModuleType("torch._C._distributed_c10d")
+ _dc10d = types.ModuleType("torch.distributed.distributed_c10d")
+ _dc10d.__package__ = "torch.distributed"
+ sys.modules["torch.distributed.distributed_c10d"] = _dc10d
+
+ # torch.distributed._shard
+ _shard = types.ModuleType("torch.distributed._shard")
+ _shard.__path__ = []
+ _shard.__package__ = "torch.distributed._shard"
+ sys.modules["torch.distributed._shard"] = _shard
+ sys.modules["torch.distributed._shard.api"] = types.ModuleType("torch.distributed._shard.api")
+ _shard_st = types.ModuleType("torch.distributed._shard.sharded_tensor")
+ _shard_st.__path__ = []
+ sys.modules["torch.distributed._shard.sharded_tensor"] = _shard_st
+ _shard_st_meta = types.ModuleType("torch.distributed._shard.sharded_tensor.metadata")
+
+ class _TensorProperties:
+ pass
+
+ _shard_st_meta.TensorProperties = _TensorProperties
+ sys.modules["torch.distributed._shard.sharded_tensor.metadata"] = _shard_st_meta
+
+ # torch.distributed.checkpoint
+ _dcp = types.ModuleType("torch.distributed.checkpoint")
+ _dcp.__path__ = []
+ _dcp.__package__ = "torch.distributed.checkpoint"
+ _dcp.save = lambda *a, **k: None
+ _dcp.load = lambda *a, **k: None
+ _dcp_meta = types.ModuleType("torch.distributed.checkpoint.metadata")
+ _dcp_meta.TensorProperties = _TensorProperties
+
+ class _BytesStorageMetadata:
+ pass
+
+ class _TensorStorageMetadata:
+ pass
+
+ class _Metadata:
+ pass
+
+ _dcp_meta.BytesStorageMetadata = _BytesStorageMetadata
+ _dcp_meta.TensorStorageMetadata = _TensorStorageMetadata
+ _dcp_meta.Metadata = _Metadata
+ _dcp.metadata = _dcp_meta
+ sys.modules["torch.distributed.checkpoint"] = _dcp
+ sys.modules["torch.distributed.checkpoint.metadata"] = _dcp_meta
+
+ for sub in [
+ "state_dict",
+ "stateful",
+ "planner",
+ "storage",
+ "default_planner",
+ "filesystem",
+ "optimizer",
+ "format_utils",
+ ]:
+ _sub_mod = types.ModuleType(f"torch.distributed.checkpoint.{sub}")
+ _sub_mod.__package__ = "torch.distributed.checkpoint"
+ sys.modules[f"torch.distributed.checkpoint.{sub}"] = _sub_mod
+ setattr(_dcp, sub, _sub_mod)
+
+ # format_utils stubs
+ sys.modules["torch.distributed.checkpoint.format_utils"].dcp_to_torch_save = lambda *a, **k: None
+ sys.modules["torch.distributed.checkpoint.format_utils"].torch_save_to_dcp = lambda *a, **k: None
+
+ # state_dict stubs
+ _sd_mod = sys.modules["torch.distributed.checkpoint.state_dict"]
+ _sd_mod.get_model_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {}
+ _sd_mod.set_model_state_dict = (
+ lambda model, sd, *a, **k: model.load_state_dict(sd) if hasattr(model, "load_state_dict") else None
+ )
+ _sd_mod.get_optimizer_state_dict = (
+ lambda model, optim, *a, **k: optim.state_dict() if hasattr(optim, "state_dict") else {}
+ )
+ _sd_mod.get_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {}
+ _sd_mod.set_state_dict = lambda model, sd, *a, **k: None
+
+ class _StateDictOptions:
+ def __init__(self, **k):
+ self.__dict__.update(k)
+
+ _sd_mod.StateDictOptions = _StateDictOptions
+
+ # stateful stub
+ class _Stateful:
+ pass
+
+ sys.modules["torch.distributed.checkpoint.stateful"].Stateful = _Stateful
+
+ # torch.distributed.fsdp
+ _fsdp = types.ModuleType("torch.distributed.fsdp")
+ _fsdp.__path__ = []
+ _fsdp.__package__ = "torch.distributed.fsdp"
+ sys.modules["torch.distributed.fsdp"] = _fsdp
+ for fsdp_sub in ["fully_sharded_data_parallel", "api", "wrap", "sharded_grad_scaler"]:
+ _fsub = types.ModuleType(f"torch.distributed.fsdp.{fsdp_sub}")
+ _fsub.__package__ = "torch.distributed.fsdp"
+ sys.modules[f"torch.distributed.fsdp.{fsdp_sub}"] = _fsub
+ setattr(_fsdp, fsdp_sub, _fsub)
+
+ # FSDP wrap policy stubs
+ class _ModuleWrapPolicy:
+ def __init__(self, module_classes=None):
+ self.module_classes = module_classes or set()
+
+ _fsdp_wrap = sys.modules["torch.distributed.fsdp.wrap"]
+ _fsdp_wrap.ModuleWrapPolicy = _ModuleWrapPolicy
+ _fsdp_wrap.lambda_auto_wrap_policy = lambda *a, **k: None
+ _fsdp_wrap.transformer_auto_wrap_policy = lambda *a, **k: None
+
+ # FSDP classes
+ class _FullyShardedDataParallel(torch.nn.Module):
+ def __init__(self, module, *a, **k):
+ super().__init__()
+ self.module = module
+
+ class _ShardingStrategy(enum.Enum):
+ FULL_SHARD = "FULL_SHARD"
+ SHARD_GRAD_OP = "SHARD_GRAD_OP"
+ NO_SHARD = "NO_SHARD"
+ HYBRID_SHARD = "HYBRID_SHARD"
+
+ class _MixedPrecision:
+ def __init__(self, *a, **k):
+ pass
+
+ class _CPUOffload:
+ def __init__(self, offload_params=False):
+ self.offload_params = offload_params
+
+ class _BackwardPrefetch(enum.Enum):
+ BACKWARD_PRE = "BACKWARD_PRE"
+ BACKWARD_POST = "BACKWARD_POST"
+
+ class _StateDictTypeFSDP(enum.Enum):
+ FULL_STATE_DICT = 0
+ LOCAL_STATE_DICT = 1
+ SHARDED_STATE_DICT = 2
+
+ _fsdp.FullyShardedDataParallel = _FullyShardedDataParallel
+ _fsdp.ShardingStrategy = _ShardingStrategy
+ _fsdp.MixedPrecision = _MixedPrecision
+ _fsdp.CPUOffload = _CPUOffload
+ _fsdp.BackwardPrefetch = _BackwardPrefetch
+ _fsdp.StateDictType = _StateDictTypeFSDP
+
+ # FSDP StateDictConfig stubs
+ class _StateDictConfig:
+ def __init__(self, **k):
+ self.__dict__.update(k)
+
+ class _ShardedStateDictConfig(_StateDictConfig):
+ pass
+
+ class _FullStateDictConfig(_StateDictConfig):
+ def __init__(self, offload_to_cpu=False, rank0_only=False, **k):
+ super().__init__(**k)
+ self.offload_to_cpu = offload_to_cpu
+ self.rank0_only = rank0_only
+
+ class _FullOptimStateDictConfig(_StateDictConfig):
+ def __init__(self, offload_to_cpu=False, rank0_only=False, **k):
+ super().__init__(**k)
+
+ _fsdp.StateDictConfig = _StateDictConfig
+ _fsdp.ShardedStateDictConfig = _ShardedStateDictConfig
+ _fsdp.FullStateDictConfig = _FullStateDictConfig
+ _fsdp.FullOptimStateDictConfig = _FullOptimStateDictConfig
+ _fsdp.LocalStateDictConfig = type("LocalStateDictConfig", (_StateDictConfig,), {})
+ _fsdp.OptimStateDictConfig = type("OptimStateDictConfig", (_StateDictConfig,), {})
+ _fsdp.ShardedOptimStateDictConfig = type("ShardedOptimStateDictConfig", (_StateDictConfig,), {})
+
+ # torch.distributed.device_mesh
+ _dm = types.ModuleType("torch.distributed.device_mesh")
+ _dm.__package__ = "torch.distributed"
+
+ class _DeviceMesh:
+ def __init__(self, *a, **k):
+ pass
+
+ _dm.DeviceMesh = _DeviceMesh
+ _dm.init_device_mesh = lambda *a, **k: _DeviceMesh()
+ sys.modules["torch.distributed.device_mesh"] = _dm
+
+ # torch.distributed.tensor (DTensor)
+ _dtensor = types.ModuleType("torch.distributed.tensor")
+ _dtensor.__path__ = []
+ _dtensor.__package__ = "torch.distributed.tensor"
+
+ class _DTensor:
+ pass
+
+ _dtensor.DTensor = _DTensor
+ sys.modules["torch.distributed.tensor"] = _dtensor
+
+ # torch.distributed.algorithms
+ _dalgo = types.ModuleType("torch.distributed.algorithms")
+ _dalgo.__path__ = []
+ _dalgo.__package__ = "torch.distributed.algorithms"
+ sys.modules["torch.distributed.algorithms"] = _dalgo
+
+ print("✓ Torch distributed patches applied")
+
+
+# ==============================================================================
+# FAIRChem heavy dependency stubs
+# ==============================================================================
+
+
+def _make_stub_module(name, attrs=None, submodules=None):
+ """Create a stub module with optional attributes and submodules."""
+ from importlib.machinery import ModuleSpec
+
+ mod = types.ModuleType(name)
+ mod.__path__ = []
+ mod.__package__ = name
+ mod.__version__ = "0.0.0"
+ mod.__spec__ = ModuleSpec(name, None, is_package=True)
+ if attrs:
+ for k, v in attrs.items():
+ setattr(mod, k, v)
+ sys.modules[name] = mod
+ if submodules:
+ for sub_name in submodules:
+ full_name = f"{name}.{sub_name}"
+ sub_mod = types.ModuleType(full_name)
+ sub_mod.__path__ = []
+ sub_mod.__package__ = name
+ sub_mod.__spec__ = ModuleSpec(full_name, None, is_package=True)
+ setattr(mod, sub_name, sub_mod)
+ sys.modules[full_name] = sub_mod
+ return mod
+
+
+def patch_fairchem_deps():
+ """
+ Stub heavy dependencies that fairchem-core imports but doesn't need for inference.
+
+ This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf,
+ submitit, clusterscope, tqdm, huggingface_hub, websockets.
+ """
+ # --- numba ---
+ numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"])
+ numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f
+ numba_mod.jit = numba_mod.njit
+ numba_mod.prange = range
+ for t in ("int32", "int64", "float32", "float64", "boolean"):
+ setattr(numba_mod, t, t)
+
+ class _TypedList(list):
+ pass
+
+ sys.modules["numba.typed"].List = _TypedList
+
+ # --- ray ---
+ ray_mod = _make_stub_module(
+ "ray",
+ submodules=[
+ "serve",
+ "runtime_env",
+ "train",
+ "data",
+ "util",
+ "util.scheduling_strategies",
+ "util.queue",
+ ],
+ )
+
+ def _ray_remote(*args, **kwargs):
+ if args and callable(args[0]):
+ args[0]._remote = lambda *a, **kw: None
+ return args[0]
+
+ def wrapper(fn_or_cls):
+ fn_or_cls._remote = lambda *a, **kw: None
+ return fn_or_cls
+
+ return wrapper
+
+ ray_mod.remote = _ray_remote
+ ray_mod.init = lambda *a, **k: None
+ ray_mod.get = lambda *a, **k: None
+ ray_mod.put = lambda *a, **k: None
+ ray_mod.wait = lambda *a, **k: ([], [])
+ ray_mod.is_initialized = lambda: False
+
+ class _ObjectRef:
+ pass
+
+ ray_mod.ObjectRef = _ObjectRef
+
+ class _PlacementGroupSchedulingStrategy:
+ def __init__(self, *a, **k):
+ pass
+
+ sys.modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = (
+ _PlacementGroupSchedulingStrategy
+ )
+
+ # ray.serve stubs
+ _serve = sys.modules["ray.serve"]
+
+ def _serve_deployment(*args, **kwargs):
+ if args and callable(args[0]):
+ return args[0]
+ return lambda cls_or_fn: cls_or_fn
+
+ _serve.deployment = _serve_deployment
+ _serve.ingress = lambda *a, **k: (lambda cls: cls)
+ _serve.run = lambda *a, **k: None
+ _serve.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0]
+
+ _serve_schema = types.ModuleType("ray.serve.schema")
+ _serve_schema.__package__ = "ray.serve"
+
+ class _LoggingConfig:
+ def __init__(self, **k):
+ self.__dict__.update(k)
+
+ _serve_schema.LoggingConfig = _LoggingConfig
+ _serve.schema = _serve_schema
+ sys.modules["ray.serve.schema"] = _serve_schema
+
+ # --- wandb ---
+ wandb_mod = _make_stub_module("wandb")
+ wandb_mod.init = lambda *a, **k: None
+ wandb_mod.log = lambda *a, **k: None
+ wandb_mod.finish = lambda *a, **k: None
+
+ # --- torchtnt ---
+ _make_stub_module(
+ "torchtnt",
+ submodules=[
+ "framework",
+ "framework.state",
+ "framework.unit",
+ "framework.callback",
+ "framework.auto_unit",
+ "framework.fit",
+ "framework.train",
+ "framework.evaluate",
+ "framework.predict",
+ "utils",
+ "utils.loggers",
+ "utils.timer",
+ "utils.distributed",
+ "utils.prepare_module",
+ ],
+ )
+
+ class _PredictUnit:
+ def __init__(self, *a, **k):
+ pass
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ class _TrainUnit:
+ def __init__(self, *a, **k):
+ pass
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ class _EvalUnit:
+ def __init__(self, *a, **k):
+ pass
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ class _State:
+ def __init__(self, *a, **k):
+ pass
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ class _Callback:
+ pass
+
+ tnt_framework = sys.modules["torchtnt.framework"]
+ tnt_unit = sys.modules["torchtnt.framework.unit"]
+ tnt_state = sys.modules["torchtnt.framework.state"]
+ tnt_cb = sys.modules["torchtnt.framework.callback"]
+ for m in [tnt_framework, tnt_unit]:
+ m.PredictUnit = _PredictUnit
+ m.TrainUnit = _TrainUnit
+ m.EvalUnit = _EvalUnit
+ tnt_state.State = _State
+ tnt_cb.Callback = _Callback
+ tnt_framework.State = _State
+ tnt_framework.Callback = _Callback
+
+ # torchtnt entry point functions
+ sys.modules["torchtnt.framework.fit"].fit = lambda *a, **k: None
+ sys.modules["torchtnt.framework.train"].train = lambda *a, **k: None
+ sys.modules["torchtnt.framework.evaluate"].evaluate = lambda *a, **k: None
+ sys.modules["torchtnt.framework.predict"].predict = lambda *a, **k: None
+
+ # torchtnt.utils stubs
+ tnt_dist = sys.modules["torchtnt.utils.distributed"]
+ tnt_dist.get_file_init_method = lambda *a, **k: ""
+ tnt_dist.get_tcp_init_method = lambda *a, **k: ""
+ tnt_dist.spawn_multi_process = lambda *a, **k: None
+
+ tnt_prep = sys.modules["torchtnt.utils.prepare_module"]
+ tnt_prep.prepare_module = lambda module, *a, **k: module
+ tnt_prep.FSDPStrategy = type("FSDPStrategy", (), {"__init__": lambda self, **k: None})
+ tnt_prep.DDPStrategy = type("DDPStrategy", (), {"__init__": lambda self, **k: None})
+ tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None})
+
+ # --- hydra / omegaconf ---
+ omegaconf_mod = _make_stub_module("omegaconf")
+
+ class _DictConfig(dict):
+ pass
+
+ class _ListConfig(list):
+ pass
+
+ omegaconf_mod.DictConfig = _DictConfig
+ omegaconf_mod.ListConfig = _ListConfig
+ omegaconf_mod.OmegaConf = type(
+ "OmegaConf",
+ (),
+ {
+ "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg),
+ "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d),
+ },
+ )
+ _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"])
+
+ def _hydra_instantiate(config, *args, _recursive_=True, **kwargs):
+ import importlib
+ if isinstance(config, dict) and "_target_" in config:
+ target = config["_target_"]
+ mod_path, cls_name = target.rsplit(".", 1)
+ mod = importlib.import_module(mod_path)
+ cls = getattr(mod, cls_name)
+ pos_args = list(args) + list(config.get("_args_", []))
+ cfg = {k: v for k, v in config.items() if not k.startswith("_")}
+ if _recursive_:
+ for k, v in cfg.items():
+ if isinstance(v, dict) and "_target_" in v:
+ cfg[k] = _hydra_instantiate(v)
+ elif isinstance(v, list):
+ cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v]
+ pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args]
+ return cls(*pos_args, **{**cfg, **kwargs})
+ return config
+
+ sys.modules["hydra.utils"].instantiate = _hydra_instantiate
+ sys.modules["hydra"].utils = sys.modules["hydra.utils"]
+
+ # --- submitit / clusterscope ---
+ _make_stub_module("submitit")
+ _make_stub_module("clusterscope")
+
+ # --- websockets ---
+ _make_stub_module("websockets")
+
+ # --- tqdm ---
+ tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"])
+
+ def _tqdm_passthrough(iterable=None, *a, **k):
+ return iterable if iterable is not None else iter([])
+
+ tqdm_mod.tqdm = _tqdm_passthrough
+ sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough
+ sys.modules["tqdm.std"].tqdm = _tqdm_passthrough
+
+ # --- huggingface_hub ---
+ hf_mod = _make_stub_module("huggingface_hub", submodules=["utils"])
+ hf_mod.hf_hub_download = lambda *a, **k: ""
+ hf_mod.snapshot_download = lambda *a, **k: ""
+
+ # --- ase_db_backends ---
+ _make_stub_module("ase_db_backends")
+ # --- INT8 quantized model support ---
+ _orig_torch_load = torch.load
+
+ def _int8_aware_torch_load(f, *args, **kwargs):
+ result = _orig_torch_load(f, *args, **kwargs)
+ if isinstance(result, dict) and "quantized_ema_state_dict" in result:
+ import gc as _gc
+ from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint
+
+ print(" Dequantizing INT8 → FP16 (streaming)...")
+ quantized_ema = result.pop("quantized_ema_state_dict")
+ scales = result.pop("quantization_scales")
+ ema_state_dict = {}
+ names = list(quantized_ema.keys())
+ for name in names:
+ tensor = quantized_ema.pop(name)
+ if name in scales:
+ scale = scales.pop(name)
+ ema_state_dict[name] = (tensor.float() * scale.float()).half()
+ del scale
+ else:
+ ema_state_dict[name] = tensor
+ del tensor
+ del quantized_ema, scales
+ _gc.collect()
+ checkpoint = MLIPInferenceCheckpoint(
+ model_config=result["model_config"],
+ model_state_dict=result.get("model_state_dict", {}),
+ ema_state_dict=ema_state_dict,
+ tasks_config=result["tasks_config"],
+ )
+ del result
+ _gc.collect()
+ print(" ✓ Dequantization complete")
+ return checkpoint
+ return result
+
+ torch.load = _int8_aware_torch_load
+
+ # --- Model registry fallback ---
+ try:
+ from fairchem.core.common.registry import registry as _registry
+
+ _orig_get_model = _registry.get_model_class
+
+ def _fallback_get_model_class(name):
+ try:
+ return _orig_get_model(name)
+ except RuntimeError:
+ import importlib as _imp
+ module_path, class_name = name.rsplit(".", 1)
+ mod = _imp.import_module(module_path)
+ return getattr(mod, class_name)
+
+ _registry.get_model_class = _fallback_get_model_class
+ except Exception:
+ pass
+
+ print("✓ FAIRChem dependency stubs applied")
+
+
# ==============================================================================
# Matscipy patches
# ==============================================================================
@@ -273,11 +959,24 @@ def patch_mace_tools():
# ==============================================================================
-def apply_all_patches():
- """Apply all torch and MACE patches for Pyodide in one call."""
+def apply_all_patches(include_fairchem=False):
+ """
+ Apply all torch and model patches for Pyodide in one call.
+
+ Args:
+ include_fairchem: If True, also apply FAIRChem-specific patches
+ (torch.distributed, heavy dependency stubs). Set this when
+ using fairchem-core / UMA models.
+ """
patch_torch_linalg()
+ patch_torch_compiler()
patch_torch_testing()
patch_matscipy()
patch_mace_training()
patch_mace_tools()
+
+ if include_fairchem:
+ patch_torch_distributed()
+ patch_fairchem_deps()
+
print("\n✅ All Pyodide patches applied successfully!")