From dc60d5ab1fff3a623bf9237714482ba04d088159 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Mon, 18 May 2026 21:01:36 -0700 Subject: [PATCH 01/12] experiment: support for uma --- config.yml | 17 + .../relax_structure_with_uma.ipynb | 300 +++++++++ .../notebooks_utils/pyodide/packages/torch.py | 569 +++++++++++++++++- 3 files changed, 878 insertions(+), 8 deletions(-) create mode 100644 other/experiments/jupyterlite/relax_structure_with_uma.ipynb diff --git a/config.yml b/config.yml index da8e0fae1..5d4673df0 100644 --- a/config.yml +++ b/config.yml @@ -101,3 +101,20 @@ notebooks: - ssl - h5py - lmdb + - name: uma + packages_pyodide: + # Packages with dependencies + - opt_einsum + - orjson + - pyyaml + - sqlite3 + # Packages without dependencies (using nodeps: prefix) + - nodeps:opt_einsum_fx + - nodeps:e3nn>=0.5 + - nodeps:ase + - nodeps:monty + - nodeps:fairchem-core + # Stubbed packages (will be patched by torch_pyodide with include_fairchem=True) + - ssl + - h5py + - lmdb diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb new file mode 100644 index 000000000..568451751 --- /dev/null +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9515ed910c085db8", + "metadata": {}, + "source": [ + "# Relax Structure with FAIRChem UMA — Universal Machine-learning Force Field\n", + "\n", + "Use FAIRChem's [UMA](https://github.com/FAIR-Chem/fairchem) (Universal Model for Atoms) to relax crystal structures using machine-learned interatomic potentials.\n", + "\n", + "

Usage

