Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ are not supported by the main library.
dpsgd
galore
GaLoreState
lnb
madgrad
MadgradState
mechanize
Expand All @@ -51,10 +52,12 @@ are not supported by the main library.
SplitRealAndImaginaryState
scale_by_ademamix
ScaleByAdemamixState
ScaleByLNBState
scale_by_simplified_ademamix
ScaleBySimplifiedAdEMAMixState
scale_by_adopt
scale_by_acprop
scale_by_lnb
scale_by_madgrad
scale_by_muon
hutchinson_estimator_diag_hessian
Expand Down
294 changes: 294 additions & 0 deletions examples/contrib/lnb_mnist.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Linear Neuron Boosting on MNIST\n",
"\n",
"This notebook depends on the following packages:\n",
"\n",
"- `equinox`\n",
"- `torch` (only for the DataLoader, can be CPU-only)\n",
"- `torchvision`\n",
"- `xdg-base-dirs`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"import itertools\n",
"from typing import Any, Tuple, TypeAlias\n",
"\n",
"import equinox as eqx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.tree_util as jtu\n",
"import numpy as np\n",
"import optax\n",
"import optax.contrib\n",
"import torch\n",
"import torchvision.datasets as tvdatasets\n",
"import torchvision.transforms as tvtransforms\n",
"import xdg_base_dirs\n",
"\n",
"JaxInitializer: TypeAlias = jax.nn.initializers.Initializer\n",
"PyTree: TypeAlias = Any"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Equinox Utilities"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_neuron(node: PyTree) -> bool:\n",
" \"\"\"Returns True if the node is an Equinox module that is a linear neuron.\"\"\"\n",
" return isinstance(node, eqx.nn.Linear) or isinstance(node, eqx.nn.Conv)\n",
"\n",
"\n",
"def initialize(\n",
" model: eqx.Module,\n",
" weight_init: JaxInitializer,\n",
" bias_init: JaxInitializer = jax.nn.initializers.zeros,\n",
" *,\n",
" key: jax.Array,\n",
") -> eqx.Module:\n",
" \"\"\"Applies the initializers on the `weights` and `bias` fields for any Neuron in model.\n",
"\n",
" Args:\n",
" model: the Module to initialize\n",
" weight_init: the [initializer]\n",
" (https://jax.readthedocs.io/en/latest/jax.nn.initializers.html) to use for the weights\n",
" bias_init: the initializer for the biases, if present.\n",
" key: the random key to split.\n",
" \"\"\"\n",
"\n",
" def _init_node(node: PyTree, key: jax.Array) -> PyTree:\n",
" if is_neuron(node):\n",
" weight_key, bias_key = jax.random.split(key, 2)\n",
" weight = weight_init(weight_key, node.weight.shape, node.weight.dtype)\n",
" node = eqx.tree_at(lambda n: n.weight, node, weight)\n",
" if getattr(node, \"bias\", None) is not None:\n",
" bias = bias_init(bias_key, node.bias.shape, node.bias.dtype)\n",
" node = eqx.tree_at(lambda n: n.bias, node, bias)\n",
" return node\n",
" else:\n",
" return node\n",
"\n",
" leaves, treedef = jtu.tree_flatten(model, is_leaf=is_neuron)\n",
" keys = jax.random.split(key, len(leaves))\n",
" leaves = itertools.starmap(_init_node, zip(leaves, keys, strict=True))\n",
" return jtu.tree_unflatten(treedef, leaves)\n",
"\n",
"\n",
"def wrap_neurons(model: eqx.nn.MLP) -> eqx.Module:\n",
" \"\"\"Wraps model with Module whose call operator returns the tuple [output, x_neurons].\n",
"\n",
" Where x_neurons: list[Array] are the inputs to each Neuron in model.\n",
" \"\"\"\n",
" leaves, treedef = jtu.tree_flatten(model, is_leaf=is_neuron)\n",
"\n",
" # A buffer containing the inputs to each neuron. Follows the example from:\n",
" # https://github.com/patrick-kidger/equinox/issues/186#issuecomment-1233606690\n",
" x_neurons = [None] * sum(map(is_neuron, leaves))\n",
"\n",
" # Saves the input to the respective location in x_neurons.\n",
" class Neuron(eqx.Module):\n",
" f: eqx.Module\n",
" idx: int\n",
"\n",
" def __call__(self, x: jax.Array, *args: Any, **kwargs: Any) -> jax.Array:\n",
" x_neurons[self.idx] = x\n",
" return self.f(x, *args, **kwargs)\n",
"\n",
" # Wrap around the model and return each neuron's input feature, in tree traversal order.\n",
" class Wrapper(eqx.Module):\n",
" F: eqx.Module\n",
"\n",
" def __call__(self, *args: Any, **kwargs: Any) -> Tuple[Any, list[jax.Array]]:\n",
" return self.F(*args, **kwargs), x_neurons\n",
"\n",
" # Replace all eqx.Modules that are linear neurons.\n",
" idx = itertools.count()\n",
" leaves = [Neuron(leaf, next(idx)) if is_neuron(leaf) else leaf for leaf in leaves]\n",
" return Wrapper(jtu.tree_unflatten(treedef, leaves))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SEED = 0\n",
"DATASETS_ROOT = xdg_base_dirs.xdg_data_home() / \"lnb\" / \"mnist\"\n",
"X_INVERTED = False\n",
"BATCH_SIZE = 1000\n",
"\n",
"MLP_KWARGS = {\n",
" \"in_size\": 28 * 28,\n",
" \"out_size\": len(tvdatasets.MNIST.classes),\n",
" \"width_size\": 800,\n",
" \"depth\": 1,\n",
"}\n",
"WEIGHT_INIT = jax.nn.initializers.glorot_normal(in_axis=1, out_axis=0)\n",
"\n",
"LNB_KWARGS = {\n",
" \"cg_ridge\": 1.0,\n",
" \"is_neuron\": is_neuron,\n",
"}\n",
"STEP_SIZE = 0.5\n",
"CLIP_NORM = 1.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load MNIST"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def np_collate(batch):\n",
" return jax.tree_util.tree_map(np.asarray, torch.utils.data.default_collate(batch))\n",
"\n",
"\n",
"def load_dataset(dataset):\n",
" loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), collate_fn=np_collate)\n",
" images, labels = next(iter(loader))\n",
" del loader\n",
" return jnp.array(images), jnp.array(labels)\n",
"\n",
"\n",
"transforms = [tvtransforms.ToTensor(), tvtransforms.Lambda(torch.flatten)]\n",
"if X_INVERTED:\n",
" transforms.append(tvtransforms.Lambda(lambda x: 1.0 - x))\n",
"transform = tvtransforms.Compose(transforms)\n",
"mnist_train = tvdatasets.MNIST(DATASETS_ROOT, download=True, train=True, transform=transform)\n",
"mnist_test = tvdatasets.MNIST(DATASETS_ROOT, download=True, train=False, transform=transform)\n",
"train_images, train_labels = load_dataset(mnist_train)\n",
"test_images, test_labels = load_dataset(mnist_test)\n",
"\n",
"assert len(mnist_train) % BATCH_SIZE == 0, \"Keep evenly divisible\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## JIT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@partial(jax.jit, static_argnums=(1, 5))\n",
"def make_step(params, static, opt_state, xs, ys, opt_update):\n",
" def loss_xs(params):\n",
" model = eqx.combine(params, static)\n",
" logitses, xs_neurons = jax.vmap(model)(xs)\n",
" losses = optax.softmax_cross_entropy_with_integer_labels(logitses, ys)\n",
" return losses.mean(), xs_neurons\n",
"\n",
" (loss, xs_neurons), grad = jax.value_and_grad(loss_xs, has_aux=True)(params)\n",
" updates, opt_state = opt_update(grad, opt_state, params, xs_neurons=xs_neurons)\n",
" new_params = optax.apply_updates(params, updates)\n",
" return new_params, opt_state, loss\n",
"\n",
"\n",
"@partial(jax.jit, static_argnums=1)\n",
"def compute_accuracy(params, static, xs, ys):\n",
" model = eqx.combine(params, static)\n",
" predictions = jax.vmap(lambda x: jnp.argmax(model(x)[0]))(xs)\n",
" return jnp.mean(predictions == ys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = eqx.nn.MLP(**MLP_KWARGS, activation=jax.nn.tanh, key=jax.random.PRNGKey(0))\n",
"model = wrap_neurons(model)\n",
"key = jax.random.PRNGKey(SEED)\n",
"params, static = eqx.partition(model, eqx.is_inexact_array)\n",
"params = initialize(params, weight_init=WEIGHT_INIT, key=key)\n",
"\n",
"opt = optax.chain(\n",
" optax.clip_by_global_norm(CLIP_NORM),\n",
" optax.contrib.lnb(**LNB_KWARGS),\n",
" optax.scale_by_learning_rate(STEP_SIZE),\n",
")\n",
"opt_state = opt.init(params)\n",
"\n",
"data_loader = torch.utils.data.DataLoader(\n",
" dataset=mnist_train,\n",
" batch_size=BATCH_SIZE,\n",
" shuffle=True,\n",
" collate_fn=np_collate,\n",
" generator=torch.Generator().manual_seed(SEED),\n",
")\n",
"num_batches = len(data_loader)\n",
"\n",
"t = 0\n",
"for epoch in range(200):\n",
" for i, (xs, ys) in enumerate(data_loader):\n",
" xs, ys = jnp.array(xs), jnp.array(ys)\n",
" params, opt_state, loss = make_step(params, static, opt_state, xs, ys, opt.update)\n",
"\n",
" if t % 5 == 0:\n",
" loss = loss.item()\n",
" print(f\"[{epoch}, {i:02d}/{num_batches}, {t:04d}] loss: {loss:0.5f}\")\n",
" t += 1\n",
"\n",
" train_acc = compute_accuracy(params, static, train_images, train_labels).item()\n",
" test_acc = compute_accuracy(params, static, test_images, test_labels).item()\n",
" print(f\"[{epoch}, -----, {t:04d}] Train Acc: {train_acc:0.4f} Test Acc: {test_acc:0.4f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from optax.contrib._galore import GaLoreDimensionNumbers
from optax.contrib._galore import GaLoreState
from optax.contrib._galore import scale_by_galore
from optax.contrib._lnb import lnb
from optax.contrib._lnb import scale_by_lnb
from optax.contrib._lnb import ScaleByLNBState
from optax.contrib._madgrad import madgrad
from optax.contrib._madgrad import MadgradState
from optax.contrib._madgrad import scale_by_madgrad
Expand Down
Loading
Loading