From 7527d3dd9f519be5fe69ab227a0fc39578a11640 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 11:59:31 -0400 Subject: [PATCH 01/13] version updates for caskade --- .github/workflows/ci.yml | 15 ++++++--------- .readthedocs.yml | 8 +++----- README.md | 13 ++++++++++--- pyproject.toml | 16 ++++++---------- requirements.txt | 2 -- src/caskade/backend.py | 24 +++++++++++++++++------- src/caskade/base.py | 4 ++-- src/caskade/param.py | 4 ++-- src/caskade/utils.py | 3 ++- 9 files changed, 48 insertions(+), 41 deletions(-) delete mode 100644 requirements.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cdd43af..ebbba60 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.11", "3.12", "3.13"] os: [ubuntu-latest, windows-latest, macOS-latest] backend: [torch, numpy, jax] @@ -55,12 +55,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-cov torch wheel pydantic + pip install pytest pytest-cov wheel pydantic # We only want to install this on one run, because otherwise we'll have # duplicate annotations. - name: Install error reporter - if: ${{ matrix.python-version == '3.10' }} + if: ${{ matrix.python-version == '3.13' }} run: | python -m pip install pytest-github-actions-annotate-failures @@ -98,22 +98,19 @@ jobs: ls -a cat .coverage shell: bash - env: - CASKADE_BACKEND: ${{ matrix.backend }} - name: Extra coverage report for jax checks if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} run: | echo "Running extra coverage report for jax checks" - pip install jax jaxlib coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/ shell: bash env: CASKADE_BACKEND: jax - name: Extra coverage report for numpy checks if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} run: | echo "Running extra coverage report for numpy checks" coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/ @@ -127,7 +124,7 @@ jobs: CASKADE_BACKEND: numpy - name: Upload coverage reports to Codecov with GitHub Action if: - ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} + ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.readthedocs.yml b/.readthedocs.yml index c2374a5..0afa10b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -10,9 +10,8 @@ sphinx: configuration: docs/source/conf.py # Optionally build your docs in additional formats such as PDF and ePub -# formats: -# - pdf -# - epub +formats: + - pdf # Optional but recommended, declare the Python requirements required # to build your documentation @@ -22,7 +21,7 @@ sphinx: build: os: "ubuntu-20.04" tools: - python: "3.10" + python: "3.13" apt_packages: - pandoc # Specify pandoc to be installed via apt-get - graphviz # Specify graphviz to be installed via apt-get @@ -38,6 +37,5 @@ build: python: install: - requirements: docs/requirements.txt # Path to your requirements.txt file - - requirements: requirements.txt # Path to your requirements.txt file - method: pip path: . # Install the package itself \ No newline at end of file diff --git a/README.md b/README.md index 484276a..f101610 100644 --- a/README.md +++ b/README.md @@ -22,14 +22,21 @@ argument passing for complex nested simulators. pip install caskade ``` -More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html). -if you want to use `caskade` with `jax` then run: +This will give you the numpy version of ``caskade``. More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html). +if you want to use ``caskade`` with ``jax`` then run: ```bash pip install caskade[jax] ``` -Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements. +if you want to use ``caskade`` with ``torch`` then run: + +```bash +pip install caskade[torch] +``` + +Don't worry if you got the wrong one. Just install the appropriate +``jax``/``torch`` directly if you want to add the option. ## Usage diff --git a/pyproject.toml b/pyproject.toml index 4ac2230..294b991 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,14 +14,15 @@ authors = [ ] description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = {file = "LICENSE"} keywords = [ "caskade", "DAG", "scientific python", "differentiable programming", - "pytorch" + "pytorch", + "jax", ] classifiers=[ "Development Status :: 1 - Planning", @@ -30,6 +31,7 @@ classifiers=[ "Operating System :: OS Independent", "Programming Language :: Python :: 3" ] +dependencies = ["numpy>=2.0.0"] [project.urls] Homepage = "https://github.com/ConnorStoneAstro/caskade" @@ -50,16 +52,10 @@ dev = [ "emcee", "scipy", "ipywidgets", - "jax", - "jaxlib", ] -torch = [] -jax = ["jax", "jaxlib"] +torch = ["torch>=2,<3"] +jax = ["jax>=0.8.0"] numpy = [] -object = [] - -[tool.hatch.metadata.hooks.requirements_txt] -files = ["requirements.txt"] [tool.hatch.version] source = "vcs" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9e45725..0000000 --- a/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -torch>=2,<3 -numpy>1.22,<2 \ No newline at end of file diff --git a/src/caskade/backend.py b/src/caskade/backend.py index 7ac1220..44d3e55 100644 --- a/src/caskade/backend.py +++ b/src/caskade/backend.py @@ -1,15 +1,17 @@ import os import importlib -from typing import Annotated +from typing import TYPE_CHECKING, TypeVar -from torch import Tensor import numpy as np from . import utils -ArrayLike = Annotated[ - Tensor, - "One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.", -] +if TYPE_CHECKING: + import torch # type: ignore + import jax.numpy as jnp # type: ignore + + ArrayLike = TypeVar("ArrayLike", np.ndarray, "torch.Tensor", "jnp.ndarray") +else: + ArrayLike = TypeVar("ArrayLike") class Backend: @@ -23,7 +25,15 @@ def backend(self): @backend.setter def backend(self, backend): if backend is None: - backend = os.getenv("CASKADE_BACKEND", "torch") + backend = os.getenv("CASKADE_BACKEND", "none").lower() + if backend == "none": # Try to find available backend + if importlib.util.find_spec("torch") is not None: + backend = "torch" + elif importlib.util.find_spec("jax") is not None: + backend = "jax" + else: + backend = "numpy" + self.module = self._load_backend(backend) self._backend = backend diff --git a/src/caskade/base.py b/src/caskade/base.py index 85e3b4c..302cd1b 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -284,9 +284,9 @@ def to(self, device=None, dtype=None): Parameters ---------- - device: (Optional[torch.device], optional) + device: (optional) The device to move the values to. Defaults to None. - dtype: (Optional[torch.dtype], optional) + dtype: (optional) The desired data type. Defaults to None. """ diff --git a/src/caskade/param.py b/src/caskade/param.py index 7fe79a4..1603cb7 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -402,9 +402,9 @@ def to(self, device=None, dtype=None) -> "Param": Parameters ---------- - device: (Optional[torch.device], optional) + device: (optional) The device to move the values to. Defaults to None. - dtype: (Optional[torch.dtype], optional) + dtype: (optional) The desired data type. Defaults to None. """ if device is not None: diff --git a/src/caskade/utils.py b/src/caskade/utils.py index b709650..ce99293 100644 --- a/src/caskade/utils.py +++ b/src/caskade/utils.py @@ -1,4 +1,3 @@ -import torch import numpy as np @@ -18,6 +17,8 @@ def broadcast_cat_torch(tensors, dim=-1): Returns: Tensor: The concatenated tensor. """ + import torch + if not tensors: raise ValueError("tensors argument must be a non-empty sequence") From 0a340f0e49ad8560852170c571d7967b7e4c8cf7 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 12:03:01 -0400 Subject: [PATCH 02/13] fix readthedocs version config --- .readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 0afa10b..8d8148d 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -19,7 +19,7 @@ formats: # Specify dependencies to be installed # Define the system dependencies build: - os: "ubuntu-20.04" + os: "ubuntu-26.04" tools: python: "3.13" apt_packages: From 8570fba504f4560aa6119b951f723354895c8bcd Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 12:04:20 -0400 Subject: [PATCH 03/13] fix pyproject.toml dependencies dynamic --- pyproject.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 294b991..83da61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,10 +4,7 @@ build-backend = "hatchling.build" [project] name = "caskade" -dynamic = [ - "dependencies", - "version" -] +dynamic = ["version"] authors = [ { name="Connor Stone", email="connorstone628@gmail.com" }, { name="Alexandre Adam", email="alexandre.adam@mila.quebec" }, From f46014c5f08ffe67bbd35e9fd15f6267b6080133 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 12:59:26 -0400 Subject: [PATCH 04/13] remove explicit imports of packages --- .readthedocs.yml | 2 +- docs/source/install.md | 33 +++++++++++++++++++++++++++++---- tests/test_backend.py | 19 +++++++++++-------- tests/test_save.py | 2 +- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 34c44c4..5f4a916 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -40,4 +40,4 @@ python: install: - requirements: docs/requirements.txt # Path to your requirements.txt file - method: pip - path: . # Install the package itself + path: .[dev,torch,jax] # Install the package itself diff --git a/docs/source/install.md b/docs/source/install.md index fd09505..d290614 100644 --- a/docs/source/install.md +++ b/docs/source/install.md @@ -1,26 +1,51 @@ # Install +In a recent update we made it so that a basic ``caskade`` install just defaults +to the numpy version. This is part of our quest to make ``caskade`` as easy to +use and nimble as possible. ``torch`` is actually a pretty heavy package to +install (taking a lot of space/time to install and several seconds just to +import), so by fully eliminating it as a requirement, anyone with the numpy or +jax versions don't need to wait for an import they wont use. + ## Basic install +To get the ``numpy`` version, just directly pip install: + ``` bash pip install caskade ``` > **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11 +## Install with PyTorch backend + +```bash +pip install caskade[torch] +``` + +This will simply install ``torch`` along with ``numpy``. You can also always just +pip install ``torch`` yourself after a basic ``caskade`` install. -## Install with jax backend + +## Install with JAX backend ```bash pip install caskade[jax] ``` -This will simply install `jax`/`jaxlib` along with the other dependencies. It is -always possible to use the `torch` and `numpy` backends since they are core -requirements. +This will simply install ``jax`` along with ``numpy``. You can also always just +pip install ``jax`` yourself after a basic ``caskade`` install. > **Note:** For M1 Mac users there can be compatibility issues with jax/jaxlib. See [discussion here](https://stackoverflow.com/questions/68327863/importing-jax-fails-on-mac-with-m1-chip) and consider installing `jaxlib==0.4.35`. +## Install with PyTorch and JAX + +```bash +pip install caskade[torch,jax] +``` + +All the extra ``[torch,jax]`` bit does is add those to the list of dependencies, so you can always just install them yourself with pip. + ## Install from source 1. Fork the repo on [GitHub](https://github.com/ConnorStoneAstro/caskade) diff --git a/tests/test_backend.py b/tests/test_backend.py index e324127..743e402 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,7 +1,5 @@ from caskade import backend, Param -from torch import Tensor from numpy import ndarray -from jax import Array import pytest @@ -11,15 +9,20 @@ def test_backend(): init_backend = backend.backend # Change the backend - backend.backend = "torch" - p = Param("p", 1.0) - assert isinstance(p.value, Tensor) + if backend.backend == "torch": + from torch import Tensor + + p = Param("p", 1.0) + assert isinstance(p.value, Tensor) + if backend.backend == "jax": + from jax import Array + + p = Param("p", 1.0) + assert isinstance(p.value, Array) + backend.backend = "numpy" p = Param("p", 1.0) assert isinstance(p.value, ndarray) - backend.backend = "jax" - p = Param("p", 1.0) - assert isinstance(p.value, Array) backend.backend = init_backend with pytest.raises(ValueError): diff --git a/tests/test_save.py b/tests/test_save.py index f758d27..7d1819d 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -1,4 +1,4 @@ -from caskade import Module, Param, GraphError, SaveStateWarning, backend, BackendError +from caskade import Module, Param, GraphError, SaveStateWarning import numpy as np import gc import h5py From e875d6005ad643ae1d2ec7fb6fdfc1441e7d8af3 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 13:13:11 -0400 Subject: [PATCH 05/13] update notebook test to patch jax for the sake of testing in torch only environment --- tests/test_notebooks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 39ce2af..3c4f610 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -2,7 +2,8 @@ import pytest import runpy import subprocess -import os +import os, sys +import numpy as np import caskade as ck import matplotlib @@ -42,9 +43,11 @@ def cleanup_py_scripts(nbpath): @pytest.mark.filterwarnings("ignore:FigureCanvasAgg") @pytest.mark.parametrize("nb_path", notebooks) -def test_notebook(nb_path): +def test_notebook(nb_path, monkeypatch): if ck.backend.backend != "torch": pytest.skip("Requires torch backend") + monkeypatch.setitem(sys.modules, "jax.numpy", np) + monkeypatch.setitem(sys.modules, "jax", np) convert_notebook_to_py(nb_path) try: runpy.run_path(nb_path.replace(".ipynb", ".py"), run_name="__main__") From 292f2bc08315ae226026b71534d9a4433e2c6181 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 13:21:34 -0400 Subject: [PATCH 06/13] fix example with explicit conversion to numpy --- docs/source/notebooks/WorkedExample.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/notebooks/WorkedExample.ipynb b/docs/source/notebooks/WorkedExample.ipynb index 12c532c..1286db4 100644 --- a/docs/source/notebooks/WorkedExample.ipynb +++ b/docs/source/notebooks/WorkedExample.ipynb @@ -635,8 +635,8 @@ "lightcurvemodel.to_dynamic()\n", "fit_vals = likelihood2.get_values()\n", "hess = -hessian(likelihood2, fit_vals, strict=True)\n", - "hess_inv = torch.linalg.inv(hess) # Invert the Hessian to get the covariance matrix\n", - "light_curve_sigma = torch.sqrt(torch.diag(hess_inv).abs()).numpy()\n", + "hess_inv = torch.linalg.inv(hess).numpy() # Invert the Hessian to get the covariance matrix\n", + "light_curve_sigma = np.sqrt(np.abs(np.diag(hess_inv)))\n", "print(\n", " f\"Light Curve t0: {fit_vals[0].item():.2f} ± {light_curve_sigma[0]:.2f} vs {SN_lightcurve.t0.value.item():.2f} (true)\"\n", ")\n", @@ -894,7 +894,7 @@ ], "metadata": { "kernelspec": { - "display_name": "PY312 (3.12.3)", + "display_name": "PY312 (3.12.3.final.0)", "language": "python", "name": "python3" }, From ea5b9228dd3516af4fec045e8c492722daf78d5b Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 13:41:38 -0400 Subject: [PATCH 07/13] fix eig from complex to eigh real eigenvectors --- docs/source/notebooks/WorkedExample.ipynb | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/notebooks/WorkedExample.ipynb b/docs/source/notebooks/WorkedExample.ipynb index 1286db4..eadac21 100644 --- a/docs/source/notebooks/WorkedExample.ipynb +++ b/docs/source/notebooks/WorkedExample.ipynb @@ -664,14 +664,15 @@ "ax.axhline(SN_lightcurve.peak_flux.value.item(), color=\"r\", linestyle=\"--\")\n", "ax.set_xlabel(\"Sigma\")\n", "ax.set_ylabel(\"Peak Flux\")\n", - "lambda_, v = np.linalg.eig(hess_inv[1:, 1:])\n", + "lambda_, v = np.linalg.eigh(hess_inv[1:, 1:])\n", "lambda_ = np.sqrt(lambda_)\n", - "angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))\n", + "\n", + "angle = np.rad2deg(np.arctan2(v[1, 1], v[0, 1]))\n", "for k in [1, 2]:\n", " ellipse = Ellipse(\n", " xy=(fit_vals[1].item(), fit_vals[2].item()),\n", - " width=lambda_[0] * k * 2,\n", - " height=lambda_[1] * k * 2,\n", + " width=lambda_[1] * k * 2,\n", + " height=lambda_[0] * k * 2,\n", " angle=angle,\n", " edgecolor=\"black\",\n", " facecolor=\"grey\",\n", @@ -679,8 +680,8 @@ " )\n", " ax.add_artist(ellipse)\n", "plt.plot([], [], c=\"k\", label=\"Likelihood Contours\")\n", - "ax.set_xlim(fit_vals[1].item() - lambda_[0] * 3, fit_vals[1].item() + lambda_[0] * 3)\n", - "ax.set_ylim(fit_vals[2].item() - lambda_[1] * 3, fit_vals[2].item() + lambda_[1] * 3)\n", + "ax.set_xlim(fit_vals[1].item() - lambda_[1] * 3, fit_vals[1].item() + lambda_[1] * 3)\n", + "ax.set_ylim(fit_vals[2].item() - lambda_[0] * 3, fit_vals[2].item() + lambda_[0] * 3)\n", "ax.set_title(\"Light Curve Parameter Uncertainty (Hessian)\")\n", "ax.legend()\n", "plt.show()" From b5a177eb0000568bae1e87afe643a86a7a7b3b84 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 13:46:41 -0400 Subject: [PATCH 08/13] install jax for jax checks --- .github/workflows/ci.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ebbba60..c9b0b89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,8 +88,9 @@ jobs: run: | pip install -e ".[dev,${{ matrix.backend }}]" pip show ${{ env.PROJECT_NAME }} - export JAX_ENABLE_X64=True shell: bash + env: + JAX_ENABLE_X64: "True" - name: Test with pytest run: | @@ -104,10 +105,12 @@ jobs: ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} run: | echo "Running extra coverage report for jax checks" + pip install jax jaxlib coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/ shell: bash env: CASKADE_BACKEND: jax + JAX_ENABLE_X64: "True" - name: Extra coverage report for numpy checks if: ${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }} From b192ddbf6615d8348e44b962e5022b39e6942a08 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 14:31:26 -0400 Subject: [PATCH 09/13] coverage for automatic backend setting --- tests/test_backend.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_backend.py b/tests/test_backend.py index 743e402..894dc65 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,3 +1,5 @@ +import sys + from caskade import backend, Param from numpy import ndarray @@ -27,3 +29,37 @@ def test_backend(): with pytest.raises(ValueError): backend.backend = "invalid_backend" + + +def test_auto_set_backend_torch(monkeypatch): + if backend.backend != "torch": + pytest.skip("Skipping test because backend is not torch") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "jax", None) + backend.backend = None + + assert backend.backend == "torch" + + +def test_auto_set_backend_jax(monkeypatch): + if backend.backend != "jax": + pytest.skip("Skipping test because backend is not jax") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "torch", None) + backend.backend = None + + assert backend.backend == "jax" + + +def test_auto_set_backend_numpy(monkeypatch): + if backend.backend != "numpy": + pytest.skip("Skipping test because backend is not numpy") + + monkeypatch.delenv("CASKADE_BACKEND", raising=False) + monkeypatch.setitem(sys.modules, "torch", None) + monkeypatch.setitem(sys.modules, "jax", None) + backend.backend = None + + assert backend.backend == "numpy" From 55e339a45a66e1fa8601ac933f0d1cf4a2c4c3ab Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 14:43:45 -0400 Subject: [PATCH 10/13] update intro to show backend install options --- docs/source/intro.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/source/intro.md b/docs/source/intro.md index 9d661b6..6ae78aa 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -15,13 +15,19 @@ argument passing for complex nested simulators. pip install caskade ``` -if you want to use `caskade` with `jax` then run: +if you want to use ``caskade`` with ``jax`` or ``torch`` then run: ```bash pip install caskade[jax] ``` -Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements. +or + +```bash +pip install caskade[torch] +``` + +Alternately, just pip install ``jax``/``torch`` separately as they are the only extra requirements. ## Usage From 1241a88a01e3e4f2fe0df357dc2c1ba982ec4cca Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 14:44:20 -0400 Subject: [PATCH 11/13] move pytorch note to pytorch section --- docs/source/install.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/install.md b/docs/source/install.md index d290614..632e586 100644 --- a/docs/source/install.md +++ b/docs/source/install.md @@ -15,7 +15,6 @@ To get the ``numpy`` version, just directly pip install: pip install caskade ``` -> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11 ## Install with PyTorch backend @@ -26,6 +25,7 @@ pip install caskade[torch] This will simply install ``torch`` along with ``numpy``. You can also always just pip install ``torch`` yourself after a basic ``caskade`` install. +> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11 ## Install with JAX backend From 910f352bc1964df3e7ce8572e1dacb00f72c9f33 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 14:53:30 -0400 Subject: [PATCH 12/13] cleanup pyproject --- docs/source/intro.md | 4 ++-- pyproject.toml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/intro.md b/docs/source/intro.md index 6ae78aa..360d39c 100644 --- a/docs/source/intro.md +++ b/docs/source/intro.md @@ -6,8 +6,8 @@ [![PyPI - Version](https://img.shields.io/pypi/v/caskade)](https://pypi.org/project/caskade/) [![Documentation Status](https://readthedocs.org/projects/caskade/badge/?version=latest)](https://caskade.readthedocs.io/en/latest/?badge=latest) -Build scientific simulators, treating them as a directed acyclic graph. Handles -argument passing for complex nested simulators. +Build scientific simulators, treating them abstractly as a directed acyclic +graph. Handles argument passing for complex nested simulators. ## Install diff --git a/pyproject.toml b/pyproject.toml index 83da61c..6c19b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ { name="Connor Stone", email="connorstone628@gmail.com" }, { name="Alexandre Adam", email="alexandre.adam@mila.quebec" }, ] -description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph." +description = "caskade handles your parameters for you without getting in your way." readme = "README.md" requires-python = ">=3.11" license = {file = "LICENSE"} @@ -22,7 +22,7 @@ keywords = [ "jax", ] classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", @@ -31,14 +31,14 @@ classifiers=[ dependencies = ["numpy>=2.0.0"] [project.urls] -Homepage = "https://github.com/ConnorStoneAstro/caskade" -Documentation = "https://github.com/ConnorStoneAstro/caskade" +Homepage = "https://caskade.readthedocs.io" +Documentation = "https://caskade.readthedocs.io" Repository = "https://github.com/ConnorStoneAstro/caskade" Issues = "https://github.com/ConnorStoneAstro/caskade/issues" [project.optional-dependencies] dev = [ - "pytest>=8.0,<9", + "pytest>=8.0", "pytest-cov>=4.1,<5", "pytest-mock>=3.12,<4", "pre-commit>=3.6,<4", From 9a8b3e8bb15317a53f75f706edfd14e4925e92bc Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 23 Jun 2026 15:03:39 -0400 Subject: [PATCH 13/13] relax numpy requirement --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6c19b1e..4e4fe9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ classifiers=[ "Operating System :: OS Independent", "Programming Language :: Python :: 3" ] -dependencies = ["numpy>=2.0.0"] +dependencies = ["numpy>=1.24.0"] [project.urls] Homepage = "https://caskade.readthedocs.io" @@ -38,7 +38,7 @@ Issues = "https://github.com/ConnorStoneAstro/caskade/issues" [project.optional-dependencies] dev = [ - "pytest>=8.0", + "pytest>=8.0,<9", "pytest-cov>=4.1,<5", "pytest-mock>=3.12,<4", "pre-commit>=3.6,<4", @@ -51,7 +51,7 @@ dev = [ "ipywidgets", ] torch = ["torch>=2,<3"] -jax = ["jax>=0.8.0"] +jax = ["jax>=0.7.0"] numpy = [] [tool.hatch.version]