diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index c30e3143e..8f1b1c679 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -26,6 +26,7 @@ are not supported by the main library. dpsgd galore GaLoreState + lnb madgrad MadgradState mechanize @@ -51,10 +52,12 @@ are not supported by the main library. SplitRealAndImaginaryState scale_by_ademamix ScaleByAdemamixState + ScaleByLNBState scale_by_simplified_ademamix ScaleBySimplifiedAdEMAMixState scale_by_adopt scale_by_acprop + scale_by_lnb scale_by_madgrad scale_by_muon hutchinson_estimator_diag_hessian diff --git a/examples/contrib/lnb_mnist.ipynb b/examples/contrib/lnb_mnist.ipynb new file mode 100644 index 000000000..4c5ade3ae --- /dev/null +++ b/examples/contrib/lnb_mnist.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Linear Neuron Boosting on MNIST\n", + "\n", + "This notebook depends on the following packages:\n", + "\n", + "- `equinox`\n", + "- `torch` (only for the DataLoader, can be CPU-only)\n", + "- `torchvision`\n", + "- `xdg-base-dirs`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "import itertools\n", + "from typing import Any, Tuple, TypeAlias\n", + "\n", + "import equinox as eqx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.tree_util as jtu\n", + "import numpy as np\n", + "import optax\n", + "import optax.contrib\n", + "import torch\n", + "import torchvision.datasets as tvdatasets\n", + "import torchvision.transforms as tvtransforms\n", + "import xdg_base_dirs\n", + "\n", + "JaxInitializer: TypeAlias = jax.nn.initializers.Initializer\n", + "PyTree: TypeAlias = Any" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Equinox Utilities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def is_neuron(node: PyTree) -> bool:\n", + " \"\"\"Returns True if the node is an Equinox module that is a linear neuron.\"\"\"\n", + " return isinstance(node, eqx.nn.Linear) or isinstance(node, eqx.nn.Conv)\n", + "\n", + "\n", + "def initialize(\n", + " model: eqx.Module,\n", + " weight_init: JaxInitializer,\n", + " bias_init: JaxInitializer = jax.nn.initializers.zeros,\n", + " *,\n", + " key: jax.Array,\n", + ") -> eqx.Module:\n", + " \"\"\"Applies the initializers on the `weights` and `bias` fields for any Neuron in model.\n", + "\n", + " Args:\n", + " model: the Module to initialize\n", + " weight_init: the [initializer]\n", + " (https://jax.readthedocs.io/en/latest/jax.nn.initializers.html) to use for the weights\n", + " bias_init: the initializer for the biases, if present.\n", + " key: the random key to split.\n", + " \"\"\"\n", + "\n", + " def _init_node(node: PyTree, key: jax.Array) -> PyTree:\n", + " if is_neuron(node):\n", + " weight_key, bias_key = jax.random.split(key, 2)\n", + " weight = weight_init(weight_key, node.weight.shape, node.weight.dtype)\n", + " node = eqx.tree_at(lambda n: n.weight, node, weight)\n", + " if getattr(node, \"bias\", None) is not None:\n", + " bias = bias_init(bias_key, node.bias.shape, node.bias.dtype)\n", + " node = eqx.tree_at(lambda n: n.bias, node, bias)\n", + " return node\n", + " else:\n", + " return node\n", + "\n", + " leaves, treedef = jtu.tree_flatten(model, is_leaf=is_neuron)\n", + " keys = jax.random.split(key, len(leaves))\n", + " leaves = itertools.starmap(_init_node, zip(leaves, keys, strict=True))\n", + " return jtu.tree_unflatten(treedef, leaves)\n", + "\n", + "\n", + "def wrap_neurons(model: eqx.nn.MLP) -> eqx.Module:\n", + " \"\"\"Wraps model with Module whose call operator returns the tuple [output, x_neurons].\n", + "\n", + " Where x_neurons: list[Array] are the inputs to each Neuron in model.\n", + " \"\"\"\n", + " leaves, treedef = jtu.tree_flatten(model, is_leaf=is_neuron)\n", + "\n", + " # A buffer containing the inputs to each neuron. Follows the example from:\n", + " # https://github.com/patrick-kidger/equinox/issues/186#issuecomment-1233606690\n", + " x_neurons = [None] * sum(map(is_neuron, leaves))\n", + "\n", + " # Saves the input to the respective location in x_neurons.\n", + " class Neuron(eqx.Module):\n", + " f: eqx.Module\n", + " idx: int\n", + "\n", + " def __call__(self, x: jax.Array, *args: Any, **kwargs: Any) -> jax.Array:\n", + " x_neurons[self.idx] = x\n", + " return self.f(x, *args, **kwargs)\n", + "\n", + " # Wrap around the model and return each neuron's input feature, in tree traversal order.\n", + " class Wrapper(eqx.Module):\n", + " F: eqx.Module\n", + "\n", + " def __call__(self, *args: Any, **kwargs: Any) -> Tuple[Any, list[jax.Array]]:\n", + " return self.F(*args, **kwargs), x_neurons\n", + "\n", + " # Replace all eqx.Modules that are linear neurons.\n", + " idx = itertools.count()\n", + " leaves = [Neuron(leaf, next(idx)) if is_neuron(leaf) else leaf for leaf in leaves]\n", + " return Wrapper(jtu.tree_unflatten(treedef, leaves))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 0\n", + "DATASETS_ROOT = xdg_base_dirs.xdg_data_home() / \"lnb\" / \"mnist\"\n", + "X_INVERTED = False\n", + "BATCH_SIZE = 1000\n", + "\n", + "MLP_KWARGS = {\n", + " \"in_size\": 28 * 28,\n", + " \"out_size\": len(tvdatasets.MNIST.classes),\n", + " \"width_size\": 800,\n", + " \"depth\": 1,\n", + "}\n", + "WEIGHT_INIT = jax.nn.initializers.glorot_normal(in_axis=1, out_axis=0)\n", + "\n", + "LNB_KWARGS = {\n", + " \"cg_ridge\": 1.0,\n", + " \"is_neuron\": is_neuron,\n", + "}\n", + "STEP_SIZE = 0.5\n", + "CLIP_NORM = 1.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def np_collate(batch):\n", + " return jax.tree_util.tree_map(np.asarray, torch.utils.data.default_collate(batch))\n", + "\n", + "\n", + "def load_dataset(dataset):\n", + " loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), collate_fn=np_collate)\n", + " images, labels = next(iter(loader))\n", + " del loader\n", + " return jnp.array(images), jnp.array(labels)\n", + "\n", + "\n", + "transforms = [tvtransforms.ToTensor(), tvtransforms.Lambda(torch.flatten)]\n", + "if X_INVERTED:\n", + " transforms.append(tvtransforms.Lambda(lambda x: 1.0 - x))\n", + "transform = tvtransforms.Compose(transforms)\n", + "mnist_train = tvdatasets.MNIST(DATASETS_ROOT, download=True, train=True, transform=transform)\n", + "mnist_test = tvdatasets.MNIST(DATASETS_ROOT, download=True, train=False, transform=transform)\n", + "train_images, train_labels = load_dataset(mnist_train)\n", + "test_images, test_labels = load_dataset(mnist_test)\n", + "\n", + "assert len(mnist_train) % BATCH_SIZE == 0, \"Keep evenly divisible\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JIT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@partial(jax.jit, static_argnums=(1, 5))\n", + "def make_step(params, static, opt_state, xs, ys, opt_update):\n", + " def loss_xs(params):\n", + " model = eqx.combine(params, static)\n", + " logitses, xs_neurons = jax.vmap(model)(xs)\n", + " losses = optax.softmax_cross_entropy_with_integer_labels(logitses, ys)\n", + " return losses.mean(), xs_neurons\n", + "\n", + " (loss, xs_neurons), grad = jax.value_and_grad(loss_xs, has_aux=True)(params)\n", + " updates, opt_state = opt_update(grad, opt_state, params, xs_neurons=xs_neurons)\n", + " new_params = optax.apply_updates(params, updates)\n", + " return new_params, opt_state, loss\n", + "\n", + "\n", + "@partial(jax.jit, static_argnums=1)\n", + "def compute_accuracy(params, static, xs, ys):\n", + " model = eqx.combine(params, static)\n", + " predictions = jax.vmap(lambda x: jnp.argmax(model(x)[0]))(xs)\n", + " return jnp.mean(predictions == ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = eqx.nn.MLP(**MLP_KWARGS, activation=jax.nn.tanh, key=jax.random.PRNGKey(0))\n", + "model = wrap_neurons(model)\n", + "key = jax.random.PRNGKey(SEED)\n", + "params, static = eqx.partition(model, eqx.is_inexact_array)\n", + "params = initialize(params, weight_init=WEIGHT_INIT, key=key)\n", + "\n", + "opt = optax.chain(\n", + " optax.clip_by_global_norm(CLIP_NORM),\n", + " optax.contrib.lnb(**LNB_KWARGS),\n", + " optax.scale_by_learning_rate(STEP_SIZE),\n", + ")\n", + "opt_state = opt.init(params)\n", + "\n", + "data_loader = torch.utils.data.DataLoader(\n", + " dataset=mnist_train,\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True,\n", + " collate_fn=np_collate,\n", + " generator=torch.Generator().manual_seed(SEED),\n", + ")\n", + "num_batches = len(data_loader)\n", + "\n", + "t = 0\n", + "for epoch in range(200):\n", + " for i, (xs, ys) in enumerate(data_loader):\n", + " xs, ys = jnp.array(xs), jnp.array(ys)\n", + " params, opt_state, loss = make_step(params, static, opt_state, xs, ys, opt.update)\n", + "\n", + " if t % 5 == 0:\n", + " loss = loss.item()\n", + " print(f\"[{epoch}, {i:02d}/{num_batches}, {t:04d}] loss: {loss:0.5f}\")\n", + " t += 1\n", + "\n", + " train_acc = compute_accuracy(params, static, train_images, train_labels).item()\n", + " test_acc = compute_accuracy(params, static, test_images, test_labels).item()\n", + " print(f\"[{epoch}, -----, {t:04d}] Train Acc: {train_acc:0.4f} Test Acc: {test_acc:0.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..e40762883 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -41,6 +41,9 @@ from optax.contrib._galore import GaLoreDimensionNumbers from optax.contrib._galore import GaLoreState from optax.contrib._galore import scale_by_galore +from optax.contrib._lnb import lnb +from optax.contrib._lnb import scale_by_lnb +from optax.contrib._lnb import ScaleByLNBState from optax.contrib._madgrad import madgrad from optax.contrib._madgrad import MadgradState from optax.contrib._madgrad import scale_by_madgrad diff --git a/optax/contrib/_lnb.py b/optax/contrib/_lnb.py new file mode 100644 index 000000000..738505557 --- /dev/null +++ b/optax/contrib/_lnb.py @@ -0,0 +1,429 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear Neuron Boosting. + +We refer to a "Neuron" as a callable PyTree that is linear in its parameters. +The input array to a Neuron must have shape (d_in, ...); this axis assumption +is noted in the code via "AA". +""" + +import functools +import itertools +from typing import Any, Callable, NamedTuple, Optional, Tuple, TypeAlias, Union + +import jax +import jax.numpy as jnp +import jax.scipy +import jax.tree_util as jtu +import optax +from optax._src import base +from optax._src import combine +from optax._src import utils +import optax.tree_utils as otu + +Neuron: TypeAlias = base.PyTree +"""A Callable PyTree that is linear in its parameters (weights and biases).""" + +Weight: TypeAlias = jax.Array +"""A Neuron's weight array with expected shape=(d_out, d_in, ...).""" + +Bias: TypeAlias = jax.Array +"""A Neuron's bias array with expected shape=(d_out,) or (d_out, 1, ..., 1).""" + +IsNeuron: TypeAlias = Callable[[base.PyTree], bool] +"""Returns True if the node is a Neuron; callled during tree flattening.""" + +IsWeight: TypeAlias = Callable[[jtu.KeyPath], bool] +"""Called on a Neuron's KeyPath, returns True if it is for the weights.""" + +IsBias: TypeAlias = Callable[[jtu.KeyPath], bool] +"""Called on a Neuron's KeyPath, returns True if it is for the biases.""" + +Vector: TypeAlias = jax.Array +"""A 1-D array.""" + + +_zip = functools.partial(zip, strict=True) + + +def _get_array( + neuron: Neuron, predicate: Union[IsWeight, IsBias] +) -> Optional[jax.Array]: + """Returns the array for the leaf where the predicate evaluates True.""" + leaves, _ = jtu.tree_flatten_with_path(neuron) + for key_path, arr in leaves: + if predicate(key_path): + return arr + return None + + +def _set_array( + neuron: Neuron, predicate: Union[IsWeight, IsBias], new_array: jax.Array +) -> Neuron: + """Assigns `new_array` for the leaf where the predicate evaluates True.""" + + def _replace(key_path: jtu.KeyPath, curr_array: jax.Array) -> jax.Array: + return new_array if predicate(key_path) else curr_array + + return jtu.tree_map_with_path(_replace, neuron) + + +def _last_name_is_str(key_path: jtu.KeyPath, query: str) -> bool: + """Returns True if the last EntryName in `key_path` matches `query`.""" + last_entry = key_path[-1] + entry_name = getattr(last_entry, "key", getattr(last_entry, "name", None)) + return entry_name == query + + +def _compute_mu(xs: jax.Array) -> Vector: + """Returns the mean feature vector across the spatial dimensions.""" + return jnp.mean(xs, axis=(0, *tuple(range(2, xs.ndim)))) # AA + + +def _shrink(diag: Vector, shrinkage: jax.typing.ArrayLike) -> Vector: + """See https://scikit-learn.org/stable/modules/covariance.html#basic-shrinkage.""" + assert 0.0 <= shrinkage + assert shrinkage < 1.0 + return (1.0 - shrinkage) * diag + ( + shrinkage * jnp.sum(diag) / jnp.size(diag) + ) + + +def default_is_weight(key_path: jtu.KeyPath) -> bool: + """Returns True if the EntryName is "weight".""" + return _last_name_is_str(key_path, "weight") + + +def default_is_bias(key_path: jtu.KeyPath) -> bool: + """Returns True if the EntryName is "bias".""" + return _last_name_is_str(key_path, "bias") + + +def make_pvp( + mu: Vector, + nu: Vector, + shrinkage: jax.typing.ArrayLike, + is_weight: IsWeight, + is_bias: IsBias, +) -> Callable[[Neuron], Neuron]: + """Returns the Neuron's preconditioner mvp for conjugate gradient. + + The preconditioner is the incomplete Cholesky factorization (Section 3.5). + + Args: + mu: mean input features + nu: mean squared input features + shrinkage: covariance shinkage in [0, 1) + is_weight: see type definition + is_bias: see type definition + """ + assert len(mu.shape) == 1 + assert mu.shape == nu.shape + variance_shrunk = _shrink(jnp.maximum(0.0, nu - mu**2), shrinkage) + + # PVP for one component of the parameter vector (with bias). + def _pvp_bias( + weight: jax.Array, bias: jax.Array + ) -> Tuple[jax.Array, jax.Array]: + # We use vmap here because `weight` might be an N-D spatial filter and + # we use scalar broadcasting to equally apply the feature moments over + # the spatial axes (i.e., assumes translational equivariance). + weight_out = jax.vmap( + lambda w, b, m, v: (w - m * b) / v, in_axes=(0, None, 0, 0) + )(weight, bias, mu, variance_shrunk) + bias_out = bias - jnp.sum(jax.vmap(lambda m, w: m * w)(mu, weight_out)) + assert weight_out.shape == weight.shape + assert bias_out.shape == bias.shape + return weight_out, bias_out + + # PVP for one component of the parameter vector (without bias). + def _pvp_nobias(weight: jax.Array) -> jax.Array: + nu_shrunk = variance_shrunk + mu**2 + return jax.vmap(lambda w, n: w / n)(weight, nu_shrunk) + + # PVP for all components of the parameter vector. + def pvp(v: Neuron) -> Neuron: + weight = _get_array(v, is_weight) + assert weight is not None + assert weight.shape[1] == mu.size # AA + bias = _get_array(v, is_bias) + if bias is None: + weight = jax.vmap(_pvp_nobias)(weight) + v = _set_array(v, is_weight, weight) + else: + weight, bias = jax.vmap(_pvp_bias)(weight, bias) + v = _set_array(v, is_weight, weight) + v = _set_array(v, is_bias, bias) + return v + + return pvp + + +def project( + xs: jax.Array, + grad: Neuron, + init: Neuron, + is_weight: IsWeight, + is_bias: IsBias, + ridge: float = 0.0, + **kwargs: Any, +) -> Neuron: + """Preconditions neuron's component of the gradient vector. + + Conjugate gradient is used to solve the linear system (Section 3.2). + + Args: + xs: the batched input tensors to the neuron + grad: the component of the gradient vector for the neuron + init: initialization for conjugate gradient + is_weight: see type definition + is_bias: see type definition + ridge: ridge reglarization; only applied to the weights + kwargs: forwards to jax.scipy.sparse.linalg.cg + """ + + def jvp(v: Neuron) -> jax.Array: + return jax.vmap(v)(xs) + + ys = jvp(init) + num_samples = ys.size / ys.shape[1] # AA + del ys + vjp = jax.linear_transpose( + jvp, init + ) # The linearization point doesn't matter since linear. + + def mvp(v: Neuron) -> Neuron: + return otu.tree_scalar_mul(1.0 / num_samples, vjp(jvp(v))[0]) # J'Jv + + # Create a pytree with same shape as the neuron and use scalar broadcasting + # to implement ridge regression while ignoring bias, if present. + ridge_neuron = _set_array(init, is_weight, ridge) + ridge_neuron = _set_array(ridge_neuron, is_bias, 0.0) + + def mvp_ridge(v: Neuron) -> Neuron: + return otu.tree_add(mvp(v), otu.tree_mul(ridge_neuron, v)) + + _mvp = mvp_ridge if ridge > 0.0 else mvp + return jax.scipy.sparse.linalg.cg(_mvp, grad, x0=init, **kwargs)[0] + + +class ScaleByLNBState(NamedTuple): + """State for the Linear Neuron Boosting algorithm.""" + + # Each state is a list of length number of neurons, in leaf traversal order + h_neurons: list[Neuron] # The previous conjugate gradient solution + mu_state: optax.EmaState # Mean input features (first moment) + nu_state: optax.EmaState # Mean squared input features (second moment) + + +def scale_by_lnb( + b_mu: jax.typing.ArrayLike = 0.9, + b_nu: jax.typing.ArrayLike = 0.999, + min_norm: jax.typing.ArrayLike = 1e-2, + cov_shrinkage: jax.typing.ArrayLike = 0.1, + *, + cg_ridge: float = 0.1, + cg_maxiter: int = 2, + approx_metric: bool = False, + is_neuron: IsNeuron = lambda node: hasattr(node, "weight"), + is_weight: IsWeight = default_is_weight, + is_bias: IsBias = default_is_bias, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformationExtraArgs: + r"""Applies a Linear Neuron Boosting update. + + LNB performs gradient descent in the space of linear functions for each + linear layer of the network. Equivalently, it is a trust region method that + takes a small step under a metric related to the second moment matrix of + each linear layer's input features. Also equivalently, it runs gradient + descent on a reparameterized model where each linear layer whitens its input + features. + + Args: + b_mu: EMA decay for input features' first moment. + b_nu: EMA decay for input features' second momment. + min_norm: The minimum norm to use to avoid dividing by zero. + cov_shrinkage: Covariance shrinkage, in [0, 1), to ensure positive + definiteness. + cg_ridge: Ridge regularizer for conjugate gradient. + cg_maxiter: Number of conjugate gradient iterations. + approx_metric: Approximate the metric with the incomplete Cholesky + factorization, instead of using conjugate gradient to solve the linear + system. + is_neuron: See type definition. + is_weight: See type definition. + is_bias: See type definition. + accumulator_dtype: EMA accumulator dtype. + + Returns: + `GradientTransformationExtraArgs` with an `update_fn` that expects the + `xs_neurons` kwarg to be a list of batched input arrays to each respective + neuron, in leaf traversal order as specified by `is_neuron`. + + References: + D. Munoz, `Simple Linear Neuron Boosting + `_, 2025 + """ + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + mu_ema = optax.ema(b_mu, debias=True, accumulator_dtype=accumulator_dtype) + nu_ema = optax.ema(b_nu, debias=True, accumulator_dtype=accumulator_dtype) + _make_pvp = functools.partial( + make_pvp, shrinkage=cov_shrinkage, is_weight=is_weight, is_bias=is_bias + ) + _project = functools.partial( + project, + is_weight=is_weight, + is_bias=is_bias, + ridge=cg_ridge, + maxiter=cg_maxiter, + ) + + def init_fn(params: base.Params) -> ScaleByLNBState: + def _make_zero_vector(neuron: Neuron) -> Vector: + weight = _get_array(neuron, is_weight) + assert weight is not None + return jnp.zeros(weight.shape[1]) # AA + + neurons = list( + filter(is_neuron, jtu.tree_leaves(params, is_leaf=is_neuron)) + ) + assert all(map(callable, neurons)), "Each Neuron must be callable." + return ScaleByLNBState( + h_neurons=list(map(otu.tree_zeros_like, neurons)), + mu_state=mu_ema.init(list(map(_make_zero_vector, neurons))), + nu_state=nu_ema.init(list(map(_make_zero_vector, neurons))), + ) + + def update_fn( + updates: base.Updates, + state: ScaleByLNBState, + params: base.Params, + xs_neurons: list[jax.Array], + ) -> Tuple[base.Updates, ScaleByLNBState]: + del params + # Update feature moments. + mu_neurons = list(map(_compute_mu, xs_neurons)) + nu_neurons = list(map(_compute_mu, map(jnp.square, xs_neurons))) + mu_neurons, mu_state = mu_ema.update(mu_neurons, state.mu_state) + nu_neurons, nu_state = nu_ema.update(nu_neurons, state.nu_state) + + # Construct PVPs for conjugate gradient using moments. + pvp_neurons = itertools.starmap(_make_pvp, _zip(mu_neurons, nu_neurons)) + + # Grab components of the entire gradient vector that are neurons. + leaves, treedef = jtu.tree_flatten(updates, is_leaf=is_neuron) + idx_neurons, grad_neurons = _zip( + *[(i, node) for i, node in enumerate(leaves) if is_neuron(node)] + ) + + # Precondition the gradient vector and then update its neurons. + if approx_metric: + h_neurons = [ + pvp(grad) for grad, pvp in _zip(grad_neurons, pvp_neurons) + ] + else: + h_neurons = [ + _project(xs, grad, init, M=pvp) + for xs, grad, init, pvp in _zip( + xs_neurons, + grad_neurons, + state.h_neurons, + pvp_neurons, + ) + ] + + # Replace the Neurons (implies the Identity metric for the others). + for idx, neuron in _zip(idx_neurons, h_neurons): + leaves[idx] = neuron + h = jtu.tree_unflatten(treedef, leaves) + + # Rescale under the metric (adaptive step size). + norm = jnp.maximum(min_norm, jnp.sqrt(otu.tree_vdot(h, updates))) + h_unit = otu.tree_scalar_mul(1.0 / norm, h) + return h_unit, ScaleByLNBState(h_neurons, mu_state, nu_state) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +def lnb( + b_g: jax.typing.ArrayLike = 0.9, + b_mu: jax.typing.ArrayLike = 0.9, + b_nu: jax.typing.ArrayLike = 0.999, + min_norm: jax.typing.ArrayLike = 1e-2, + cov_shrinkage: jax.typing.ArrayLike = 0.1, + weight_decay: base.ScalarOrSchedule = 1e-4, + *, + cg_ridge: float = 0.1, + cg_maxiter: int = 2, + approx_metric: bool = False, + is_neuron: IsNeuron = lambda node: hasattr(node, "weight"), + is_weight: IsWeight = default_is_weight, + is_bias: IsBias = default_is_bias, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformationExtraArgs: + r"""Linear Neuron Boosting. + + LNB performs gradient descent in the space of linear functions for each + linear layer of the network. Equivalently, it is a trust region method that + takes a small step under a metric related to the second moment matrix of + each linear layer's input features. Also equivalently, it runs gradient + descent on a reparameterized model where each linear layer whitens its input + features. + + Args: + b_g: EMA decay for the gradient vector. + b_mu: EMA decay for input features' first moment. + b_nu: EMA decay for input features' second momment. + min_norm: The minimum norm to use to avoid dividing by zero. + cov_shrinkage: Covariance shrinkage, in [0, 1), to ensure positive + definiteness. + weight_decay: Weight decay factor or schedule. + cg_ridge: Ridge regularizer for conjugate gradient. + cg_maxiter: Number of conjugate gradient iterations. + approx_metric: Approximate the metric with the incomplete Cholesky + factorization, instead of using conjugate gradient to solve the linear + system. + is_neuron: See type definition. + is_weight: See type definition. + is_bias: See type definition. + accumulator_dtype: EMA accumulator dtype. + + Returns: + `GradientTransformationExtraArgs` with an `update_fn` that expects the + `xs_neurons` kwarg to be a list of batched input arrays to each respective + neuron, in leaf traversal order as specified by `is_neuron`. + + References: + D. Munoz, `Simple Linear Neuron Boosting + `_, 2025 + """ + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + return combine.chain( + optax.ema(b_g, debias=True, accumulator_dtype=accumulator_dtype), + scale_by_lnb( + b_mu, + b_nu, + min_norm, + cov_shrinkage, + cg_ridge=cg_ridge, + cg_maxiter=cg_maxiter, + approx_metric=approx_metric, + is_neuron=is_neuron, + is_weight=is_weight, + is_bias=is_bias, + accumulator_dtype=accumulator_dtype, + ), + optax.add_decayed_weights(weight_decay), + )