diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cdd43af..c9b0b89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] os: [ubuntu-latest, windows-latest, macOS-latest] backend: [torch, numpy, jax] @@ -55,12 +55,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-cov torch wheel pydantic + pip install pytest pytest-cov wheel pydantic # We only want to install this on one run, because otherwise we'll have # duplicate annotations. - name: Install error reporter - if: ${{ matrix.python-version == '3.10' }} + if: ${{ matrix.python-version == '3.13' }} run: | python -m pip install pytest-github-actions-annotate-failures @@ -88,8 +88,9 @@ jobs: run: | pip install -e ".[dev,${{ matrix.backend }}]" pip show ${{ env.PROJECT_NAME }} - export JAX_ENABLE_X64=True shell: bash + env: + JAX_ENABLE_X64: "True" - name: Test with pytest run: | @@ -98,12 +99,10 @@ jobs: ls -a cat .coverage shell: bash - env: - CASKADE_BACKEND: ${{ matrix.backend }} - name: Extra coverage report for jax checks if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} run: | echo "Running extra coverage report for jax checks" pip install jax jaxlib @@ -111,9 +110,10 @@ jobs: shell: bash env: CASKADE_BACKEND: jax + JAX_ENABLE_X64: "True" - name: Extra coverage report for numpy checks if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} run: | echo "Running extra coverage report for numpy checks" coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/ @@ -127,7 +127,7 @@ jobs: CASKADE_BACKEND: numpy - name: Upload coverage reports to Codecov with GitHub Action if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.readthedocs.yml b/.readthedocs.yml index d57c077..5f4a916 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -21,9 +21,9 @@ formats: # Specify dependencies to be installed # Define the system dependencies build: - os: "ubuntu-20.04" + os: "ubuntu-26.04" tools: - python: "3.10" + python: "3.13" apt_packages: - pandoc # Specify pandoc to be installed via apt-get - graphviz # Specify graphviz to be installed via apt-get @@ -39,6 +39,5 @@ build: python: install: - requirements: docs/requirements.txt # Path to your requirements.txt file - - requirements: requirements.txt # Path to your requirements.txt file - method: pip - path: . # Install the package itself + path: .[dev,torch,jax] # Install the package itself diff --git a/README.md b/README.md index 484276a..f101610 100644 --- a/README.md +++ b/README.md @@ -22,14 +22,21 @@ argument passing for complex nested simulators. pip install caskade ``` -More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html). -if you want to use `caskade` with `jax` then run: +This will give you the numpy version of ``caskade``. More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html). +if you want to use ``caskade`` with ``jax`` then run: ```bash pip install caskade[jax] ``` -Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements. +if you want to use ``caskade`` with ``torch`` then run: + +```bash +pip install caskade[torch] +``` + +Don't worry if you got the wrong one. Just install the appropriate +``jax``/``torch`` directly if you want to add the option. ## Usage diff --git a/docs/source/install.md b/docs/source/install.md index fd09505..632e586 100644 --- a/docs/source/install.md +++ b/docs/source/install.md @@ -1,26 +1,51 @@ # Install +In a recent update we made it so that a basic ``caskade`` install just defaults +to the numpy version. This is part of our quest to make ``caskade`` as easy to +use and nimble as possible. ``torch`` is actually a pretty heavy package to +install (taking a lot of space/time to install and several seconds just to +import), so by fully eliminating it as a requirement, anyone with the numpy or +jax versions don't need to wait for an import they wont use. + ## Basic install +To get the ``numpy`` version, just directly pip install: + ``` bash pip install caskade ``` -> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11 +## Install with PyTorch backend + +```bash +pip install caskade[torch] +``` + +This will simply install ``torch`` along with ``numpy``. You can also always just +pip install ``torch`` yourself after a basic ``caskade`` install. -## Install with jax backend +> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11 + +## Install with JAX backend ```bash pip install caskade[jax] ``` -This will simply install `jax`/`jaxlib` along with the other dependencies. It is -always possible to use the `torch` and `numpy` backends since they are core -requirements. +This will simply install ``jax`` along with ``numpy``. You can also always just +pip install ``jax`` yourself after a basic ``caskade`` install. > **Note:** For M1 Mac users there can be compatibility issues with jax/jaxlib. See [discussion here](https://stackoverflow.com/questions/68327863/importing-jax-fails-on-mac-with-m1-chip) and consider installing `jaxlib==0.4.35`. +## Install with PyTorch and JAX + +```bash +pip install caskade[torch,jax] +``` + +All the extra ``[torch,jax]`` bit does is add those to the list of dependencies, so you can always just install them yourself with pip. + ## Install from source 1. Fork the repo on [GitHub](https://github.com/ConnorStoneAstro/caskade) diff --git a/docs/source/intro.md b/docs/source/intro.md index 9d661b6..360d39c 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -6,8 +6,8 @@ [![PyPI - Version](https://img.shields.io/pypi/v/caskade)](https://pypi.org/project/caskade/) [![Documentation Status](https://readthedocs.org/projects/caskade/badge/?version=latest)](https://caskade.readthedocs.io/en/latest/?badge=latest) -Build scientific simulators, treating them as a directed acyclic graph. Handles -argument passing for complex nested simulators. +Build scientific simulators, treating them abstractly as a directed acyclic +graph. Handles argument passing for complex nested simulators. ## Install @@ -15,13 +15,19 @@ argument passing for complex nested simulators. pip install caskade ``` -if you want to use `caskade` with `jax` then run: +if you want to use ``caskade`` with ``jax`` or ``torch`` then run: ```bash pip install caskade[jax] ``` -Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements. +or + +```bash +pip install caskade[torch] +``` + +Alternately, just pip install ``jax``/``torch`` separately as they are the only extra requirements. ## Usage diff --git a/docs/source/notebooks/WorkedExample.ipynb b/docs/source/notebooks/WorkedExample.ipynb index 12c532c..eadac21 100644 --- a/docs/source/notebooks/WorkedExample.ipynb +++ b/docs/source/notebooks/WorkedExample.ipynb @@ -635,8 +635,8 @@ "lightcurvemodel.to_dynamic()\n", "fit_vals = likelihood2.get_values()\n", "hess = -hessian(likelihood2, fit_vals, strict=True)\n", - "hess_inv = torch.linalg.inv(hess) # Invert the Hessian to get the covariance matrix\n", - "light_curve_sigma = torch.sqrt(torch.diag(hess_inv).abs()).numpy()\n", + "hess_inv = torch.linalg.inv(hess).numpy() # Invert the Hessian to get the covariance matrix\n", + "light_curve_sigma = np.sqrt(np.abs(np.diag(hess_inv)))\n", "print(\n", " f\"Light Curve t0: {fit_vals[0].item():.2f} ± {light_curve_sigma[0]:.2f} vs {SN_lightcurve.t0.value.item():.2f} (true)\"\n", ")\n", @@ -664,14 +664,15 @@ "ax.axhline(SN_lightcurve.peak_flux.value.item(), color=\"r\", linestyle=\"--\")\n", "ax.set_xlabel(\"Sigma\")\n", "ax.set_ylabel(\"Peak Flux\")\n", - "lambda_, v = np.linalg.eig(hess_inv[1:, 1:])\n", + "lambda_, v = np.linalg.eigh(hess_inv[1:, 1:])\n", "lambda_ = np.sqrt(lambda_)\n", - "angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))\n", + "\n", + "angle = np.rad2deg(np.arctan2(v[1, 1], v[0, 1]))\n", "for k in [1, 2]:\n", " ellipse = Ellipse(\n", " xy=(fit_vals[1].item(), fit_vals[2].item()),\n", - " width=lambda_[0] * k * 2,\n", - " height=lambda_[1] * k * 2,\n", + " width=lambda_[1] * k * 2,\n", + " height=lambda_[0] * k * 2,\n", " angle=angle,\n", " edgecolor=\"black\",\n", " facecolor=\"grey\",\n", @@ -679,8 +680,8 @@ " )\n", " ax.add_artist(ellipse)\n", "plt.plot([], [], c=\"k\", label=\"Likelihood Contours\")\n", - "ax.set_xlim(fit_vals[1].item() - lambda_[0] * 3, fit_vals[1].item() + lambda_[0] * 3)\n", - "ax.set_ylim(fit_vals[2].item() - lambda_[1] * 3, fit_vals[2].item() + lambda_[1] * 3)\n", + "ax.set_xlim(fit_vals[1].item() - lambda_[1] * 3, fit_vals[1].item() + lambda_[1] * 3)\n", + "ax.set_ylim(fit_vals[2].item() - lambda_[0] * 3, fit_vals[2].item() + lambda_[0] * 3)\n", "ax.set_title(\"Light Curve Parameter Uncertainty (Hessian)\")\n", "ax.legend()\n", "plt.show()" @@ -894,7 +895,7 @@ ], "metadata": { "kernelspec": { - "display_name": "PY312 (3.12.3)", + "display_name": "PY312 (3.12.3.final.0)", "language": "python", "name": "python3" }, diff --git a/pyproject.toml b/pyproject.toml index 4ac2230..4e4fe9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,36 +4,35 @@ build-backend = "hatchling.build" [project] name = "caskade" -dynamic = [ - "dependencies", - "version" -] +dynamic = ["version"] authors = [ { name="Connor Stone", email="connorstone628@gmail.com" }, { name="Alexandre Adam", email="alexandre.adam@mila.quebec" }, ] -description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph." +description = "caskade handles your parameters for you without getting in your way." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = {file = "LICENSE"} keywords = [ "caskade", "DAG", "scientific python", "differentiable programming", - "pytorch" + "pytorch", + "jax", ] classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python :: 3" ] +dependencies = ["numpy>=1.24.0"] [project.urls] -Homepage = "https://github.com/ConnorStoneAstro/caskade" -Documentation = "https://github.com/ConnorStoneAstro/caskade" +Homepage = "https://caskade.readthedocs.io" +Documentation = "https://caskade.readthedocs.io" Repository = "https://github.com/ConnorStoneAstro/caskade" Issues = "https://github.com/ConnorStoneAstro/caskade/issues" @@ -50,16 +49,10 @@ dev = [ "emcee", "scipy", "ipywidgets", - "jax", - "jaxlib", ] -torch = [] -jax = ["jax", "jaxlib"] +torch = ["torch>=2,<3"] +jax = ["jax>=0.7.0"] numpy = [] -object = [] - -[tool.hatch.metadata.hooks.requirements_txt] -files = ["requirements.txt"] [tool.hatch.version] source = "vcs" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9e45725..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torch>=2,<3 -numpy>1.22,<2 \ No newline at end of file diff --git a/src/caskade/backend.py b/src/caskade/backend.py index 7ac1220..44d3e55 100644 --- a/src/caskade/backend.py +++ b/src/caskade/backend.py @@ -1,15 +1,17 @@ import os import importlib -from typing import Annotated +from typing import TYPE_CHECKING, TypeVar -from torch import Tensor import numpy as np from . import utils -ArrayLike = Annotated[ - Tensor, - "One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.", -] +if TYPE_CHECKING: + import torch # type: ignore + import jax.numpy as jnp # type: ignore + + ArrayLike = TypeVar("ArrayLike", np.ndarray, "torch.Tensor", "jnp.ndarray") +else: + ArrayLike = TypeVar("ArrayLike") class Backend: @@ -23,7 +25,15 @@ def backend(self): @backend.setter def backend(self, backend): if backend is None: - backend = os.getenv("CASKADE_BACKEND", "torch") + backend = os.getenv("CASKADE_BACKEND", "none").lower() + if backend == "none": # Try to find available backend + if importlib.util.find_spec("torch") is not None: + backend = "torch" + elif importlib.util.find_spec("jax") is not None: + backend = "jax" + else: + backend = "numpy" + self.module = self._load_backend(backend) self._backend = backend diff --git a/src/caskade/base.py b/src/caskade/base.py index 85e3b4c..302cd1b 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -284,9 +284,9 @@ def to(self, device=None, dtype=None): Parameters ---------- - device: (Optional[torch.device], optional) + device: (optional) The device to move the values to. Defaults to None. - dtype: (Optional[torch.dtype], optional) + dtype: (optional) The desired data type. Defaults to None. """ diff --git a/src/caskade/param.py b/src/caskade/param.py index 7fe79a4..1603cb7 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -402,9 +402,9 @@ def to(self, device=None, dtype=None) -> "Param": Parameters ---------- - device: (Optional[torch.device], optional) + device: (optional) The device to move the values to. Defaults to None. - dtype: (Optional[torch.dtype], optional) + dtype: (optional) The desired data type. Defaults to None. """ if device is not None: diff --git a/src/caskade/utils.py b/src/caskade/utils.py index b709650..ce99293 100644 --- a/src/caskade/utils.py +++ b/src/caskade/utils.py @@ -1,4 +1,3 @@ -import torch import numpy as np @@ -18,6 +17,8 @@ def broadcast_cat_torch(tensors, dim=-1): Returns: Tensor: The concatenated tensor. """ + import torch + if not tensors: raise ValueError("tensors argument must be a non-empty sequence") diff --git a/tests/test_backend.py b/tests/test_backend.py index e324127..894dc65 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,7 +1,7 @@ +import sys + from caskade import backend, Param -from torch import Tensor from numpy import ndarray -from jax import Array import pytest @@ -11,16 +11,55 @@ def test_backend(): init_backend = backend.backend # Change the backend - backend.backend = "torch" - p = Param("p", 1.0) - assert isinstance(p.value, Tensor) + if backend.backend == "torch": + from torch import Tensor + + p = Param("p", 1.0) + assert isinstance(p.value, Tensor) + if backend.backend == "jax": + from jax import Array + + p = Param("p", 1.0) + assert isinstance(p.value, Array) + backend.backend = "numpy" p = Param("p", 1.0) assert isinstance(p.value, ndarray) - backend.backend = "jax" - p = Param("p", 1.0) - assert isinstance(p.value, Array) backend.backend = init_backend with pytest.raises(ValueError): backend.backend = "invalid_backend" + + +def test_auto_set_backend_torch(monkeypatch): + if backend.backend != "torch": + pytest.skip("Skipping test because backend is not torch") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "jax", None) + backend.backend = None + + assert backend.backend == "torch" + + +def test_auto_set_backend_jax(monkeypatch): + if backend.backend != "jax": + pytest.skip("Skipping test because backend is not jax") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "torch", None) + backend.backend = None + + assert backend.backend == "jax" + + +def test_auto_set_backend_numpy(monkeypatch): + if backend.backend != "numpy": + pytest.skip("Skipping test because backend is not numpy") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "torch", None) + monkeypatch.setitem(sys.modules, "jax", None) + backend.backend = None + + assert backend.backend == "numpy" diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 39ce2af..3c4f610 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -2,7 +2,8 @@ import pytest import runpy import subprocess -import os +import os, sys +import numpy as np import caskade as ck import matplotlib @@ -42,9 +43,11 @@ def cleanup_py_scripts(nbpath): @pytest.mark.filterwarnings("ignore:FigureCanvasAgg") @pytest.mark.parametrize("nb_path", notebooks) -def test_notebook(nb_path): +def test_notebook(nb_path, monkeypatch): if ck.backend.backend != "torch": pytest.skip("Requires torch backend") + monkeypatch.setitem(sys.modules, "jax.numpy", np) + monkeypatch.setitem(sys.modules, "jax", np) convert_notebook_to_py(nb_path) try: runpy.run_path(nb_path.replace(".ipynb", ".py"), run_name="__main__") diff --git a/tests/test_save.py b/tests/test_save.py index f758d27..7d1819d 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -1,4 +1,4 @@ -from caskade import Module, Param, GraphError, SaveStateWarning, backend, BackendError +from caskade import Module, Param, GraphError, SaveStateWarning import numpy as np import gc import h5py