\n", + "\n", + "1. Drop the materials files into the \"uploads\" folder in the JupyterLab file browser\n", + "1. Set Input Parameters below or use the default values\n", + "1. Click \"Run\" > \"Run All\" to run all cells\n", + "1. Wait for the run to complete. Scroll down to view cell results.\n", + "1. Review the relaxation plot and modify parameters as needed\n", + "\n", + "## Methodology\n", + "\n", + "1. Load materials from JSON files and create structure via `mat3ra-made`\n", + "2. Install FAIRChem UMA and apply Pyodide patches\n", + "3. Convert to ASE atoms with `to_ase()`\n", + "4. Relax the structure with FAIRChem UMA and visualize convergence\n", + "5. Compute relaxation energy" + ] + }, + { + "cell_type": "markdown", + "id": "4737d145950b1cc8", + "metadata": {}, + "source": [ + "## 1. Set Input Parameters\n", + "### 1.1. Structure and Relaxation Parameters\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57dc565952aa2e4e", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "FOLDER = \"uploads\"\n", + "STRUCTURE_NAME = \"Interface\" # Name of the structure to load from local file\n", + "\n", + "RELAXATION_PARAMETERS = {\n", + " \"FMAX\": 0.05,\n", + "}\n", + "UMA_TASK = \"s2ef\" # \"s2ef\" (structure to energy/forces) or \"is2re\" (initial structure to relaxed energy)" + ] + }, + { + "cell_type": "markdown", + "id": "4e89f2d820acb1ed", + "metadata": {}, + "source": [ + "## 2. Install Packages" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from mat3ra.notebooks_utils.packages import install_packages\n", + "\n", + "await install_packages(\"made|api_examples|torch|uma\")\n", + "\n", + " from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + "\n", + " apply_all_patches(include_fairchem=True)" + ], + "id": "ece62358982306f2" + }, + { + "cell_type": "markdown", + "id": "f89d3c98ddce2ab5", + "metadata": {}, + "source": [ + "## 3. Load Materials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fd400dace70549e", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from mat3ra.made.material import Material\n", + "from mat3ra.notebooks_utils.material import load_material_from_folder\n", + "from mat3ra.standata.materials import Materials\n", + "\n", + "structure = load_material_from_folder(FOLDER, STRUCTURE_NAME) or Material.create(\n", + " Materials.get_by_name_first_match(STRUCTURE_NAME))" + ] + }, + { + "cell_type": "markdown", + "id": "42f12abf6b65aa2c", + "metadata": {}, + "source": [ + "### 3.1. Visualize Input Structure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcbb4e6c1de21233", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from mat3ra.notebooks_utils.ipython.entity.material.visualize import ViewersEnum, visualize_materials as visualize\n", + "\n", + "visualize([{\"material\": structure, \"title\": structure.name}], viewer=ViewersEnum.wave)\n", + "visualize(structure, repetitions=[1, 1, 1], rotation=\"-90x\")" + ] + }, + { + "cell_type": "markdown", + "id": "e64688fc18c49bb6", + "metadata": {}, + "source": [ + "## 4. Apply Relaxation\n", + "### 4.1. Relax with FAIRChem UMA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d8746a77f71bab5", + "metadata": { + "jupyter": { + "is_executing": true + }, + "trusted": true + }, + "outputs": [], + "source": [ + "import plotly.graph_objs as go\n", + "from IPython.display import display\n", + "from plotly.subplots import make_subplots\n", + "\n", + "from mat3ra.made.tools.convert import to_ase\n", + "from ase.optimize import BFGS\n", + "\n", + "from fairchem.core import FAIRChemCalculator, pretrained_mlip\n", + "\n", + "# Load UMA model\n", + "predictor = pretrained_mlip.get_predict_unit(\"uma-s-1\", device=\"cpu\")\n", + "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n", + "\n", + "ase_structure = to_ase(structure)\n", + "ase_structure.set_calculator(calculator)\n", + "dyn = BFGS(ase_structure)\n", + "\n", + "steps = []\n", + "energies = []\n", + "\n", + "fig = make_subplots(rows=1, cols=1, specs=[[{\"type\": \"scatter\"}]])\n", + "scatter = go.Scatter(x=[], y=[], mode=\"lines+markers\", name=\"Energy\")\n", + "fig.add_trace(scatter)\n", + "fig.update_layout(title_text=\"Real-time Optimization Progress\", xaxis_title=\"Step\", yaxis_title=\"Energy (eV)\")\n", + "\n", + "f = go.FigureWidget(fig)\n", + "display(f)\n", + "\n", + "\n", + "def plotly_callback():\n", + " step = dyn.nsteps\n", + " energy = ase_structure.get_total_energy()\n", + " steps.append(step)\n", + " energies.append(energy)\n", + " print(f\"Step: {step}, Energy: {energy:.4f} eV\")\n", + " with f.batch_update():\n", + " f.data[0].x = steps\n", + " f.data[0].y = energies\n", + "\n", + "\n", + "dyn.attach(plotly_callback, interval=1)\n", + "dyn.run(fmax=RELAXATION_PARAMETERS[\"FMAX\"])\n", + "\n", + "ase_original_structure = to_ase(structure)\n", + "ase_original_structure.set_calculator(calculator)\n", + "ase_final_structure = ase_structure\n", + "\n", + "original_energy = ase_original_structure.get_total_energy()\n", + "relaxed_energy = ase_structure.get_total_energy()\n", + "\n", + "print(f\"The final energy is {float(relaxed_energy):.3f} eV.\")" + ] + }, + { + "cell_type": "markdown", + "id": "abfa372909a96bf8", + "metadata": {}, + "source": [ + "## 5. Analyze Results\n", + "### 5.1. View Structure Before and After Relaxation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9565d0931b198f63", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from mat3ra.made.tools.convert import from_ase\n", + "\n", + "material_original = Material.create(from_ase(ase_original_structure))\n", + "material_relaxed = Material.create(from_ase(ase_final_structure))\n", + "material_original.name = structure.name\n", + "material_relaxed.name = structure.name + \" (UMA Relaxed)\"\n", + "\n", + "visualize(\n", + " [\n", + " {\"material\": material_original, \"title\": material_original.name},\n", + " {\"material\": material_relaxed, \"title\": material_relaxed.name},\n", + " ],\n", + " viewer=ViewersEnum.wave,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e4b49774283e5517", + "metadata": {}, + "source": [ + "### 5.2. Output interlayer distance before and after relaxation\n", + "This requires labels for substrate and film present in the interface structure, which is already done if interface was created with `mat3ra-made`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dd00402bc2e9d59", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from mat3ra.made.tools.analyze.other import get_average_interlayer_distance\n", + "\n", + "SUBSTRATE_TAG = 0\n", + "FILM_TAG = 1\n", + "\n", + "print(\n", + " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")\n", + "print(\n", + " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")" + ] + }, + { + "cell_type": "markdown", + "id": "2f60fdb73e44c09c", + "metadata": {}, + "source": [ + "## References\n", + "\n", + "[1] FAIRChem: https://github.com/FAIR-Chem/fairchem \n", + "[2] UMA — Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", + "[3] mat3ra-made interface builder: https://github.com/Exabyte-io/made " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (Pyodide)", + "language": "python", + "name": "python" + }, + "language_info": { + "codemirror_mode": { + "name": "python", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index 3aa373ac2..90ff256c2 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -140,12 +140,6 @@ def patch_torch_linalg(): torch.Tensor.__array__ = _tensor_array_compat torch.Tensor.numpy = lambda self: np.array(self.detach().tolist()) - # Fix torch.compiler.is_compiling for Pyodide - if not hasattr(torch, "compiler"): - torch.compiler = types.ModuleType("torch.compiler") - if not hasattr(torch.compiler, "is_compiling"): - torch.compiler.is_compiling = lambda: False - print("✓ Torch linalg patches applied") @@ -192,6 +186,552 @@ def patch_torch_testing(): print("✓ Torch testing patches applied") +# ============================================================================== +# Torch compiler patches +# ============================================================================== + + +def patch_torch_compiler(): + """ + Patch torch.compiler for Pyodide WASM. + + - Stubs torch.compiler.is_compiling (not available in WASM build) + - Makes torch.compiler.disable a no-op decorator (avoids torch._dynamo + import chain which requires C extensions missing in WASM) + - Needed by e3nn >= 0.5 and fairchem-core which use @torch.compiler.disable + """ + if not hasattr(torch, "compiler"): + torch.compiler = types.ModuleType("torch.compiler") + if not hasattr(torch.compiler, "is_compiling"): + torch.compiler.is_compiling = lambda: False + + def _compiler_disable(fn=None, recursive=True): + if fn is not None: + return fn + return lambda f: f + + torch.compiler.disable = _compiler_disable + + print("✓ Torch compiler patches applied") + + +# ============================================================================== +# Torch distributed patches (for FAIRChem) +# ============================================================================== + + +def patch_torch_distributed(): + """ + Stub torch.distributed modules that require C extensions missing in WASM. + + FAIRChem's mlip_unit.py imports from torch.distributed.checkpoint, + torch.distributed.fsdp, and torch.distributed.device_mesh. + All of these trigger torch._C._distributed_c10d which doesn't exist in WASM. + """ + import enum + + import torch.distributed as _dist + + # Core distributed API stubs + if not hasattr(_dist, "group"): + + class _Group: + WORLD = None + + _dist.group = _Group + _dist.WORLD = None + if not hasattr(_dist, "ReduceOp"): + + class _ReduceOp: + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + AVG = 7 + + _dist.ReduceOp = _ReduceOp + if not hasattr(_dist, "is_initialized"): + _dist.is_initialized = lambda: False + if not hasattr(_dist, "get_rank"): + _dist.get_rank = lambda group=None: 0 + if not hasattr(_dist, "get_world_size"): + _dist.get_world_size = lambda group=None: 1 + + # torch.distributed.nn.functional + _dist_nn = types.ModuleType("torch.distributed.nn") + _dist_nn.__path__ = [] + _dist_nn.__package__ = "torch.distributed.nn" + _dist_nn_func = types.ModuleType("torch.distributed.nn.functional") + _dist_nn_func.__package__ = "torch.distributed.nn" + _dist_nn_func.all_reduce = lambda tensor, *a, **k: tensor + _dist_nn_func.reduce_scatter = lambda output, input_list, *a, **k: output + _dist_nn_func.all_gather = lambda tensor_list, tensor, *a, **k: tensor_list + _dist_nn.functional = _dist_nn_func + sys.modules["torch.distributed.nn"] = _dist_nn + sys.modules["torch.distributed.nn.functional"] = _dist_nn_func + + # C extension stubs + sys.modules["torch._C._distributed_c10d"] = types.ModuleType("torch._C._distributed_c10d") + _dc10d = types.ModuleType("torch.distributed.distributed_c10d") + _dc10d.__package__ = "torch.distributed" + sys.modules["torch.distributed.distributed_c10d"] = _dc10d + + # torch.distributed._shard + _shard = types.ModuleType("torch.distributed._shard") + _shard.__path__ = [] + _shard.__package__ = "torch.distributed._shard" + sys.modules["torch.distributed._shard"] = _shard + sys.modules["torch.distributed._shard.api"] = types.ModuleType("torch.distributed._shard.api") + _shard_st = types.ModuleType("torch.distributed._shard.sharded_tensor") + _shard_st.__path__ = [] + sys.modules["torch.distributed._shard.sharded_tensor"] = _shard_st + _shard_st_meta = types.ModuleType("torch.distributed._shard.sharded_tensor.metadata") + + class _TensorProperties: + pass + + _shard_st_meta.TensorProperties = _TensorProperties + sys.modules["torch.distributed._shard.sharded_tensor.metadata"] = _shard_st_meta + + # torch.distributed.checkpoint + _dcp = types.ModuleType("torch.distributed.checkpoint") + _dcp.__path__ = [] + _dcp.__package__ = "torch.distributed.checkpoint" + _dcp.save = lambda *a, **k: None + _dcp.load = lambda *a, **k: None + _dcp_meta = types.ModuleType("torch.distributed.checkpoint.metadata") + _dcp_meta.TensorProperties = _TensorProperties + + class _BytesStorageMetadata: + pass + + class _TensorStorageMetadata: + pass + + class _Metadata: + pass + + _dcp_meta.BytesStorageMetadata = _BytesStorageMetadata + _dcp_meta.TensorStorageMetadata = _TensorStorageMetadata + _dcp_meta.Metadata = _Metadata + _dcp.metadata = _dcp_meta + sys.modules["torch.distributed.checkpoint"] = _dcp + sys.modules["torch.distributed.checkpoint.metadata"] = _dcp_meta + + for sub in [ + "state_dict", + "stateful", + "planner", + "storage", + "default_planner", + "filesystem", + "optimizer", + "format_utils", + ]: + _sub_mod = types.ModuleType(f"torch.distributed.checkpoint.{sub}") + _sub_mod.__package__ = "torch.distributed.checkpoint" + sys.modules[f"torch.distributed.checkpoint.{sub}"] = _sub_mod + setattr(_dcp, sub, _sub_mod) + + # format_utils stubs + sys.modules["torch.distributed.checkpoint.format_utils"].dcp_to_torch_save = lambda *a, **k: None + sys.modules["torch.distributed.checkpoint.format_utils"].torch_save_to_dcp = lambda *a, **k: None + + # state_dict stubs + _sd_mod = sys.modules["torch.distributed.checkpoint.state_dict"] + _sd_mod.get_model_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {} + _sd_mod.set_model_state_dict = ( + lambda model, sd, *a, **k: model.load_state_dict(sd) if hasattr(model, "load_state_dict") else None + ) + _sd_mod.get_optimizer_state_dict = ( + lambda model, optim, *a, **k: optim.state_dict() if hasattr(optim, "state_dict") else {} + ) + _sd_mod.get_state_dict = lambda model, *a, **k: model.state_dict() if hasattr(model, "state_dict") else {} + _sd_mod.set_state_dict = lambda model, sd, *a, **k: None + + class _StateDictOptions: + def __init__(self, **k): + self.__dict__.update(k) + + _sd_mod.StateDictOptions = _StateDictOptions + + # stateful stub + class _Stateful: + pass + + sys.modules["torch.distributed.checkpoint.stateful"].Stateful = _Stateful + + # torch.distributed.fsdp + _fsdp = types.ModuleType("torch.distributed.fsdp") + _fsdp.__path__ = [] + _fsdp.__package__ = "torch.distributed.fsdp" + sys.modules["torch.distributed.fsdp"] = _fsdp + for fsdp_sub in ["fully_sharded_data_parallel", "api", "wrap", "sharded_grad_scaler"]: + _fsub = types.ModuleType(f"torch.distributed.fsdp.{fsdp_sub}") + _fsub.__package__ = "torch.distributed.fsdp" + sys.modules[f"torch.distributed.fsdp.{fsdp_sub}"] = _fsub + setattr(_fsdp, fsdp_sub, _fsub) + + # FSDP wrap policy stubs + class _ModuleWrapPolicy: + def __init__(self, module_classes=None): + self.module_classes = module_classes or set() + + _fsdp_wrap = sys.modules["torch.distributed.fsdp.wrap"] + _fsdp_wrap.ModuleWrapPolicy = _ModuleWrapPolicy + _fsdp_wrap.lambda_auto_wrap_policy = lambda *a, **k: None + _fsdp_wrap.transformer_auto_wrap_policy = lambda *a, **k: None + + # FSDP classes + class _FullyShardedDataParallel(torch.nn.Module): + def __init__(self, module, *a, **k): + super().__init__() + self.module = module + + class _ShardingStrategy(enum.Enum): + FULL_SHARD = "FULL_SHARD" + SHARD_GRAD_OP = "SHARD_GRAD_OP" + NO_SHARD = "NO_SHARD" + HYBRID_SHARD = "HYBRID_SHARD" + + class _MixedPrecision: + def __init__(self, *a, **k): + pass + + class _CPUOffload: + def __init__(self, offload_params=False): + self.offload_params = offload_params + + class _BackwardPrefetch(enum.Enum): + BACKWARD_PRE = "BACKWARD_PRE" + BACKWARD_POST = "BACKWARD_POST" + + class _StateDictTypeFSDP(enum.Enum): + FULL_STATE_DICT = 0 + LOCAL_STATE_DICT = 1 + SHARDED_STATE_DICT = 2 + + _fsdp.FullyShardedDataParallel = _FullyShardedDataParallel + _fsdp.ShardingStrategy = _ShardingStrategy + _fsdp.MixedPrecision = _MixedPrecision + _fsdp.CPUOffload = _CPUOffload + _fsdp.BackwardPrefetch = _BackwardPrefetch + _fsdp.StateDictType = _StateDictTypeFSDP + + # FSDP StateDictConfig stubs + class _StateDictConfig: + def __init__(self, **k): + self.__dict__.update(k) + + class _ShardedStateDictConfig(_StateDictConfig): + pass + + class _FullStateDictConfig(_StateDictConfig): + def __init__(self, offload_to_cpu=False, rank0_only=False, **k): + super().__init__(**k) + self.offload_to_cpu = offload_to_cpu + self.rank0_only = rank0_only + + class _FullOptimStateDictConfig(_StateDictConfig): + def __init__(self, offload_to_cpu=False, rank0_only=False, **k): + super().__init__(**k) + + _fsdp.StateDictConfig = _StateDictConfig + _fsdp.ShardedStateDictConfig = _ShardedStateDictConfig + _fsdp.FullStateDictConfig = _FullStateDictConfig + _fsdp.FullOptimStateDictConfig = _FullOptimStateDictConfig + _fsdp.LocalStateDictConfig = type("LocalStateDictConfig", (_StateDictConfig,), {}) + _fsdp.OptimStateDictConfig = type("OptimStateDictConfig", (_StateDictConfig,), {}) + _fsdp.ShardedOptimStateDictConfig = type("ShardedOptimStateDictConfig", (_StateDictConfig,), {}) + + # torch.distributed.device_mesh + _dm = types.ModuleType("torch.distributed.device_mesh") + _dm.__package__ = "torch.distributed" + + class _DeviceMesh: + def __init__(self, *a, **k): + pass + + _dm.DeviceMesh = _DeviceMesh + _dm.init_device_mesh = lambda *a, **k: _DeviceMesh() + sys.modules["torch.distributed.device_mesh"] = _dm + + # torch.distributed.tensor (DTensor) + _dtensor = types.ModuleType("torch.distributed.tensor") + _dtensor.__path__ = [] + _dtensor.__package__ = "torch.distributed.tensor" + + class _DTensor: + pass + + _dtensor.DTensor = _DTensor + sys.modules["torch.distributed.tensor"] = _dtensor + + # torch.distributed.algorithms + _dalgo = types.ModuleType("torch.distributed.algorithms") + _dalgo.__path__ = [] + _dalgo.__package__ = "torch.distributed.algorithms" + sys.modules["torch.distributed.algorithms"] = _dalgo + + print("✓ Torch distributed patches applied") + + +# ============================================================================== +# FAIRChem heavy dependency stubs +# ============================================================================== + + +def _make_stub_module(name, attrs=None, submodules=None): + """Create a stub module with optional attributes and submodules.""" + mod = types.ModuleType(name) + mod.__path__ = [] + mod.__package__ = name + mod.__version__ = "0.0.0" + if attrs: + for k, v in attrs.items(): + setattr(mod, k, v) + sys.modules[name] = mod + if submodules: + for sub_name in submodules: + full_name = f"{name}.{sub_name}" + sub_mod = types.ModuleType(full_name) + sub_mod.__path__ = [] + sub_mod.__package__ = name + setattr(mod, sub_name, sub_mod) + sys.modules[full_name] = sub_mod + return mod + + +def patch_fairchem_deps(): + """ + Stub heavy dependencies that fairchem-core imports but doesn't need for inference. + + This stubs: numba, ray (+ serve), wandb, torchtnt, hydra, omegaconf, + submitit, clusterscope, tqdm, huggingface_hub, websockets. + """ + # --- numba --- + numba_mod = _make_stub_module("numba", submodules=["core", "core.types", "typed"]) + numba_mod.njit = lambda *a, **k: (lambda f: f) if not a or callable(a[0]) else lambda f: f + numba_mod.jit = numba_mod.njit + numba_mod.prange = range + for t in ("int32", "int64", "float32", "float64", "boolean"): + setattr(numba_mod, t, t) + + class _TypedList(list): + pass + + sys.modules["numba.typed"].List = _TypedList + + # --- ray --- + ray_mod = _make_stub_module( + "ray", + submodules=[ + "serve", + "runtime_env", + "train", + "data", + "util", + "util.scheduling_strategies", + "util.queue", + ], + ) + + def _ray_remote(*args, **kwargs): + if args and callable(args[0]): + args[0]._remote = lambda *a, **kw: None + return args[0] + + def wrapper(fn_or_cls): + fn_or_cls._remote = lambda *a, **kw: None + return fn_or_cls + + return wrapper + + ray_mod.remote = _ray_remote + ray_mod.init = lambda *a, **k: None + ray_mod.get = lambda *a, **k: None + ray_mod.put = lambda *a, **k: None + ray_mod.wait = lambda *a, **k: ([], []) + ray_mod.is_initialized = lambda: False + + class _ObjectRef: + pass + + ray_mod.ObjectRef = _ObjectRef + + class _PlacementGroupSchedulingStrategy: + def __init__(self, *a, **k): + pass + + sys.modules["ray.util.scheduling_strategies"].PlacementGroupSchedulingStrategy = ( + _PlacementGroupSchedulingStrategy + ) + + # ray.serve stubs + _serve = sys.modules["ray.serve"] + + def _serve_deployment(*args, **kwargs): + if args and callable(args[0]): + return args[0] + return lambda cls_or_fn: cls_or_fn + + _serve.deployment = _serve_deployment + _serve.ingress = lambda *a, **k: (lambda cls: cls) + _serve.run = lambda *a, **k: None + _serve.batch = lambda *args, **kwargs: (lambda fn: fn) if not args or not callable(args[0]) else args[0] + + _serve_schema = types.ModuleType("ray.serve.schema") + _serve_schema.__package__ = "ray.serve" + + class _LoggingConfig: + def __init__(self, **k): + self.__dict__.update(k) + + _serve_schema.LoggingConfig = _LoggingConfig + _serve.schema = _serve_schema + sys.modules["ray.serve.schema"] = _serve_schema + + # --- wandb --- + wandb_mod = _make_stub_module("wandb") + wandb_mod.init = lambda *a, **k: None + wandb_mod.log = lambda *a, **k: None + wandb_mod.finish = lambda *a, **k: None + + # --- torchtnt --- + _make_stub_module( + "torchtnt", + submodules=[ + "framework", + "framework.state", + "framework.unit", + "framework.callback", + "framework.auto_unit", + "framework.fit", + "framework.train", + "framework.evaluate", + "framework.predict", + "utils", + "utils.loggers", + "utils.timer", + "utils.distributed", + "utils.prepare_module", + ], + ) + + class _PredictUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _TrainUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _EvalUnit: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _State: + def __init__(self, *a, **k): + pass + + def __class_getitem__(cls, item): + return cls + + class _Callback: + pass + + tnt_framework = sys.modules["torchtnt.framework"] + tnt_unit = sys.modules["torchtnt.framework.unit"] + tnt_state = sys.modules["torchtnt.framework.state"] + tnt_cb = sys.modules["torchtnt.framework.callback"] + for m in [tnt_framework, tnt_unit]: + m.PredictUnit = _PredictUnit + m.TrainUnit = _TrainUnit + m.EvalUnit = _EvalUnit + tnt_state.State = _State + tnt_cb.Callback = _Callback + tnt_framework.State = _State + tnt_framework.Callback = _Callback + + # torchtnt entry point functions + sys.modules["torchtnt.framework.fit"].fit = lambda *a, **k: None + sys.modules["torchtnt.framework.train"].train = lambda *a, **k: None + sys.modules["torchtnt.framework.evaluate"].evaluate = lambda *a, **k: None + sys.modules["torchtnt.framework.predict"].predict = lambda *a, **k: None + + # torchtnt.utils stubs + tnt_dist = sys.modules["torchtnt.utils.distributed"] + tnt_dist.get_file_init_method = lambda *a, **k: "" + tnt_dist.get_tcp_init_method = lambda *a, **k: "" + tnt_dist.spawn_multi_process = lambda *a, **k: None + + tnt_prep = sys.modules["torchtnt.utils.prepare_module"] + tnt_prep.prepare_module = lambda module, *a, **k: module + tnt_prep.FSDPStrategy = type("FSDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.DDPStrategy = type("DDPStrategy", (), {"__init__": lambda self, **k: None}) + tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) + + # --- hydra / omegaconf --- + omegaconf_mod = _make_stub_module("omegaconf") + + class _DictConfig(dict): + pass + + class _ListConfig(list): + pass + + omegaconf_mod.DictConfig = _DictConfig + omegaconf_mod.ListConfig = _ListConfig + omegaconf_mod.OmegaConf = type( + "OmegaConf", + (), + { + "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), + "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), + }, + ) + _make_stub_module("hydra", submodules=["core", "core.global_hydra"]) + + # --- submitit / clusterscope --- + _make_stub_module("submitit") + _make_stub_module("clusterscope") + + # --- websockets --- + _make_stub_module("websockets") + + # --- tqdm --- + tqdm_mod = _make_stub_module("tqdm", submodules=["auto", "std"]) + + def _tqdm_passthrough(iterable=None, *a, **k): + return iterable if iterable is not None else iter([]) + + tqdm_mod.tqdm = _tqdm_passthrough + sys.modules["tqdm.auto"].tqdm = _tqdm_passthrough + sys.modules["tqdm.std"].tqdm = _tqdm_passthrough + + # --- huggingface_hub --- + hf_mod = _make_stub_module("huggingface_hub", submodules=["utils"]) + hf_mod.hf_hub_download = lambda *a, **k: "" + hf_mod.snapshot_download = lambda *a, **k: "" + + # --- ase_db_backends --- + _make_stub_module("ase_db_backends") + + print("✓ FAIRChem dependency stubs applied") + + # ============================================================================== # Matscipy patches # ============================================================================== @@ -273,11 +813,24 @@ def patch_mace_tools(): # ============================================================================== -def apply_all_patches(): - """Apply all torch and MACE patches for Pyodide in one call.""" +def apply_all_patches(include_fairchem=False): + """ + Apply all torch and model patches for Pyodide in one call. + + Args: + include_fairchem: If True, also apply FAIRChem-specific patches + (torch.distributed, heavy dependency stubs). Set this when + using fairchem-core / UMA models. + """ patch_torch_linalg() + patch_torch_compiler() patch_torch_testing() patch_matscipy() patch_mace_training() patch_mace_tools() + + if include_fairchem: + patch_torch_distributed() + patch_fairchem_deps() + print("\n✅ All Pyodide patches applied successfully!") From 9ad7737009eddc3839dfb8cedba375dfacabdd10 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 11:16:31 -0700 Subject: [PATCH 02/12] chore: update uma logic --- config.yml | 4 + .../relax_structure_with_uma.ipynb | 78 +++++++++++++++---- .../notebooks_utils/pyodide/packages/torch.py | 69 +++++++++++++++- 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/config.yml b/config.yml index 5d4673df0..3beec21b9 100644 --- a/config.yml +++ b/config.yml @@ -103,11 +103,15 @@ notebooks: - lmdb - name: uma packages_pyodide: + # antlr4 local wheel (required by omegaconf, PyPI version has ATN mismatch) + - emfs:/drive/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl # Packages with dependencies - opt_einsum - orjson - pyyaml - sqlite3 + - omegaconf + - hydra-core # Packages without dependencies (using nodeps: prefix) - nodeps:opt_einsum_fx - nodeps:e3nn>=0.5 diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index 568451751..c7ade9483 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -50,7 +50,19 @@ "RELAXATION_PARAMETERS = {\n", " \"FMAX\": 0.05,\n", "}\n", - "UMA_TASK = \"s2ef\" # \"s2ef\" (structure to energy/forces) or \"is2re\" (initial structure to relaxed energy)" + "UMA_TASK = \"omat\" # Task name for the UMA model (e.g. \"omat\", \"oc20\", \"omol\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e9d9e76d619cd91", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-f16.pt\" # FP16 inference-only checkpoint (~294 MB)" ] }, { @@ -63,20 +75,31 @@ }, { "cell_type": "code", - "metadata": {}, - "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "ece62358982306f2", + "metadata": { + "trusted": true + }, + "outputs": [], "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "\n", - "await install_packages(\"made|api_examples|torch|uma\")\n", - "\n", - " from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", + "await install_packages(\"made|api_examples|torch|uma\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7f8e5e6a1b34d90", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", "\n", - " apply_all_patches(include_fairchem=True)" - ], - "id": "ece62358982306f2" + "apply_all_patches(include_fairchem=True)" + ] }, { "cell_type": "markdown", @@ -132,7 +155,34 @@ "metadata": {}, "source": [ "## 4. Apply Relaxation\n", - "### 4.1. Relax with FAIRChem UMA" + "### 4.1. Load UMA Model and Create Calculator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1b2c3d4e5f60001", + "metadata": { + "trusted": true + }, + "outputs": [], + "source": [ + "from fairchem.core import FAIRChemCalculator\n", + "from fairchem.core.units.mlip_unit import load_predict_unit\n", + "\n", + "predictor = load_predict_unit(UMA_MODEL_PATH, device=\"cpu\")\n", + "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n", + "\n", + "print(f\"UMA model loaded from {UMA_MODEL_PATH}\")\n", + "print(f\"Task: {UMA_TASK}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d4e5f6a7b8c90002", + "metadata": {}, + "source": [ + "### 4.2. Relax with UMA" ] }, { @@ -154,12 +204,6 @@ "from mat3ra.made.tools.convert import to_ase\n", "from ase.optimize import BFGS\n", "\n", - "from fairchem.core import FAIRChemCalculator, pretrained_mlip\n", - "\n", - "# Load UMA model\n", - "predictor = pretrained_mlip.get_predict_unit(\"uma-s-1\", device=\"cpu\")\n", - "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n", - "\n", "ase_structure = to_ase(structure)\n", "ase_structure.set_calculator(calculator)\n", "dyn = BFGS(ase_structure)\n", diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index 90ff256c2..cc885bec6 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -140,7 +140,66 @@ def patch_torch_linalg(): torch.Tensor.__array__ = _tensor_array_compat torch.Tensor.numpy = lambda self: np.array(self.detach().tolist()) - print("✓ Torch linalg patches applied") + # torch.from_numpy — WASM PyTorch build lacks NumPy interop + _orig_from_numpy = torch.from_numpy + + def _patched_from_numpy(ndarray): + try: + return _orig_from_numpy(ndarray) + except RuntimeError: + return torch.tensor(ndarray.tolist()) + + torch.from_numpy = _patched_from_numpy + + # torch.as_tensor — also uses numpy interop internally + _orig_as_tensor = torch.as_tensor + + def _patched_as_tensor(data, dtype=None, device=None): + try: + return _orig_as_tensor(data, dtype=dtype, device=device) + except (RuntimeError, TypeError): + if hasattr(data, "tolist"): + return torch.tensor(data.tolist(), dtype=dtype, device=device) + return torch.tensor(data, dtype=dtype, device=device) + + torch.as_tensor = _patched_as_tensor + + # Tensor indexing — WASM PyTorch can't infer numpy dtypes for boolean/integer masks + _orig_getitem = torch.Tensor.__getitem__ + + def _patched_getitem(self, key): + if isinstance(key, np.ndarray): + key = torch.tensor(key.tolist()) + return _orig_getitem(self, key) + + torch.Tensor.__getitem__ = _patched_getitem + + _orig_setitem = torch.Tensor.__setitem__ + + def _patched_setitem(self, key, value): + if isinstance(key, np.ndarray): + key = torch.tensor(key.tolist()) + return _orig_setitem(self, key, value) + + torch.Tensor.__setitem__ = _patched_setitem + + # torch.tensor — handle lists of 0-d tensors (WASM calls len() on elements, fails for 0-d) + _orig_torch_tensor = torch.tensor + + def _patched_torch_tensor(data, *args, **kwargs): + if isinstance(data, (list, tuple)): + + def _unwrap(item): + if isinstance(item, torch.Tensor) and item.ndim == 0: + return item.item() + return item + + data = type(data)(_unwrap(x) for x in data) + return _orig_torch_tensor(data, *args, **kwargs) + + torch.tensor = _patched_torch_tensor + + print("✓ Torch linalg + numpy interop patches applied") # ============================================================================== @@ -183,6 +242,10 @@ def patch_torch_testing(): sys.modules["torch.testing._internal.common_utils"] = _common_utils sys.modules["torch.testing._internal.logging_tensor"] = _logging_tensor + # Import torch.utils.checkpoint AFTER logging_tensor stubs are set up + # (checkpoint.py imports logging_tensor at import time) + import torch.utils.checkpoint # noqa: F401 + print("✓ Torch testing patches applied") @@ -486,10 +549,13 @@ class _DTensor: def _make_stub_module(name, attrs=None, submodules=None): """Create a stub module with optional attributes and submodules.""" + from importlib.machinery import ModuleSpec + mod = types.ModuleType(name) mod.__path__ = [] mod.__package__ = name mod.__version__ = "0.0.0" + mod.__spec__ = ModuleSpec(name, None, is_package=True) if attrs: for k, v in attrs.items(): setattr(mod, k, v) @@ -500,6 +566,7 @@ def _make_stub_module(name, attrs=None, submodules=None): sub_mod = types.ModuleType(full_name) sub_mod.__path__ = [] sub_mod.__package__ = name + sub_mod.__spec__ = ModuleSpec(full_name, None, is_package=True) setattr(mod, sub_name, sub_mod) sys.modules[full_name] = sub_mod return mod From a72ee4f57a807dcc89c8bd592afe88ac6efbfc4e Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 11:22:50 -0700 Subject: [PATCH 03/12] chore: antlr package --- packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl diff --git a/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl new file mode 100644 index 000000000..2a6ca7233 --- /dev/null +++ b/packages/antlr4_python3_runtime-4.9.3-py3-none-any.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59f990d0c30e35b3ec80e0daa1c2aff9bd3115794fcb8f119125c9235e2e1178 +size 144573 From 1b9eba2b56479df8792676e813ba1ec2e5615b0c Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 11:28:26 -0700 Subject: [PATCH 04/12] chore: git attributes --- .gitattributes | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitattributes b/.gitattributes index 790431e68..7f638c094 100644 --- a/.gitattributes +++ b/.gitattributes @@ -17,4 +17,5 @@ images/*.png filter=lfs diff=lfs merge=lfs -text examples/assets/bash_workflow_template.json filter=lfs diff=lfs merge=lfs -text *.whl filter=lfs diff=lfs merge=lfs -text *.model filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text packages filter=lfs diff=lfs merge=lfs -text From fc7eb528f725f05cdd04341502b16b1abe9086f9 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 11:28:39 -0700 Subject: [PATCH 05/12] chore: f16 pt file --- packages/models/uma-s-1p1-f16.pt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 packages/models/uma-s-1p1-f16.pt diff --git a/packages/models/uma-s-1p1-f16.pt b/packages/models/uma-s-1p1-f16.pt new file mode 100644 index 000000000..641385575 --- /dev/null +++ b/packages/models/uma-s-1p1-f16.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dad47e7d35b42b08917bc4ebbca98cc76e32b3df15b6256f817ba3f42b9bdfe +size 294266951 From 16f80ec7bc92d9b46a77e8ee90f8ad5ad80663f9 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 12:39:35 -0700 Subject: [PATCH 06/12] chore: update notebook --- config.yml | 21 +++++++++++++++++++ .../relax_structure_with_uma.ipynb | 8 +++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/config.yml b/config.yml index 3beec21b9..f6aeccc67 100644 --- a/config.yml +++ b/config.yml @@ -122,3 +122,24 @@ notebooks: - ssl - h5py - lmdb + - name: uma_relax + packages_pyodide: + # Minimal mat3ra-made dependencies (no pymatgen-analysis-defects, tabulate, uncertainties, nbformat) + - lzma + - annotated_types>=0.6.0 + - networkx==3.2.1 + - monty==2023.11.3 + - scipy==1.11.2 + - sympy==1.12 + - ase==3.25.0 + - ipywidgets + - plotly>=5.18 + - emfs:/drive/packages/pymatgen-2024.4.13-py3-none-any.whl + - emfs:/drive/packages/spglib-2.0.2-py3-none-any.whl + - emfs:/drive/packages/ruamel.yaml-0.17.32-py3-none-any.whl + - emfs:/drive/packages/pydantic_core-2.18.2-py3-none-any.whl + - emfs:/drive/packages/pydantic-2.7.1-py3-none-any.whl + - mat3ra-periodic-table + - mat3ra-made>=2025.12.29.post0 + # For default structure loading (was in api_examples / specific_examples) + - mat3ra-standata diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index c7ade9483..3adb68746 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -84,7 +84,7 @@ "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "\n", - "await install_packages(\"made|api_examples|torch|uma\")" + "await install_packages(\"uma_relax|torch|uma\")" ] }, { @@ -98,7 +98,9 @@ "source": [ "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", "\n", - "apply_all_patches(include_fairchem=True)" + "apply_all_patches(include_fairchem=True)\n", + "\n", + "import gc; gc.collect() # Free memory after package installation and patching" ] }, { @@ -167,11 +169,13 @@ }, "outputs": [], "source": [ + "import gc\n", "from fairchem.core import FAIRChemCalculator\n", "from fairchem.core.units.mlip_unit import load_predict_unit\n", "\n", "predictor = load_predict_unit(UMA_MODEL_PATH, device=\"cpu\")\n", "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n", + "gc.collect() # Free temporary buffers from model deserialization\n", "\n", "print(f\"UMA model loaded from {UMA_MODEL_PATH}\")\n", "print(f\"Task: {UMA_TASK}\")" From dbafc0e0d5947e89577f71e561f106e8f3726b5a Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 13:50:33 -0700 Subject: [PATCH 07/12] chore: update notebook --- .../relax_structure_with_uma.ipynb | 10 +- .../notebooks_utils/pyodide/packages/torch.py | 137 ++++++++++++++++-- 2 files changed, 127 insertions(+), 20 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index 3adb68746..82891af59 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -5,7 +5,7 @@ "id": "9515ed910c085db8", "metadata": {}, "source": [ - "# Relax Structure with FAIRChem UMA — Universal Machine-learning Force Field\n", + "# Relax Structure with FAIRChem UMA \u2014 Universal Machine-learning Force Field\n", "\n", "Use FAIRChem's [UMA](https://github.com/FAIR-Chem/fairchem) (Universal Model for Atoms) to relax crystal structures using machine-learned interatomic potentials.\n", "\n", @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-f16.pt\" # FP16 inference-only checkpoint (~294 MB)" + "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-int8.pt\" # INT8 quantized checkpoint (~148 MB)" ] }, { @@ -306,9 +306,9 @@ "FILM_TAG = 1\n", "\n", "print(\n", - " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")\n", + " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")\n", "print(\n", - " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")" + " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")" ] }, { @@ -319,7 +319,7 @@ "## References\n", "\n", "[1] FAIRChem: https://github.com/FAIR-Chem/fairchem \n", - "[2] UMA — Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", + "[2] UMA \u2014 Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", "[3] mat3ra-made interface builder: https://github.com/Exabyte-io/made " ] } diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index cc885bec6..e574add35 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -751,25 +751,67 @@ class _Callback: tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) # --- hydra / omegaconf --- - omegaconf_mod = _make_stub_module("omegaconf") + # Only stub if the real packages aren't installed + if "omegaconf" not in sys.modules: + omegaconf_mod = _make_stub_module("omegaconf") - class _DictConfig(dict): - pass + class _DictConfig(dict): + pass + + class _ListConfig(list): + pass - class _ListConfig(list): + omegaconf_mod.DictConfig = _DictConfig + omegaconf_mod.ListConfig = _ListConfig + omegaconf_mod.OmegaConf = type( + "OmegaConf", + (), + { + "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), + "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), + }, + ) + # Try to use the real hydra if installed, otherwise stub it. + # Must catch all exceptions since hydra's import chain may fail + # with non-ImportError exceptions in Pyodide (missing C deps, etc.) + _hydra_ok = False + try: + import hydra.utils + hydra.utils.instantiate # verify it actually works + _hydra_ok = True + except Exception: pass - omegaconf_mod.DictConfig = _DictConfig - omegaconf_mod.ListConfig = _ListConfig - omegaconf_mod.OmegaConf = type( - "OmegaConf", - (), - { - "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), - "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), - }, - ) - _make_stub_module("hydra", submodules=["core", "core.global_hydra"]) + if not _hydra_ok: + # Remove any partial hydra from sys.modules + for k in list(sys.modules.keys()): + if k == "hydra" or k.startswith("hydra."): + del sys.modules[k] + hydra_stub = _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) + # Add a minimal instantiate that uses _target_ to construct objects + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + import importlib + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + # Extract positional args from _args_ key + pos_args = list(args) + list(config.get("_args_", [])) + # Filter config to remove hydra special keys + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + # Recursively instantiate nested configs + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [_hydra_instantiate(item) if isinstance(item, dict) and "_target_" in item else item for item in v] + pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] # --- submitit / clusterscope --- _make_stub_module("submitit") @@ -796,6 +838,70 @@ def _tqdm_passthrough(iterable=None, *a, **k): # --- ase_db_backends --- _make_stub_module("ase_db_backends") + # --- matplotlib (imported by fairchem backbone but not needed for inference) --- + if "matplotlib" not in sys.modules: + _make_stub_module("matplotlib", submodules=["pyplot"]) + + # --- INT8 quantized model support --- + # Monkey-patch torch.load so that INT8 quantized checkpoints + # (created by scripts/quantize_uma_model.py) are automatically dequantized + # back to FP16 and returned as MLIPInferenceCheckpoint. + # This makes load_predict_unit() work transparently with INT8 files. + _orig_torch_load = torch.load + + def _int8_aware_torch_load(f, *args, **kwargs): + result = _orig_torch_load(f, *args, **kwargs) + if isinstance(result, dict) and "quantized_ema_state_dict" in result: + import gc as _gc + from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint + + print(" Dequantizing INT8 → FP16...") + quantized_ema = result["quantized_ema_state_dict"] + scales = result["quantization_scales"] + ema_state_dict = {} + for name, tensor in quantized_ema.items(): + if name in scales: + ema_state_dict[name] = (tensor.float() * scales[name].float()).half() + else: + ema_state_dict[name] = tensor + del quantized_ema, scales + checkpoint = MLIPInferenceCheckpoint( + model_config=result["model_config"], + model_state_dict=result.get("model_state_dict", {}), + ema_state_dict=ema_state_dict, + tasks_config=result["tasks_config"], + ) + del result + _gc.collect() + print(" ✓ Dequantization complete") + return checkpoint + return result + + torch.load = _int8_aware_torch_load + + # --- Model registry fallback --- + # In Pyodide, fairchem model modules can't be imported in bulk + # (missing C deps), so the @registry decorators never populate + # model_name_mapping. Patch get_model_class to resolve by full + # dotted import path (e.g. 'fairchem.core.models.uma.escn_moe.eSCNMDMoeBackbone'). + try: + from fairchem.core.common.registry import registry as _registry + + _orig_get_model = _registry.get_model_class + + def _fallback_get_model_class(name): + try: + return _orig_get_model(name) + except RuntimeError: + import importlib as _imp + module_path, class_name = name.rsplit(".", 1) + mod = _imp.import_module(module_path) + return getattr(mod, class_name) + + _registry.get_model_class = _fallback_get_model_class + except Exception: + pass + print("✓ FAIRChem dependency stubs applied") @@ -901,3 +1007,4 @@ def apply_all_patches(include_fairchem=False): patch_fairchem_deps() print("\n✅ All Pyodide patches applied successfully!") + From 8758ac87964547628a3ad790b1ea0156092d9a7e Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 13:51:07 -0700 Subject: [PATCH 08/12] chore: int8 pt file --- packages/models/uma-s-1p1-int8.pt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 packages/models/uma-s-1p1-int8.pt diff --git a/packages/models/uma-s-1p1-int8.pt b/packages/models/uma-s-1p1-int8.pt new file mode 100644 index 000000000..a000109a4 --- /dev/null +++ b/packages/models/uma-s-1p1-int8.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2b6fea3cebdb512078d5cb45feaf701b786be7b1fa024cee0a7dd24aafae8fc9 +size 146678700 From b724ec668317cd7eae25fea49687b4c361270b0e Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 15:48:42 -0700 Subject: [PATCH 09/12] Revert "chore: update notebook" This reverts commit 16f80ec7bc92d9b46a77e8ee90f8ad5ad80663f9. --- config.yml | 21 ------------------- .../relax_structure_with_uma.ipynb | 8 ++----- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/config.yml b/config.yml index f6aeccc67..3beec21b9 100644 --- a/config.yml +++ b/config.yml @@ -122,24 +122,3 @@ notebooks: - ssl - h5py - lmdb - - name: uma_relax - packages_pyodide: - # Minimal mat3ra-made dependencies (no pymatgen-analysis-defects, tabulate, uncertainties, nbformat) - - lzma - - annotated_types>=0.6.0 - - networkx==3.2.1 - - monty==2023.11.3 - - scipy==1.11.2 - - sympy==1.12 - - ase==3.25.0 - - ipywidgets - - plotly>=5.18 - - emfs:/drive/packages/pymatgen-2024.4.13-py3-none-any.whl - - emfs:/drive/packages/spglib-2.0.2-py3-none-any.whl - - emfs:/drive/packages/ruamel.yaml-0.17.32-py3-none-any.whl - - emfs:/drive/packages/pydantic_core-2.18.2-py3-none-any.whl - - emfs:/drive/packages/pydantic-2.7.1-py3-none-any.whl - - mat3ra-periodic-table - - mat3ra-made>=2025.12.29.post0 - # For default structure loading (was in api_examples / specific_examples) - - mat3ra-standata diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index 82891af59..00ee07966 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -84,7 +84,7 @@ "source": [ "from mat3ra.notebooks_utils.packages import install_packages\n", "\n", - "await install_packages(\"uma_relax|torch|uma\")" + "await install_packages(\"made|api_examples|torch|uma\")" ] }, { @@ -98,9 +98,7 @@ "source": [ "from mat3ra.notebooks_utils.pyodide.packages.torch import apply_all_patches\n", "\n", - "apply_all_patches(include_fairchem=True)\n", - "\n", - "import gc; gc.collect() # Free memory after package installation and patching" + "apply_all_patches(include_fairchem=True)" ] }, { @@ -169,13 +167,11 @@ }, "outputs": [], "source": [ - "import gc\n", "from fairchem.core import FAIRChemCalculator\n", "from fairchem.core.units.mlip_unit import load_predict_unit\n", "\n", "predictor = load_predict_unit(UMA_MODEL_PATH, device=\"cpu\")\n", "calculator = FAIRChemCalculator(predictor, task_name=UMA_TASK)\n", - "gc.collect() # Free temporary buffers from model deserialization\n", "\n", "print(f\"UMA model loaded from {UMA_MODEL_PATH}\")\n", "print(f\"Task: {UMA_TASK}\")" From 9e4a89e463fec7cd83ffe2d33ed9588d2cee0da5 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 15:49:01 -0700 Subject: [PATCH 10/12] Revert "chore: update notebook" This reverts commit dbafc0e0d5947e89577f71e561f106e8f3726b5a. --- .../relax_structure_with_uma.ipynb | 10 +- .../notebooks_utils/pyodide/packages/torch.py | 137 ++---------------- 2 files changed, 20 insertions(+), 127 deletions(-) diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index 00ee07966..c7ade9483 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -5,7 +5,7 @@ "id": "9515ed910c085db8", "metadata": {}, "source": [ - "# Relax Structure with FAIRChem UMA \u2014 Universal Machine-learning Force Field\n", + "# Relax Structure with FAIRChem UMA — Universal Machine-learning Force Field\n", "\n", "Use FAIRChem's [UMA](https://github.com/FAIR-Chem/fairchem) (Universal Model for Atoms) to relax crystal structures using machine-learned interatomic potentials.\n", "\n", @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-int8.pt\" # INT8 quantized checkpoint (~148 MB)" + "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-f16.pt\" # FP16 inference-only checkpoint (~294 MB)" ] }, { @@ -302,9 +302,9 @@ "FILM_TAG = 1\n", "\n", "print(\n", - " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")\n", + " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")\n", "print(\n", - " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")" + " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")" ] }, { @@ -315,7 +315,7 @@ "## References\n", "\n", "[1] FAIRChem: https://github.com/FAIR-Chem/fairchem \n", - "[2] UMA \u2014 Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", + "[2] UMA — Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", "[3] mat3ra-made interface builder: https://github.com/Exabyte-io/made " ] } diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index e574add35..cc885bec6 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -751,67 +751,25 @@ class _Callback: tnt_prep.NOOPStrategy = type("NOOPStrategy", (), {"__init__": lambda self, **k: None}) # --- hydra / omegaconf --- - # Only stub if the real packages aren't installed - if "omegaconf" not in sys.modules: - omegaconf_mod = _make_stub_module("omegaconf") + omegaconf_mod = _make_stub_module("omegaconf") - class _DictConfig(dict): - pass - - class _ListConfig(list): - pass + class _DictConfig(dict): + pass - omegaconf_mod.DictConfig = _DictConfig - omegaconf_mod.ListConfig = _ListConfig - omegaconf_mod.OmegaConf = type( - "OmegaConf", - (), - { - "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), - "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), - }, - ) - # Try to use the real hydra if installed, otherwise stub it. - # Must catch all exceptions since hydra's import chain may fail - # with non-ImportError exceptions in Pyodide (missing C deps, etc.) - _hydra_ok = False - try: - import hydra.utils - hydra.utils.instantiate # verify it actually works - _hydra_ok = True - except Exception: + class _ListConfig(list): pass - if not _hydra_ok: - # Remove any partial hydra from sys.modules - for k in list(sys.modules.keys()): - if k == "hydra" or k.startswith("hydra."): - del sys.modules[k] - hydra_stub = _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) - # Add a minimal instantiate that uses _target_ to construct objects - def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): - import importlib - if isinstance(config, dict) and "_target_" in config: - target = config["_target_"] - mod_path, cls_name = target.rsplit(".", 1) - mod = importlib.import_module(mod_path) - cls = getattr(mod, cls_name) - # Extract positional args from _args_ key - pos_args = list(args) + list(config.get("_args_", [])) - # Filter config to remove hydra special keys - cfg = {k: v for k, v in config.items() if not k.startswith("_")} - # Recursively instantiate nested configs - if _recursive_: - for k, v in cfg.items(): - if isinstance(v, dict) and "_target_" in v: - cfg[k] = _hydra_instantiate(v) - elif isinstance(v, list): - cfg[k] = [_hydra_instantiate(item) if isinstance(item, dict) and "_target_" in item else item for item in v] - pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] - return cls(*pos_args, **{**cfg, **kwargs}) - return config - sys.modules["hydra.utils"].instantiate = _hydra_instantiate - sys.modules["hydra"].utils = sys.modules["hydra.utils"] + omegaconf_mod.DictConfig = _DictConfig + omegaconf_mod.ListConfig = _ListConfig + omegaconf_mod.OmegaConf = type( + "OmegaConf", + (), + { + "to_container": staticmethod(lambda cfg, **k: dict(cfg) if isinstance(cfg, dict) else cfg), + "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), + }, + ) + _make_stub_module("hydra", submodules=["core", "core.global_hydra"]) # --- submitit / clusterscope --- _make_stub_module("submitit") @@ -838,70 +796,6 @@ def _tqdm_passthrough(iterable=None, *a, **k): # --- ase_db_backends --- _make_stub_module("ase_db_backends") - # --- matplotlib (imported by fairchem backbone but not needed for inference) --- - if "matplotlib" not in sys.modules: - _make_stub_module("matplotlib", submodules=["pyplot"]) - - # --- INT8 quantized model support --- - # Monkey-patch torch.load so that INT8 quantized checkpoints - # (created by scripts/quantize_uma_model.py) are automatically dequantized - # back to FP16 and returned as MLIPInferenceCheckpoint. - # This makes load_predict_unit() work transparently with INT8 files. - _orig_torch_load = torch.load - - def _int8_aware_torch_load(f, *args, **kwargs): - result = _orig_torch_load(f, *args, **kwargs) - if isinstance(result, dict) and "quantized_ema_state_dict" in result: - import gc as _gc - from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint - - print(" Dequantizing INT8 → FP16...") - quantized_ema = result["quantized_ema_state_dict"] - scales = result["quantization_scales"] - ema_state_dict = {} - for name, tensor in quantized_ema.items(): - if name in scales: - ema_state_dict[name] = (tensor.float() * scales[name].float()).half() - else: - ema_state_dict[name] = tensor - del quantized_ema, scales - checkpoint = MLIPInferenceCheckpoint( - model_config=result["model_config"], - model_state_dict=result.get("model_state_dict", {}), - ema_state_dict=ema_state_dict, - tasks_config=result["tasks_config"], - ) - del result - _gc.collect() - print(" ✓ Dequantization complete") - return checkpoint - return result - - torch.load = _int8_aware_torch_load - - # --- Model registry fallback --- - # In Pyodide, fairchem model modules can't be imported in bulk - # (missing C deps), so the @registry decorators never populate - # model_name_mapping. Patch get_model_class to resolve by full - # dotted import path (e.g. 'fairchem.core.models.uma.escn_moe.eSCNMDMoeBackbone'). - try: - from fairchem.core.common.registry import registry as _registry - - _orig_get_model = _registry.get_model_class - - def _fallback_get_model_class(name): - try: - return _orig_get_model(name) - except RuntimeError: - import importlib as _imp - module_path, class_name = name.rsplit(".", 1) - mod = _imp.import_module(module_path) - return getattr(mod, class_name) - - _registry.get_model_class = _fallback_get_model_class - except Exception: - pass - print("✓ FAIRChem dependency stubs applied") @@ -1007,4 +901,3 @@ def apply_all_patches(include_fairchem=False): patch_fairchem_deps() print("\n✅ All Pyodide patches applied successfully!") - From 28aec58092b73fa8162b62aa8aef97df773026c5 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 16:19:01 -0700 Subject: [PATCH 11/12] chore: update notebook --- .gitignore | 1 + .../relax_structure_with_uma.ipynb | 56 ++------------ .../notebooks_utils/pyodide/packages/torch.py | 75 ++++++++++++++++++- 3 files changed, 81 insertions(+), 51 deletions(-) diff --git a/.gitignore b/.gitignore index 04de5d664..dc0a1e54a 100644 --- a/.gitignore +++ b/.gitignore @@ -110,3 +110,4 @@ examples/assets/Molecules_Dataset_Collection pw_scf.* examples/**/*.html .DS_Store +node_modules diff --git a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb index c7ade9483..78f0baf9b 100644 --- a/other/experiments/jupyterlite/relax_structure_with_uma.ipynb +++ b/other/experiments/jupyterlite/relax_structure_with_uma.ipynb @@ -5,7 +5,7 @@ "id": "9515ed910c085db8", "metadata": {}, "source": [ - "# Relax Structure with FAIRChem UMA — Universal Machine-learning Force Field\n", + "# Relax Structure with FAIRChem UMA \u2014 Universal Machine-learning Force Field\n", "\n", "Use FAIRChem's [UMA](https://github.com/FAIR-Chem/fairchem) (Universal Model for Atoms) to relax crystal structures using machine-learned interatomic potentials.\n", "\n", @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-f16.pt\" # FP16 inference-only checkpoint (~294 MB)" + "UMA_MODEL_PATH = \"/drive/packages/models/uma-s-1p1-int8.pt\"" ] }, { @@ -197,51 +197,7 @@ }, "outputs": [], "source": [ - "import plotly.graph_objs as go\n", - "from IPython.display import display\n", - "from plotly.subplots import make_subplots\n", - "\n", - "from mat3ra.made.tools.convert import to_ase\n", - "from ase.optimize import BFGS\n", - "\n", - "ase_structure = to_ase(structure)\n", - "ase_structure.set_calculator(calculator)\n", - "dyn = BFGS(ase_structure)\n", - "\n", - "steps = []\n", - "energies = []\n", - "\n", - "fig = make_subplots(rows=1, cols=1, specs=[[{\"type\": \"scatter\"}]])\n", - "scatter = go.Scatter(x=[], y=[], mode=\"lines+markers\", name=\"Energy\")\n", - "fig.add_trace(scatter)\n", - "fig.update_layout(title_text=\"Real-time Optimization Progress\", xaxis_title=\"Step\", yaxis_title=\"Energy (eV)\")\n", - "\n", - "f = go.FigureWidget(fig)\n", - "display(f)\n", - "\n", - "\n", - "def plotly_callback():\n", - " step = dyn.nsteps\n", - " energy = ase_structure.get_total_energy()\n", - " steps.append(step)\n", - " energies.append(energy)\n", - " print(f\"Step: {step}, Energy: {energy:.4f} eV\")\n", - " with f.batch_update():\n", - " f.data[0].x = steps\n", - " f.data[0].y = energies\n", - "\n", - "\n", - "dyn.attach(plotly_callback, interval=1)\n", - "dyn.run(fmax=RELAXATION_PARAMETERS[\"FMAX\"])\n", - "\n", - "ase_original_structure = to_ase(structure)\n", - "ase_original_structure.set_calculator(calculator)\n", - "ase_final_structure = ase_structure\n", - "\n", - "original_energy = ase_original_structure.get_total_energy()\n", - "relaxed_energy = ase_structure.get_total_energy()\n", - "\n", - "print(f\"The final energy is {float(relaxed_energy):.3f} eV.\")" + "import plotly.graph_objs as go\nfrom IPython.display import display\nfrom plotly.subplots import make_subplots\n\nfrom mat3ra.made.tools.convert import to_ase\nfrom ase.optimize import BFGS\n\nase_structure = to_ase(structure)\nase_structure.set_calculator(calculator)\ndyn = BFGS(ase_structure)\n\nsteps = []\nenergies = []\n\nfig = make_subplots(rows=1, cols=1, specs=[[{\"type\": \"scatter\"}]])\nscatter = go.Scatter(x=[], y=[], mode=\"lines+markers\", name=\"Energy\")\nfig.add_trace(scatter)\nfig.update_layout(title_text=\"Real-time Optimization Progress\", xaxis_title=\"Step\", yaxis_title=\"Energy (eV)\")\n\ntry:\n f = go.FigureWidget(fig)\nexcept ImportError:\n f = go.Figure(fig)\ndisplay(f)\n\n\ndef plotly_callback():\n step = dyn.nsteps\n energy = ase_structure.get_total_energy()\n steps.append(step)\n energies.append(energy)\n print(f\"Step: {step}, Energy: {energy:.4f} eV\")\n if hasattr(f, \"batch_update\"):\n with f.batch_update():\n f.data[0].x = steps\n f.data[0].y = energies\n else:\n f.data[0].x = steps\n f.data[0].y = energies\n\n\ndyn.attach(plotly_callback, interval=1)\ndyn.run(fmax=RELAXATION_PARAMETERS[\"FMAX\"])\n\nase_original_structure = to_ase(structure)\nase_original_structure.set_calculator(calculator)\nase_final_structure = ase_structure\n\noriginal_energy = ase_original_structure.get_total_energy()\nrelaxed_energy = ase_structure.get_total_energy()\n\nprint(f\"The final energy is {float(relaxed_energy):.3f} eV.\")" ] }, { @@ -302,9 +258,9 @@ "FILM_TAG = 1\n", "\n", "print(\n", - " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")\n", + " f\"Interlayer distance before relaxation: {get_average_interlayer_distance(material_original, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")\n", "print(\n", - " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} Å\")" + " f\"Interlayer distance after relaxation: {get_average_interlayer_distance(material_relaxed, SUBSTRATE_TAG, FILM_TAG):.4f} \u00c5\")" ] }, { @@ -315,7 +271,7 @@ "## References\n", "\n", "[1] FAIRChem: https://github.com/FAIR-Chem/fairchem \n", - "[2] UMA — Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", + "[2] UMA \u2014 Universal Machine-learning Force Field for Atomistic Systems: https://arxiv.org/abs/2410.22570 \n", "[3] mat3ra-made interface builder: https://github.com/Exabyte-io/made " ] } diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index cc885bec6..bfb2b43a1 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -769,7 +769,29 @@ class _ListConfig(list): "create": staticmethod(lambda d: _DictConfig(d) if isinstance(d, dict) else d), }, ) - _make_stub_module("hydra", submodules=["core", "core.global_hydra"]) + _make_stub_module("hydra", submodules=["core", "core.global_hydra", "utils"]) + + def _hydra_instantiate(config, *args, _recursive_=True, **kwargs): + import importlib + if isinstance(config, dict) and "_target_" in config: + target = config["_target_"] + mod_path, cls_name = target.rsplit(".", 1) + mod = importlib.import_module(mod_path) + cls = getattr(mod, cls_name) + pos_args = list(args) + list(config.get("_args_", [])) + cfg = {k: v for k, v in config.items() if not k.startswith("_")} + if _recursive_: + for k, v in cfg.items(): + if isinstance(v, dict) and "_target_" in v: + cfg[k] = _hydra_instantiate(v) + elif isinstance(v, list): + cfg[k] = [_hydra_instantiate(i) if isinstance(i, dict) and "_target_" in i else i for i in v] + pos_args = [_hydra_instantiate(a) if isinstance(a, dict) and "_target_" in a else a for a in pos_args] + return cls(*pos_args, **{**cfg, **kwargs}) + return config + + sys.modules["hydra.utils"].instantiate = _hydra_instantiate + sys.modules["hydra"].utils = sys.modules["hydra.utils"] # --- submitit / clusterscope --- _make_stub_module("submitit") @@ -795,6 +817,57 @@ def _tqdm_passthrough(iterable=None, *a, **k): # --- ase_db_backends --- _make_stub_module("ase_db_backends") + # --- INT8 quantized model support --- + _orig_torch_load = torch.load + + def _int8_aware_torch_load(f, *args, **kwargs): + result = _orig_torch_load(f, *args, **kwargs) + if isinstance(result, dict) and "quantized_ema_state_dict" in result: + import gc as _gc + from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint + + print(" Dequantizing INT8 → FP16...") + quantized_ema = result["quantized_ema_state_dict"] + scales = result["quantization_scales"] + ema_state_dict = {} + for name, tensor in quantized_ema.items(): + if name in scales: + ema_state_dict[name] = (tensor.float() * scales[name].float()).half() + else: + ema_state_dict[name] = tensor + del quantized_ema, scales + checkpoint = MLIPInferenceCheckpoint( + model_config=result["model_config"], + model_state_dict=result.get("model_state_dict", {}), + ema_state_dict=ema_state_dict, + tasks_config=result["tasks_config"], + ) + del result + _gc.collect() + print(" ✓ Dequantization complete") + return checkpoint + return result + + torch.load = _int8_aware_torch_load + + # --- Model registry fallback --- + try: + from fairchem.core.common.registry import registry as _registry + + _orig_get_model = _registry.get_model_class + + def _fallback_get_model_class(name): + try: + return _orig_get_model(name) + except RuntimeError: + import importlib as _imp + module_path, class_name = name.rsplit(".", 1) + mod = _imp.import_module(module_path) + return getattr(mod, class_name) + + _registry.get_model_class = _fallback_get_model_class + except Exception: + pass print("✓ FAIRChem dependency stubs applied") From 3acca48cd5ec7c3cd9439f4f3f680026f235b584 Mon Sep 17 00:00:00 2001 From: Timur Bazhirov Date: Tue, 19 May 2026 17:46:35 -0700 Subject: [PATCH 12/12] experiment: dequantization to save space/memory --- .../notebooks_utils/pyodide/packages/torch.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py index bfb2b43a1..d4eb832d3 100644 --- a/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py +++ b/src/py/mat3ra/notebooks_utils/pyodide/packages/torch.py @@ -826,16 +826,22 @@ def _int8_aware_torch_load(f, *args, **kwargs): import gc as _gc from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint - print(" Dequantizing INT8 → FP16...") - quantized_ema = result["quantized_ema_state_dict"] - scales = result["quantization_scales"] + print(" Dequantizing INT8 → FP16 (streaming)...") + quantized_ema = result.pop("quantized_ema_state_dict") + scales = result.pop("quantization_scales") ema_state_dict = {} - for name, tensor in quantized_ema.items(): + names = list(quantized_ema.keys()) + for name in names: + tensor = quantized_ema.pop(name) if name in scales: - ema_state_dict[name] = (tensor.float() * scales[name].float()).half() + scale = scales.pop(name) + ema_state_dict[name] = (tensor.float() * scale.float()).half() + del scale else: ema_state_dict[name] = tensor + del tensor del quantized_ema, scales + _gc.collect() checkpoint = MLIPInferenceCheckpoint( model_config=result["model_config"], model_state_dict=result.get("model_state_dict", {}),