From 50a790252e665c5697831012bd58726942867cd6 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 24 May 2026 07:40:05 -0500 Subject: [PATCH 1/2] feat: add VMEC mgrid export from coils --- essos/__init__.py | 3 + essos/coils.py | 13 ++- essos/mgrid.py | 188 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_mgrid.py | 99 +++++++++++++++++++++++ 4 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 essos/mgrid.py create mode 100644 tests/test_mgrid.py diff --git a/essos/__init__.py b/essos/__init__.py index e69de29b..96f25eda 100644 --- a/essos/__init__.py +++ b/essos/__init__.py @@ -0,0 +1,3 @@ +from .mgrid import MGrid, coils_to_mgrid + +__all__ = ["MGrid", "coils_to_mgrid"] diff --git a/essos/coils.py b/essos/coils.py index 7782fe8e..b64947d6 100644 --- a/essos/coils.py +++ b/essos/coils.py @@ -475,6 +475,17 @@ def to_json(self, filename: str): with open(filename, "w") as file: json.dump(data, file) + def to_mgrid(self, filename: str, **kwargs): + """Write a VMEC mgrid file by evaluating this coil set. + + This mirrors SIMSOPT's ``MagneticField.to_mgrid`` grid convention. + Keyword arguments are forwarded to :func:`essos.mgrid.coils_to_mgrid`. + """ + + from .mgrid import coils_to_mgrid + + return coils_to_mgrid(self, filename, **kwargs) + class Coils_from_json(Coils): def __init__(self, filename: str): import json @@ -657,4 +668,4 @@ def fit_dofs_from_coils( gamma_uni = _resample_closed_curve_uniform_batch(coils_gamma, n_segments) # arclength (vmapped) dofs = _fit_real_fourier_batch(gamma_uni, order) # rFFT-based fit - return dofs, gamma_uni \ No newline at end of file + return dofs, gamma_uni diff --git a/essos/mgrid.py b/essos/mgrid.py new file mode 100644 index 00000000..9f7d6dda --- /dev/null +++ b/essos/mgrid.py @@ -0,0 +1,188 @@ +"""VMEC mgrid read/write helpers for ESSOS coil fields. + +The writer mirrors the SIMSOPT ``MagneticField.to_mgrid`` convention: +fields are evaluated on a cylindrical tensor grid with layout +``(nphi, nz, nr)`` and written to VMEC/MAKEGRID-compatible NetCDF files. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import numpy as np +from scipy.io import netcdf_file + + +def _pad_string(string: str) -> str: + return "{:^30}".format(str(string)).replace(" ", "_") + + +def _unpack(binary_array: Any) -> str: + return "".join(np.char.decode(binary_array)).strip() + + +@dataclass +class MGrid: + """Container for VMEC mgrid field data.""" + + nr: int = 51 + nz: int = 51 + nphi: int = 24 + nfp: int = 2 + rmin: float = 0.20 + rmax: float = 0.40 + zmin: float = -0.10 + zmax: float = 0.10 + br_arr: list[Any] = field(default_factory=list) + bp_arr: list[Any] = field(default_factory=list) + bz_arr: list[Any] = field(default_factory=list) + coil_names: list[str] = field(default_factory=list) + + @property + def n_ext_cur(self) -> int: + return len(self.br_arr) + + def add_field_cylindrical(self, br: Any, bp: Any, bz: Any, name: str | None = None) -> None: + """Append one external-current group in cylindrical components.""" + + expected = (int(self.nphi), int(self.nz), int(self.nr)) + br_arr = np.asarray(br, dtype=float) + bp_arr = np.asarray(bp, dtype=float) + bz_arr = np.asarray(bz, dtype=float) + if br_arr.shape != expected or bp_arr.shape != expected or bz_arr.shape != expected: + raise ValueError(f"mgrid fields must have shape {expected}") + self.br_arr.append(br_arr) + self.bp_arr.append(bp_arr) + self.bz_arr.append(bz_arr) + self.coil_names.append(_pad_string(name or f"magnet_{self.n_ext_cur}")) + + def write(self, filename: str | Path) -> None: + """Write the mgrid to a VMEC-compatible NetCDF file.""" + + filename = str(filename) + with netcdf_file(filename, "w", mmap=False, version=2) as ds: + ds.createDimension("stringsize", 30) + ds.createDimension("dim_00001", 1) + ds.createDimension("external_coil_groups", self.n_ext_cur) + ds.createDimension("external_coils", self.n_ext_cur) + ds.createDimension("rad", int(self.nr)) + ds.createDimension("zee", int(self.nz)) + ds.createDimension("phi", int(self.nphi)) + + ds.createVariable("ir", "i4", tuple()).data[()] = int(self.nr) + ds.createVariable("jz", "i4", tuple()).data[()] = int(self.nz) + ds.createVariable("kp", "i4", tuple()).data[()] = int(self.nphi) + ds.createVariable("nfp", "i4", tuple()).data[()] = int(self.nfp) + ds.createVariable("nextcur", "i4", tuple()).data[()] = int(self.n_ext_cur) + ds.createVariable("rmin", "f8", tuple()).data[()] = float(self.rmin) + ds.createVariable("zmin", "f8", tuple()).data[()] = float(self.zmin) + ds.createVariable("rmax", "f8", tuple()).data[()] = float(self.rmax) + ds.createVariable("zmax", "f8", tuple()).data[()] = float(self.zmax) + + if self.n_ext_cur == 1: + coil_group = ds.createVariable("coil_group", "c", ("stringsize",)) + coil_group[:] = self.coil_names[0] + else: + coil_group = ds.createVariable("coil_group", "c", ("external_coil_groups", "stringsize")) + coil_group[:] = self.coil_names + mode = ds.createVariable("mgrid_mode", "c", ("dim_00001",)) + mode[:] = "N" + raw_current = ds.createVariable("raw_coil_cur", "f8", ("external_coils",)) + raw_current[:] = np.ones(self.n_ext_cur) + + for idx in range(self.n_ext_cur): + tag = f"_{idx + 1:03d}" + ds.createVariable("br" + tag, "f8", ("phi", "zee", "rad"))[:, :, :] = self.br_arr[idx] + ds.createVariable("bp" + tag, "f8", ("phi", "zee", "rad"))[:, :, :] = self.bp_arr[idx] + ds.createVariable("bz" + tag, "f8", ("phi", "zee", "rad"))[:, :, :] = self.bz_arr[idx] + + @classmethod + def from_file(cls, filename: str | Path) -> "MGrid": + """Read an mgrid NetCDF file.""" + + with netcdf_file(str(filename), "r", mmap=False, version=2) as ds: + mgrid = cls( + nr=int(ds.variables["ir"].getValue()), + nz=int(ds.variables["jz"].getValue()), + nphi=int(ds.variables["kp"].getValue()), + nfp=int(ds.variables["nfp"].getValue()), + rmin=float(ds.variables["rmin"].getValue()), + rmax=float(ds.variables["rmax"].getValue()), + zmin=float(ds.variables["zmin"].getValue()), + zmax=float(ds.variables["zmax"].getValue()), + ) + nextcur = int(ds.variables["nextcur"].getValue()) + coil_data = ds.variables["coil_group"][:] + if len(ds.variables["coil_group"].dimensions) == 2: + mgrid.coil_names = [_unpack(coil_data[j]) for j in range(nextcur)] + else: + mgrid.coil_names = [_unpack(coil_data)] + mgrid.mode = ds.variables["mgrid_mode"][:][0].decode() + mgrid.raw_coil_current = np.asarray(ds.variables["raw_coil_cur"][:], dtype=float) + for idx in range(nextcur): + tag = f"_{idx + 1:03d}" + mgrid.br_arr.append(np.asarray(ds.variables["br" + tag][:], dtype=float)) + mgrid.bp_arr.append(np.asarray(ds.variables["bp" + tag][:], dtype=float)) + mgrid.bz_arr.append(np.asarray(ds.variables["bz" + tag][:], dtype=float)) + mgrid.br = mgrid.br_arr[0] if nextcur == 1 else np.sum(mgrid.br_arr, axis=0) + mgrid.bp = mgrid.bp_arr[0] if nextcur == 1 else np.sum(mgrid.bp_arr, axis=0) + mgrid.bz = mgrid.bz_arr[0] if nextcur == 1 else np.sum(mgrid.bz_arr, axis=0) + mgrid.bvec = np.transpose([mgrid.br, mgrid.bp, mgrid.bz]) + return mgrid + + +def coils_to_mgrid( + coils: Any, + filename: str | Path, + *, + nr: int = 10, + nphi: int = 4, + nz: int = 12, + rmin: float = 1.0, + rmax: float = 2.0, + zmin: float = -0.5, + zmax: float = 0.5, + nfp: int | None = None, + name: str = "essos_coils", +) -> MGrid: + """Evaluate ESSOS coils on a cylindrical grid and write an mgrid file.""" + + import jax + import jax.numpy as jnp + + from .fields import BiotSavart + + nfp_eff = int(coils.nfp if nfp is None else nfp) + rs = np.linspace(float(rmin), float(rmax), int(nr), endpoint=True) + phis = np.linspace(0.0, 2.0 * np.pi / nfp_eff, int(nphi), endpoint=False) + zs = np.linspace(float(zmin), float(zmax), int(nz), endpoint=True) + Phi, Z, R = np.meshgrid(phis, zs, rs, indexing="ij") + xyz = np.stack( + [ + R.reshape(-1) * np.cos(Phi.reshape(-1)), + R.reshape(-1) * np.sin(Phi.reshape(-1)), + Z.reshape(-1), + ], + axis=1, + ) + + field = BiotSavart(coils) + b_xyz = np.asarray(jax.vmap(field.B)(jnp.asarray(xyz)), dtype=float) + phi_flat = Phi.reshape(-1) + bx = b_xyz[:, 0] + by = b_xyz[:, 1] + br = bx * np.cos(phi_flat) + by * np.sin(phi_flat) + bp = -bx * np.sin(phi_flat) + by * np.cos(phi_flat) + bz = b_xyz[:, 2] + + mgrid = MGrid(nfp=nfp_eff, nr=int(nr), nz=int(nz), nphi=int(nphi), rmin=rmin, rmax=rmax, zmin=zmin, zmax=zmax) + mgrid.add_field_cylindrical( + br.reshape((int(nphi), int(nz), int(nr))), + bp.reshape((int(nphi), int(nz), int(nr))), + bz.reshape((int(nphi), int(nz), int(nr))), + name=name, + ) + mgrid.write(filename) + return mgrid diff --git a/tests/test_mgrid.py b/tests/test_mgrid.py new file mode 100644 index 00000000..9c0383ee --- /dev/null +++ b/tests/test_mgrid.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from essos.coils import Coils_from_json, Coils_from_simsopt, CreateEquallySpacedCurves +from essos.mgrid import MGrid, coils_to_mgrid + + +def test_mgrid_write_read_roundtrip(tmp_path): + mgrid = MGrid(nr=3, nz=4, nphi=2, nfp=1, rmin=1.0, rmax=2.0, zmin=-0.5, zmax=0.5) + br = np.ones((2, 4, 3)) + bp = 2.0 * br + bz = 3.0 * br + mgrid.add_field_cylindrical(br, bp, bz, name="test_coil") + + filename = tmp_path / "mgrid.test.nc" + mgrid.write(filename) + loaded = MGrid.from_file(filename) + + assert loaded.n_ext_cur == 1 + assert loaded.nr == 3 + assert loaded.nz == 4 + assert loaded.nphi == 2 + assert loaded.nfp == 1 + assert loaded.coil_names[0] == "__________test_coil___________" + np.testing.assert_allclose(loaded.br_arr[0], br) + np.testing.assert_allclose(loaded.bp_arr[0], bp) + np.testing.assert_allclose(loaded.bz_arr[0], bz) + + +def test_coils_to_mgrid_writes_expected_shape_and_finite_values(tmp_path): + coils = CreateEquallySpacedCurves(1, order=1, R=1.0, r=0.2, n_segments=32, nfp=1, stellsym=False) + from essos.coils import Coils + import jax.numpy as jnp + + coil_set = Coils(coils, jnp.asarray([2.0])) + filename = tmp_path / "mgrid.coils.nc" + mgrid = coils_to_mgrid( + coil_set, + filename, + nr=4, + nphi=3, + nz=5, + rmin=0.4, + rmax=1.8, + zmin=-0.7, + zmax=0.7, + nfp=1, + ) + loaded = MGrid.from_file(filename) + + assert mgrid.n_ext_cur == 1 + assert loaded.br_arr[0].shape == (3, 5, 4) + assert loaded.bp_arr[0].shape == (3, 5, 4) + assert loaded.bz_arr[0].shape == (3, 5, 4) + assert np.all(np.isfinite(loaded.br_arr[0])) + assert np.all(np.isfinite(loaded.bp_arr[0])) + assert np.all(np.isfinite(loaded.bz_arr[0])) + + +def test_landreman_paul_qa_essos_json_can_write_mgrid(tmp_path): + path = Path(__file__).resolve().parents[1] / "examples" / "input_files" / "ESSOS_biot_savart_LandremanPaulQA.json" + coils = Coils_from_json(str(path)) + filename = tmp_path / "mgrid.lp_qa.nc" + + coils.to_mgrid(filename, nr=4, nphi=3, nz=5, rmin=0.5, rmax=2.0, zmin=-0.8, zmax=0.8) + loaded = MGrid.from_file(filename) + + assert loaded.nfp == 2 + assert loaded.br_arr[0].shape == (3, 5, 4) + assert np.max(np.abs(loaded.br_arr[0])) > 0.0 + assert np.max(np.abs(loaded.bz_arr[0])) > 0.0 + + +def test_simsopt_to_mgrid_parity_when_simsopt_is_available(tmp_path): + simsopt = pytest.importorskip("simsopt") + from simsopt import load + from simsopt.field import MGrid as SimsoptMGrid + + del simsopt + json_file = Path(__file__).resolve().parents[1] / "examples" / "input_files" / "SIMSOPT_biot_savart_LandremanPaulQA.json" + essos_coils = Coils_from_simsopt(str(json_file), nfp=2, stellsym=True) + simsopt_field = load(str(json_file)) + + kwargs = dict(nr=4, nphi=3, nz=5, rmin=0.5, rmax=2.0, zmin=-0.8, zmax=0.8, nfp=2) + essos_file = tmp_path / "mgrid.essos.nc" + simsopt_file = tmp_path / "mgrid.simsopt.nc" + essos_coils.to_mgrid(essos_file, **kwargs) + simsopt_field.to_mgrid(simsopt_file, **kwargs) + + essos_grid = MGrid.from_file(essos_file) + simsopt_grid = SimsoptMGrid.from_file(simsopt_file) + + np.testing.assert_allclose(essos_grid.br_arr[0], simsopt_grid.br_arr[0], rtol=5.0e-12, atol=1.0e-16) + np.testing.assert_allclose(essos_grid.bp_arr[0], simsopt_grid.bp_arr[0], rtol=5.0e-12, atol=1.0e-16) + np.testing.assert_allclose(essos_grid.bz_arr[0], simsopt_grid.bz_arr[0], rtol=5.0e-12, atol=1.0e-16) From 4e19e7f129fe41c7f5aed53a032d1cf3c90ecfc8 Mon Sep 17 00:00:00 2001 From: Rogerio Jorge Date: Sun, 24 May 2026 08:13:34 -0500 Subject: [PATCH 2/2] test: fix ESSOS CI for current JAX --- essos/augmented_lagrangian.py | 38 +++++++++++++++++++---------------- essos/coil_perturbation.py | 3 +-- tests/test_multiobjectives.py | 2 +- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/essos/augmented_lagrangian.py b/essos/augmented_lagrangian.py index 222c9c51..34d9e0bf 100644 --- a/essos/augmented_lagrangian.py +++ b/essos/augmented_lagrangian.py @@ -11,6 +11,11 @@ import jaxopt import optimistix +if not hasattr(jax, "tree_map"): + jax.tree_map = jax.tree_util.tree_map # compatibility for older jaxopt releases + +_tree_map = jax.tree_util.tree_map + class LagrangeMultiplier(NamedTuple): """A class containing constrain parameters for Augmented Lagrangian Method""" value: Any @@ -29,16 +34,16 @@ def update_method(params,updates,eta,omega,model_mu='Constant',beta=2.0,mu_max=1 pred = lambda x: isinstance(x, LagrangeMultiplier) if model_mu=='Constant': #jax.debug.print('{m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Monotonic': #jax.debug.print('{m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Conditional_True': #jax.debug.print('True {m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Conditional_False': #jax.debug.print('False {m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Tolerance_True': #jax.debug.print('Standard True {m}', m=model_mu) mu_average=penalty_average(params) @@ -46,7 +51,7 @@ def update_method(params,updates,eta,omega,model_mu='Constant',beta=2.0,mu_max=1 #omega=omega/mu_average eta=jnp.maximum(eta/mu_average**(0.1),eta_tol) omega=jnp.maximum(omega/mu_average,omega_tol) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*y.value,0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega elif model_mu=='Mu_Tolerance_False': #jax.debug.print('Standard False {m}', m=model_mu) mu_average=penalty_average(params) @@ -56,12 +61,12 @@ def update_method(params,updates,eta,omega,model_mu='Constant',beta=2.0,mu_max=1 #jax.debug.print('HMMMMMM mu_av {m}', m=mu_average) #jax.debug.print('HMMMMMM eta {m}', m=eta) omega=jnp.maximum(1./mu_average,omega_tol) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega - #return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + #return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega elif model_mu=='Mu_Adaptative': #jax.debug.print('True {m}', m=model_mu) #Note that y.penalty is the derivative with respect to mu and so it is 0.5*C(x)**2, like the derivative with respect to lambda is C(x) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*y.value,-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*y.penalty*2.),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*y.value,-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*y.penalty*2.),params,updates,is_leaf=pred) @@ -74,16 +79,16 @@ def update_method_squared(params,updates,eta,omega,model_mu='Constant',beta=2.0, pred = lambda x: isinstance(x, LagrangeMultiplier) if model_mu=='Constant': #jax.debug.print('{m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier((y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier((y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Monotonic': #jax.debug.print('{m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Conditional_True': #jax.debug.print('True {m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Conditional_False': #jax.debug.print('False {m}', m=model_mu) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred) elif model_mu=='Mu_Tolerance_True': #jax.debug.print('Squared True {m}', m=model_mu) mu_average=penalty_average(params) @@ -91,7 +96,7 @@ def update_method_squared(params,updates,eta,omega,model_mu='Constant',beta=2.0, #omega=omega/mu_average eta=jnp.maximum(eta/mu_average**(0.1),eta_tol) omega=jnp.maximum(omega/mu_average,omega_tol) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega + return _tree_map(lambda x,y: LagrangeMultiplier(x.penalty*(y.value-x.value/x.penalty),0.0*x.value,0.0*x.value),params,updates,is_leaf=pred),eta,omega elif model_mu=='Mu_Tolerance_False': #jax.debug.print('Squared False {m}', m=model_mu) mu_average=penalty_average(params) @@ -99,12 +104,12 @@ def update_method_squared(params,updates,eta,omega,model_mu='Constant',beta=2.0, #omega=1./mu_average eta=jnp.maximum(1./mu_average**(0.1),eta_tol) omega=jnp.maximum(1./mu_average,omega_tol) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega - #return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega + #return _tree_map(lambda x,y: LagrangeMultiplier(0.0*x.value,-x.penalty+jnp.minimum(beta*x.penalty,mu_max),0.0*x.value),params,updates,is_leaf=pred),eta,omega elif model_mu=='Mu_Adaptative': #jax.debug.print('True {m}', m=model_mu) #Note that y.penalty is the derivative with respect to mu and so it is 0.5*C(x)**2, like the derivative with respect to lambda is C(x) - return jax.jax.tree_util.tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*(y.value-x.value/x.penalty),-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*(y.penalty*2.+(x.value/x.penalty)**2)),params,updates,is_leaf=pred) + return _tree_map(lambda x,y: LagrangeMultiplier(gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon)*(y.value-x.value/x.penalty),-x.penalty+gamma/(jnp.sqrt(alpha*x.sq_grad+(1.-alpha)*y.penalty*2.)+epsilon),-x.sq_grad+alpha*x.sq_grad+(1.-alpha)*(y.penalty*2.+(x.value/x.penalty)**2)),params,updates,is_leaf=pred) @@ -718,4 +723,3 @@ def update_fn(params, lag_state,grad,info,eta,omega,beta=beta,mu_max=mu_max,alph return ALM(init_fn,partial(update_fn,beta=beta,mu_max=mu_max,alpha=alpha,gamma=gamma,epsilon=epsilon,eta_tol=eta_tol,omega_tol=omega_tol)) - diff --git a/essos/coil_perturbation.py b/essos/coil_perturbation.py index 7a9778ef..48460abc 100644 --- a/essos/coil_perturbation.py +++ b/essos/coil_perturbation.py @@ -52,7 +52,7 @@ def matrix_sqrt_via_spectral(A): eigvals, Q = jnp.linalg.eigh(A) # A = Q Λ Q^T # Ensure numerical stability (clip small negatives to 0) - eigvals = jnp.clip(eigvals, a_min=0) + eigvals = jnp.clip(eigvals, min=0) sqrt_eigvals = jnp.sqrt(eigvals) sqrt_A = Q @ jnp.diag(sqrt_eigvals) @ Q.T @@ -271,4 +271,3 @@ def perturb_curves_statistic(curves: Curves,sampler:GaussianSampler, key=None): curves.gamma_dash=curves.gamma_dash + perturbation[:,1,:,:] curves.gamma_dashdash=curves.gamma_dashdash + perturbation[:,2,:,:] #return curves - diff --git a/tests/test_multiobjectives.py b/tests/test_multiobjectives.py index 019a0ac5..a9f41dc1 100644 --- a/tests/test_multiobjectives.py +++ b/tests/test_multiobjectives.py @@ -68,7 +68,7 @@ def test_build_available_inputs( vmec=mock_vmec(), dummy_loss_fn=dummy_loss_fn( assert loss_weight_result == 496 optimized_coils=optimizer.optimize_with_optax(weights, method="adam", lr=1e-2) - assert optimized_coils.currents_scale==0.01999998979999997872 + assert optimized_coils.currents_scale == pytest.approx(0.01999998979999997872) dofs_curves=optimized_coils.dofs_curves currents_scale=optimized_coils.currents_scale