diff --git a/example_notebooks/nuclear_evolution.ipynb b/example_notebooks/nuclear_evolution.ipynb index f007e02..997c6c3 100644 --- a/example_notebooks/nuclear_evolution.ipynb +++ b/example_notebooks/nuclear_evolution.ipynb @@ -118,6 +118,18 @@ "t_vec, a_vec, rho_g_vec, rho_nu_vec, rho_NP_vec, p_NP_vec, Neff_vec = bkg_model(jnp.asarray(0.))" ] }, + { + "cell_type": "markdown", + "source": "### Releasing memory after the first solve\n\nJAX/XLA's compile uses several hundred MB, but does not need them after compile. Run this command to free up that compile-only memory, once after your first `BackgroundModel`/`AbundanceModel` call. Especially useful when running many ranks in parallel (e.g. MCMC or nested sampling).", + "metadata": {} + }, + { + "cell_type": "code", + "source": "from linx import release_unused_memory\nrelease_unused_memory()", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -940,4 +952,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/linx/__init__.py b/linx/__init__.py index e69de29..799f861 100644 --- a/linx/__init__.py +++ b/linx/__init__.py @@ -0,0 +1,3 @@ +"""LINX: Light Isotope Nucleosynthesis with JAX.""" + +from linx.utils import release_unused_memory diff --git a/linx/utils.py b/linx/utils.py new file mode 100644 index 0000000..37466db --- /dev/null +++ b/linx/utils.py @@ -0,0 +1,23 @@ +"""Miscellaneous LINX runtime utilities.""" + +import ctypes +import gc +import sys + + +def release_unused_memory(): + """Return freed heap pages to the OS. + + JAX/XLA compile leaves several hundred MiB of freed-but-untrimmed memory + on systems with glibc. Calling this once -- typically right after + the first BackgroundModel / AbundanceModel call -- can free up to ~half + of the memory needed per solve. + + Safe to call any time, cannot drop live memory. + """ + gc.collect() + if sys.platform.startswith("linux"): + try: + ctypes.CDLL("libc.so.6").malloc_trim(0) + except OSError: + pass