Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ are not supported by the main library.
ScheduleFreeState
sophia
SophiaState
spsa_estimator
spsa_standard_schedule
split_real_and_imaginary
SplitRealAndImaginaryState
scale_by_ademamix
Expand Down
2 changes: 2 additions & 0 deletions optax/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,5 @@
from optax.contrib._sophia import HutchinsonState
from optax.contrib._sophia import sophia
from optax.contrib._sophia import SophiaState
from optax.contrib._spsa import spsa_estimator
from optax.contrib._spsa import spsa_standard_schedule
118 changes: 118 additions & 0 deletions optax/contrib/_spsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Simultaneous Perturbation Stochastic Approximation (SPSA) method."""

from typing import Any, Callable

import jax
import jax.numpy as jnp
import optax.tree
from optax._src import base


def spsa_standard_schedule(
init_value: float,
decay_rate: float,
offset: float = 1.0,
) -> optax.schedules.Schedule:
"""Returns a schedule for the SPSA learning rate or perturbation scale.

The standard SPSA decay schedule is given by:

.. math::
v_k = \\frac{\\text{init\\_value}}{(count + offset)^{\\text{decay\\_rate}}}

Args:
init_value: The initial value of the parameter.
decay_rate: The exponent for the polynomial decay.
offset: The offset added to the count for stability.

Returns:
A function that takes the current count and returns the decayed value.
"""

def schedule(count: jax.typing.ArrayLike) -> jax.typing.ArrayLike:
return init_value / ((count + offset) ** decay_rate)

return schedule


def spsa_estimator(
value_fn: Callable[..., jax.Array],
) -> Callable[..., base.Updates]:
r"""Returns a function that computes the SPSA gradient estimate.

The Simultaneous Perturbation Stochastic Approximation (SPSA) method
estimates the gradient of a function by evaluating it at two symmetrically
perturbed points.

Let :math:`\Delta` be a random vector sampled from the Rademacher
distribution
(values in :math:`\{-1, 1\}` with equal probability). The SPSA gradient
estimate for a function :math:`f` at point :math:`x` is given by:

.. math::
g = \frac{f(x + c \Delta) - f(x - c \Delta)}{2 c \Delta}

Args:
value_fn: The function whose gradient is to be estimated. Its first
argument
should be the parameters (a PyTree) with respect to which the gradient
is estimated. It should return a scalar value.

Returns:
A function with the signature
``grad_fn(params, c, key, *args, **kwargs)``
that computes the SPSA gradient estimate. ``c`` is the
perturbation scale
(a scalar). ``key`` is a ``jax.random.PRNGKey`` used to generate
the random
perturbations.

References:
Spall, `An Overview of the Simultaneous Perturbation Method
for Efficient
Optimization <https://www.jhuapl.edu/SPSA/PDF-SPSA/Spall_An_Overview.PDF>`_,
1998
"""

def grad_fn(
params: base.Params,
c: jax.typing.ArrayLike,
key: base.PRNGKey,
*args: Any,
**kwargs: Any,
) -> base.Updates:
def sample_rademacher(k, shape, dtype):
# Rademacher distribution: uniform over {-1, 1}
return jax.random.rademacher(k, shape, dtype=dtype)

delta = optax.tree.random_like(key, params, sampler=sample_rademacher)

params_plus = jax.tree.map(lambda p, d: p + c * d, params, delta)
params_minus = jax.tree.map(lambda p, d: p - c * d, params, delta)

y_plus = value_fn(params_plus, *args, **kwargs)
y_minus = value_fn(params_minus, *args, **kwargs)

# Note: Since delta_i is either 1 or -1, dividing by delta_i is
# equivalent
# to multiplying by delta_i. We multiply for numerical stability.
safe_c = jnp.maximum(c, jnp.finfo(jnp.result_type(c)).eps)
scalar_diff = (y_plus - y_minus) / (2.0 * safe_c)
grad_estimate = jax.tree.map(lambda d: scalar_diff * d, delta)
return grad_estimate

return grad_fn
94 changes: 94 additions & 0 deletions tests/contrib/spsa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optax.contrib.spsa."""

from absl.testing import absltest
import chex
import jax
import jax.numpy as jnp
import numpy as np
import optax


class SPSATest(chex.TestCase):
@chex.all_variants
def test_spsa_gradient_estimator(self):
def loss_fn(params):
return jnp.sum(params["w"] ** 2)

estimator = optax.contrib.spsa_estimator(loss_fn)
params = {"w": jnp.array([1.0, -2.0, 3.0])}

# We use a fixed key to ensure deterministic test.
key = jax.random.PRNGKey(42)

@self.variant
def get_grad(p, k, c):
return estimator(p, c, k)

# Average over 10000 samples to test unbiasedness
keys = jax.random.split(key, 10000)
get_grad_vmap = jax.vmap(get_grad, in_axes=(None, 0, None))
grad_estimates = get_grad_vmap(params, keys, 1.0)

mean_grad = jax.tree.map(lambda x: jnp.mean(x, axis=0), grad_estimates)
expected_grad = {"w": jnp.array([2.0, -4.0, 6.0])}

# Check if the mean of the estimates is close to the true gradient
chex.assert_trees_all_close(mean_grad, expected_grad, atol=0.2)

@chex.all_variants
def test_spsa_optimizer_integration(self):
# Test using SPSA estimator with standard optax optimizer (e.g. SGD)
def loss_fn(params):
return jnp.sum((params["w"] - 2.0) ** 2)

estimator = optax.contrib.spsa_estimator(loss_fn)
optimizer = optax.sgd(learning_rate=0.1)

params = {"w": jnp.array([-5.0, 5.0])}
opt_state = optimizer.init(params)
key = jax.random.PRNGKey(0)

@self.variant
def step(p, state, k):
k1, k2 = jax.random.split(k)
grad = estimator(p, 0.1, k1)
updates, new_state = optimizer.update(grad, state, p)
new_params = optax.apply_updates(p, updates)
return new_params, new_state, k2

for _ in range(50):
params, opt_state, key = step(params, opt_state, key)

# After 50 steps, parameters should be close to 2.0
chex.assert_trees_all_close(
params["w"], jnp.array([2.0, 2.0]), atol=1e-2
)

def test_spsa_standard_schedule(self):
schedule = optax.contrib.spsa_standard_schedule(
init_value=1.0, decay_rate=0.5, offset=10.0
)

val_0 = schedule(0)
val_10 = schedule(10)

np.testing.assert_allclose(val_0, 1.0 / (10.0**0.5))
np.testing.assert_allclose(val_10, 1.0 / (20.0**0.5))


if __name__ == "__main__":
absltest.main()
Loading