diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index c30e3143e..c97cf429f 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -47,6 +47,8 @@ are not supported by the main library. ScheduleFreeState sophia SophiaState + spsa_estimator + spsa_standard_schedule split_real_and_imaginary SplitRealAndImaginaryState scale_by_ademamix diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..286c39f0c 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -74,3 +74,5 @@ from optax.contrib._sophia import HutchinsonState from optax.contrib._sophia import sophia from optax.contrib._sophia import SophiaState +from optax.contrib._spsa import spsa_estimator +from optax.contrib._spsa import spsa_standard_schedule diff --git a/optax/contrib/_spsa.py b/optax/contrib/_spsa.py new file mode 100644 index 000000000..26f3308f0 --- /dev/null +++ b/optax/contrib/_spsa.py @@ -0,0 +1,118 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simultaneous Perturbation Stochastic Approximation (SPSA) method.""" + +from typing import Any, Callable + +import jax +import jax.numpy as jnp +import optax.tree +from optax._src import base + + +def spsa_standard_schedule( + init_value: float, + decay_rate: float, + offset: float = 1.0, +) -> optax.schedules.Schedule: + """Returns a schedule for the SPSA learning rate or perturbation scale. + + The standard SPSA decay schedule is given by: + + .. math:: + v_k = \\frac{\\text{init\\_value}}{(count + offset)^{\\text{decay\\_rate}}} + + Args: + init_value: The initial value of the parameter. + decay_rate: The exponent for the polynomial decay. + offset: The offset added to the count for stability. + + Returns: + A function that takes the current count and returns the decayed value. + """ + + def schedule(count: jax.typing.ArrayLike) -> jax.typing.ArrayLike: + return init_value / ((count + offset) ** decay_rate) + + return schedule + + +def spsa_estimator( + value_fn: Callable[..., jax.Array], +) -> Callable[..., base.Updates]: + r"""Returns a function that computes the SPSA gradient estimate. + + The Simultaneous Perturbation Stochastic Approximation (SPSA) method + estimates the gradient of a function by evaluating it at two symmetrically + perturbed points. + + Let :math:`\Delta` be a random vector sampled from the Rademacher + distribution + (values in :math:`\{-1, 1\}` with equal probability). The SPSA gradient + estimate for a function :math:`f` at point :math:`x` is given by: + + .. math:: + g = \frac{f(x + c \Delta) - f(x - c \Delta)}{2 c \Delta} + + Args: + value_fn: The function whose gradient is to be estimated. Its first + argument + should be the parameters (a PyTree) with respect to which the gradient + is estimated. It should return a scalar value. + + Returns: + A function with the signature + ``grad_fn(params, c, key, *args, **kwargs)`` + that computes the SPSA gradient estimate. ``c`` is the + perturbation scale + (a scalar). ``key`` is a ``jax.random.PRNGKey`` used to generate + the random + perturbations. + + References: + Spall, `An Overview of the Simultaneous Perturbation Method + for Efficient + Optimization `_, + 1998 + """ + + def grad_fn( + params: base.Params, + c: jax.typing.ArrayLike, + key: base.PRNGKey, + *args: Any, + **kwargs: Any, + ) -> base.Updates: + def sample_rademacher(k, shape, dtype): + # Rademacher distribution: uniform over {-1, 1} + return jax.random.rademacher(k, shape, dtype=dtype) + + delta = optax.tree.random_like(key, params, sampler=sample_rademacher) + + params_plus = jax.tree.map(lambda p, d: p + c * d, params, delta) + params_minus = jax.tree.map(lambda p, d: p - c * d, params, delta) + + y_plus = value_fn(params_plus, *args, **kwargs) + y_minus = value_fn(params_minus, *args, **kwargs) + + # Note: Since delta_i is either 1 or -1, dividing by delta_i is + # equivalent + # to multiplying by delta_i. We multiply for numerical stability. + safe_c = jnp.maximum(c, jnp.finfo(jnp.result_type(c)).eps) + scalar_diff = (y_plus - y_minus) / (2.0 * safe_c) + grad_estimate = jax.tree.map(lambda d: scalar_diff * d, delta) + return grad_estimate + + return grad_fn diff --git a/tests/contrib/spsa_test.py b/tests/contrib/spsa_test.py new file mode 100644 index 000000000..f29d3fab1 --- /dev/null +++ b/tests/contrib/spsa_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.contrib.spsa.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp +import numpy as np +import optax + + +class SPSATest(chex.TestCase): + @chex.all_variants + def test_spsa_gradient_estimator(self): + def loss_fn(params): + return jnp.sum(params["w"] ** 2) + + estimator = optax.contrib.spsa_estimator(loss_fn) + params = {"w": jnp.array([1.0, -2.0, 3.0])} + + # We use a fixed key to ensure deterministic test. + key = jax.random.PRNGKey(42) + + @self.variant + def get_grad(p, k, c): + return estimator(p, c, k) + + # Average over 10000 samples to test unbiasedness + keys = jax.random.split(key, 10000) + get_grad_vmap = jax.vmap(get_grad, in_axes=(None, 0, None)) + grad_estimates = get_grad_vmap(params, keys, 1.0) + + mean_grad = jax.tree.map(lambda x: jnp.mean(x, axis=0), grad_estimates) + expected_grad = {"w": jnp.array([2.0, -4.0, 6.0])} + + # Check if the mean of the estimates is close to the true gradient + chex.assert_trees_all_close(mean_grad, expected_grad, atol=0.2) + + @chex.all_variants + def test_spsa_optimizer_integration(self): + # Test using SPSA estimator with standard optax optimizer (e.g. SGD) + def loss_fn(params): + return jnp.sum((params["w"] - 2.0) ** 2) + + estimator = optax.contrib.spsa_estimator(loss_fn) + optimizer = optax.sgd(learning_rate=0.1) + + params = {"w": jnp.array([-5.0, 5.0])} + opt_state = optimizer.init(params) + key = jax.random.PRNGKey(0) + + @self.variant + def step(p, state, k): + k1, k2 = jax.random.split(k) + grad = estimator(p, 0.1, k1) + updates, new_state = optimizer.update(grad, state, p) + new_params = optax.apply_updates(p, updates) + return new_params, new_state, k2 + + for _ in range(50): + params, opt_state, key = step(params, opt_state, key) + + # After 50 steps, parameters should be close to 2.0 + chex.assert_trees_all_close( + params["w"], jnp.array([2.0, 2.0]), atol=1e-2 + ) + + def test_spsa_standard_schedule(self): + schedule = optax.contrib.spsa_standard_schedule( + init_value=1.0, decay_rate=0.5, offset=10.0 + ) + + val_0 = schedule(0) + val_10 = schedule(10) + + np.testing.assert_allclose(val_0, 1.0 / (10.0**0.5)) + np.testing.assert_allclose(val_10, 1.0 / (20.0**0.5)) + + +if __name__ == "__main__": + absltest.main()