Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.11", "3.12", "3.13"]
os: [ubuntu-latest, windows-latest, macOS-latest]
backend: [torch, numpy, jax]

Expand Down Expand Up @@ -55,12 +55,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov torch wheel pydantic
pip install pytest pytest-cov wheel pydantic

# We only want to install this on one run, because otherwise we'll have
# duplicate annotations.
- name: Install error reporter
if: ${{ matrix.python-version == '3.10' }}
if: ${{ matrix.python-version == '3.13' }}
run: |
python -m pip install pytest-github-actions-annotate-failures

Expand Down Expand Up @@ -88,8 +88,9 @@ jobs:
run: |
pip install -e ".[dev,${{ matrix.backend }}]"
pip show ${{ env.PROJECT_NAME }}
export JAX_ENABLE_X64=True
shell: bash
env:
JAX_ENABLE_X64: "True"

- name: Test with pytest
run: |
Expand All @@ -98,22 +99,21 @@ jobs:
ls -a
cat .coverage
shell: bash
env:
CASKADE_BACKEND: ${{ matrix.backend }}

- name: Extra coverage report for jax checks
if:
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
run: |
echo "Running extra coverage report for jax checks"
pip install jax jaxlib
coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/
shell: bash
env:
CASKADE_BACKEND: jax
JAX_ENABLE_X64: "True"
- name: Extra coverage report for numpy checks
if:
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
run: |
echo "Running extra coverage report for numpy checks"
coverage run --append --source=${{ env.PROJECT_NAME }} -m pytest tests/
Expand All @@ -127,7 +127,7 @@ jobs:
CASKADE_BACKEND: numpy
- name: Upload coverage reports to Codecov with GitHub Action
if:
${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
${{ matrix.python-version == '3.13' && matrix.os == 'ubuntu-latest' && matrix.backend == 'torch' }}
uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
Expand Down
7 changes: 3 additions & 4 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ formats:
# Specify dependencies to be installed
# Define the system dependencies
build:
os: "ubuntu-20.04"
os: "ubuntu-26.04"
tools:
python: "3.10"
python: "3.13"
apt_packages:
- pandoc # Specify pandoc to be installed via apt-get
- graphviz # Specify graphviz to be installed via apt-get
Expand All @@ -39,6 +39,5 @@ build:
python:
install:
- requirements: docs/requirements.txt # Path to your requirements.txt file
- requirements: requirements.txt # Path to your requirements.txt file
- method: pip
path: . # Install the package itself
path: .[dev,torch,jax] # Install the package itself
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,21 @@ argument passing for complex nested simulators.
pip install caskade
```

More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html).
if you want to use `caskade` with `jax` then run:
This will give you the numpy version of ``caskade``. More details on the [docs page](https://caskade.readthedocs.io/en/latest/install.html).
if you want to use ``caskade`` with ``jax`` then run:

```bash
pip install caskade[jax]
```

Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements.
if you want to use ``caskade`` with ``torch`` then run:

```bash
pip install caskade[torch]
```

Don't worry if you got the wrong one. Just install the appropriate
``jax``/``torch`` directly if you want to add the option.

## Usage

Expand Down
35 changes: 30 additions & 5 deletions docs/source/install.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,51 @@
# Install

In a recent update we made it so that a basic ``caskade`` install just defaults
to the numpy version. This is part of our quest to make ``caskade`` as easy to
use and nimble as possible. ``torch`` is actually a pretty heavy package to
install (taking a lot of space/time to install and several seconds just to
import), so by fully eliminating it as a requirement, anyone with the numpy or
jax versions don't need to wait for an import they wont use.

## Basic install

To get the ``numpy`` version, just directly pip install:

``` bash
pip install caskade
```

> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11

## Install with PyTorch backend

```bash
pip install caskade[torch]
```

This will simply install ``torch`` along with ``numpy``. You can also always just
pip install ``torch`` yourself after a basic ``caskade`` install.

## Install with jax backend
> **Note:** PyTorch is not compatible with Python 3.12 on all systems, you may need 3.9 - 3.11

## Install with JAX backend

```bash
pip install caskade[jax]
```

This will simply install `jax`/`jaxlib` along with the other dependencies. It is
always possible to use the `torch` and `numpy` backends since they are core
requirements.
This will simply install ``jax`` along with ``numpy``. You can also always just
pip install ``jax`` yourself after a basic ``caskade`` install.

> **Note:** For M1 Mac users there can be compatibility issues with jax/jaxlib. See [discussion here](https://stackoverflow.com/questions/68327863/importing-jax-fails-on-mac-with-m1-chip) and consider installing `jaxlib==0.4.35`.

## Install with PyTorch and JAX

```bash
pip install caskade[torch,jax]
```

All the extra ``[torch,jax]`` bit does is add those to the list of dependencies, so you can always just install them yourself with pip.

## Install from source

1. Fork the repo on [GitHub](https://github.com/ConnorStoneAstro/caskade)
Expand Down
14 changes: 10 additions & 4 deletions docs/source/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@
[![PyPI - Version](https://img.shields.io/pypi/v/caskade)](https://pypi.org/project/caskade/)
[![Documentation Status](https://readthedocs.org/projects/caskade/badge/?version=latest)](https://caskade.readthedocs.io/en/latest/?badge=latest)

Build scientific simulators, treating them as a directed acyclic graph. Handles
argument passing for complex nested simulators.
Build scientific simulators, treating them abstractly as a directed acyclic
graph. Handles argument passing for complex nested simulators.

## Install

``` bash
pip install caskade
```

if you want to use `caskade` with `jax` then run:
if you want to use ``caskade`` with ``jax`` or ``torch`` then run:

```bash
pip install caskade[jax]
```

Alternately, just pip install `jax`/`jaxlib` separately as they are the only extra requirements.
or

```bash
pip install caskade[torch]
```

Alternately, just pip install ``jax``/``torch`` separately as they are the only extra requirements.

## Usage

Expand Down
19 changes: 10 additions & 9 deletions docs/source/notebooks/WorkedExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@
"lightcurvemodel.to_dynamic()\n",
"fit_vals = likelihood2.get_values()\n",
"hess = -hessian(likelihood2, fit_vals, strict=True)\n",
"hess_inv = torch.linalg.inv(hess) # Invert the Hessian to get the covariance matrix\n",
"light_curve_sigma = torch.sqrt(torch.diag(hess_inv).abs()).numpy()\n",
"hess_inv = torch.linalg.inv(hess).numpy() # Invert the Hessian to get the covariance matrix\n",
"light_curve_sigma = np.sqrt(np.abs(np.diag(hess_inv)))\n",
"print(\n",
" f\"Light Curve t0: {fit_vals[0].item():.2f} ± {light_curve_sigma[0]:.2f} vs {SN_lightcurve.t0.value.item():.2f} (true)\"\n",
")\n",
Expand Down Expand Up @@ -664,23 +664,24 @@
"ax.axhline(SN_lightcurve.peak_flux.value.item(), color=\"r\", linestyle=\"--\")\n",
"ax.set_xlabel(\"Sigma\")\n",
"ax.set_ylabel(\"Peak Flux\")\n",
"lambda_, v = np.linalg.eig(hess_inv[1:, 1:])\n",
"lambda_, v = np.linalg.eigh(hess_inv[1:, 1:])\n",
"lambda_ = np.sqrt(lambda_)\n",
"angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))\n",
"\n",
"angle = np.rad2deg(np.arctan2(v[1, 1], v[0, 1]))\n",
"for k in [1, 2]:\n",
" ellipse = Ellipse(\n",
" xy=(fit_vals[1].item(), fit_vals[2].item()),\n",
" width=lambda_[0] * k * 2,\n",
" height=lambda_[1] * k * 2,\n",
" width=lambda_[1] * k * 2,\n",
" height=lambda_[0] * k * 2,\n",
" angle=angle,\n",
" edgecolor=\"black\",\n",
" facecolor=\"grey\",\n",
" alpha=0.6,\n",
" )\n",
" ax.add_artist(ellipse)\n",
"plt.plot([], [], c=\"k\", label=\"Likelihood Contours\")\n",
"ax.set_xlim(fit_vals[1].item() - lambda_[0] * 3, fit_vals[1].item() + lambda_[0] * 3)\n",
"ax.set_ylim(fit_vals[2].item() - lambda_[1] * 3, fit_vals[2].item() + lambda_[1] * 3)\n",
"ax.set_xlim(fit_vals[1].item() - lambda_[1] * 3, fit_vals[1].item() + lambda_[1] * 3)\n",
"ax.set_ylim(fit_vals[2].item() - lambda_[0] * 3, fit_vals[2].item() + lambda_[0] * 3)\n",
"ax.set_title(\"Light Curve Parameter Uncertainty (Hessian)\")\n",
"ax.legend()\n",
"plt.show()"
Expand Down Expand Up @@ -894,7 +895,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "PY312 (3.12.3)",
"display_name": "PY312 (3.12.3.final.0)",
"language": "python",
"name": "python3"
},
Expand Down
29 changes: 11 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,35 @@ build-backend = "hatchling.build"

[project]
name = "caskade"
dynamic = [
"dependencies",
"version"
]
dynamic = ["version"]
authors = [
{ name="Connor Stone", email="connorstone628@gmail.com" },
{ name="Alexandre Adam", email="alexandre.adam@mila.quebec" },
]
description = "Package for building scientific simulators, with dynamic arguments arranged in a directed acyclic graph."
description = "caskade handles your parameters for you without getting in your way."
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.11"
license = {file = "LICENSE"}
keywords = [
"caskade",
"DAG",
"scientific python",
"differentiable programming",
"pytorch"
"pytorch",
"jax",
]
classifiers=[
"Development Status :: 1 - Planning",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3"
]
dependencies = ["numpy>=1.24.0"]

[project.urls]
Homepage = "https://github.com/ConnorStoneAstro/caskade"
Documentation = "https://github.com/ConnorStoneAstro/caskade"
Homepage = "https://caskade.readthedocs.io"
Documentation = "https://caskade.readthedocs.io"
Repository = "https://github.com/ConnorStoneAstro/caskade"
Issues = "https://github.com/ConnorStoneAstro/caskade/issues"

Expand All @@ -50,16 +49,10 @@ dev = [
"emcee",
"scipy",
"ipywidgets",
"jax",
"jaxlib",
]
torch = []
jax = ["jax", "jaxlib"]
torch = ["torch>=2,<3"]
jax = ["jax>=0.7.0"]
numpy = []
object = []

[tool.hatch.metadata.hooks.requirements_txt]
files = ["requirements.txt"]

[tool.hatch.version]
source = "vcs"
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

24 changes: 17 additions & 7 deletions src/caskade/backend.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import importlib
from typing import Annotated
from typing import TYPE_CHECKING, TypeVar

from torch import Tensor
import numpy as np
from . import utils

ArrayLike = Annotated[
Tensor,
"One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.",
]
if TYPE_CHECKING:
import torch # type: ignore
import jax.numpy as jnp # type: ignore

ArrayLike = TypeVar("ArrayLike", np.ndarray, "torch.Tensor", "jnp.ndarray")
else:
ArrayLike = TypeVar("ArrayLike")


class Backend:
Expand All @@ -23,7 +25,15 @@
@backend.setter
def backend(self, backend):
if backend is None:
backend = os.getenv("CASKADE_BACKEND", "torch")
backend = os.getenv("CASKADE_BACKEND", "none").lower()
if backend == "none": # Try to find available backend
if importlib.util.find_spec("torch") is not None:
backend = "torch"
elif importlib.util.find_spec("jax") is not None:
backend = "jax"
else:
backend = "numpy"

self.module = self._load_backend(backend)
self._backend = backend

Expand Down Expand Up @@ -175,10 +185,10 @@
return self.module.all(array)

def log(self, array):
return self.module.log(array)

Check warning on line 188 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend numpy

invalid value encountered in log

Check warning on line 188 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS macOS-latest - Backend numpy

invalid value encountered in log

Check warning on line 188 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend torch

invalid value encountered in log

def exp(self, array):
return self.module.exp(array)

Check warning on line 191 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 191 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS macOS-latest - Backend numpy

overflow encountered in exp

Check warning on line 191 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend torch

overflow encountered in exp

def sum(self, array, axis=None):
return self.module.sum(array, axis=axis)
Expand All @@ -190,7 +200,7 @@
return self.jax.nn.sigmoid(array)

def _sigmoid_numpy(self, array):
return 1 / (1 + self.module.exp(-array))

Check warning on line 203 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend numpy

overflow encountered in exp

Check warning on line 203 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS macOS-latest - Backend numpy

overflow encountered in exp

Check warning on line 203 in src/caskade/backend.py

View workflow job for this annotation

GitHub Actions / Python 3.13 - OS ubuntu-latest - Backend torch

overflow encountered in exp

def _logit_torch(self, array):
return self.module.logit(array)
Expand Down
4 changes: 2 additions & 2 deletions src/caskade/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def to(self, device=None, dtype=None):

Parameters
----------
device: (Optional[torch.device], optional)
device: (optional)
The device to move the values to. Defaults to None.
dtype: (Optional[torch.dtype], optional)
dtype: (optional)
The desired data type. Defaults to None.
"""

Expand Down
4 changes: 2 additions & 2 deletions src/caskade/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ def to(self, device=None, dtype=None) -> "Param":

Parameters
----------
device: (Optional[torch.device], optional)
device: (optional)
The device to move the values to. Defaults to None.
dtype: (Optional[torch.dtype], optional)
dtype: (optional)
The desired data type. Defaults to None.
"""
if device is not None:
Expand Down
Loading
Loading