From 16b503562bf46940f374b1fa08b77a24698fd9a0 Mon Sep 17 00:00:00 2001 From: Paramveer singh Date: Tue, 23 Jun 2026 12:14:03 +0530 Subject: [PATCH 1/4] feat: Add SPSA optimization method (Issue #357) --- docs/api/contrib.rst | 2 + optax/contrib/__init__.py | 2 + optax/contrib/_spsa.py | 117 +++++++++++++++++++++++++++++++++++++ tests/contrib/spsa_test.py | 93 +++++++++++++++++++++++++++++ 4 files changed, 214 insertions(+) create mode 100644 optax/contrib/_spsa.py create mode 100644 tests/contrib/spsa_test.py diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index c30e3143e..c97cf429f 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -47,6 +47,8 @@ are not supported by the main library. ScheduleFreeState sophia SophiaState + spsa_estimator + spsa_standard_schedule split_real_and_imaginary SplitRealAndImaginaryState scale_by_ademamix diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index 35c032ef0..286c39f0c 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -74,3 +74,5 @@ from optax.contrib._sophia import HutchinsonState from optax.contrib._sophia import sophia from optax.contrib._sophia import SophiaState +from optax.contrib._spsa import spsa_estimator +from optax.contrib._spsa import spsa_standard_schedule diff --git a/optax/contrib/_spsa.py b/optax/contrib/_spsa.py new file mode 100644 index 000000000..f128bfe47 --- /dev/null +++ b/optax/contrib/_spsa.py @@ -0,0 +1,117 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simultaneous Perturbation Stochastic Approximation (SPSA) method.""" + +from typing import Any, Callable + +import jax +import optax.tree +from optax._src import base + + +def spsa_standard_schedule( + init_value: float, + decay_rate: float, + offset: float = 0.0, +) -> optax.schedules.Schedule: + """Returns a schedule for the SPSA learning rate or perturbation scale. + + The standard SPSA decay schedule is given by: + + .. math:: + v_k = \\frac{\\text{init\\_value}}{(count + offset)^{\\text{decay\\_rate}}} + + Args: + init_value: The initial value of the parameter. + decay_rate: The exponent for the polynomial decay. + offset: The offset added to the count for stability. + + Returns: + A function that takes the current count and returns the decayed value. + """ + + def schedule(count: jax.typing.ArrayLike) -> jax.typing.ArrayLike: + return init_value / ((count + offset) ** decay_rate) + + return schedule + + +def spsa_estimator( + value_fn: Callable[..., jax.Array], +) -> Callable[..., base.Updates]: + r"""Returns a function that computes the SPSA gradient estimate. + + The Simultaneous Perturbation Stochastic Approximation (SPSA) method + estimates the gradient of a function by evaluating it at two symmetrically + perturbed points. + + Let :math:`\Delta` be a random vector sampled from the Rademacher + distribution + (values in :math:`\{-1, 1\}` with equal probability). The SPSA gradient + estimate for a function :math:`f` at point :math:`x` is given by: + + .. math:: + g = \frac{f(x + c \Delta) - f(x - c \Delta)}{2 c \Delta} + + Args: + value_fn: The function whose gradient is to be estimated. Its first + argument + should be the parameters (a PyTree) with respect to which the gradient + is estimated. It should return a scalar value. + + Returns: + A function with the signature + ``grad_fn(params, c, key, *args, **kwargs)`` + that computes the SPSA gradient estimate. ``c`` is the + perturbation scale + (a scalar). ``key`` is a ``jax.random.PRNGKey`` used to generate + the random + perturbations. + + References: + Spall, `An Overview of the Simultaneous Perturbation Method + for Efficient + Optimization `_, + 1998 + """ + + def grad_fn( + params: base.Params, + c: jax.typing.ArrayLike, + key: base.PRNGKey, + *args: Any, + **kwargs: Any, + ) -> base.Updates: + def sample_rademacher(k, shape, dtype): + # Rademacher distribution: uniform over {-1, 1} + return jax.random.rademacher(k, shape, dtype=dtype) + + delta = optax.tree.random_like(key, params, sampler=sample_rademacher) + + params_plus = jax.tree.map(lambda p, d: p + c * d, params, delta) + params_minus = jax.tree.map(lambda p, d: p - c * d, params, delta) + + y_plus = value_fn(params_plus, *args, **kwargs) + y_minus = value_fn(params_minus, *args, **kwargs) + + # Note: Since delta_i is either 1 or -1, dividing by delta_i is + # equivalent + # to multiplying by delta_i. We multiply for numerical stability. + grad_estimate = jax.tree.map( + lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta + ) + return grad_estimate + + return grad_fn diff --git a/tests/contrib/spsa_test.py b/tests/contrib/spsa_test.py new file mode 100644 index 000000000..1ea5436ac --- /dev/null +++ b/tests/contrib/spsa_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optax.contrib.spsa.""" + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp +import optax + + +class SPSATest(chex.TestCase): + @chex.all_variants + def test_spsa_gradient_estimator(self): + def loss_fn(params): + return jnp.sum(params["w"] ** 2) + + estimator = optax.contrib.spsa_estimator(loss_fn) + params = {"w": jnp.array([1.0, -2.0, 3.0])} + + # We use a fixed key to ensure deterministic test. + key = jax.random.PRNGKey(42) + + @self.variant + def get_grad(p, k, c): + return estimator(p, c, k) + + # Average over 10000 samples to test unbiasedness + keys = jax.random.split(key, 10000) + get_grad_vmap = jax.vmap(get_grad, in_axes=(None, 0, None)) + grad_estimates = get_grad_vmap(params, keys, 1.0) + + mean_grad = jax.tree.map(lambda x: jnp.mean(x, axis=0), grad_estimates) + expected_grad = {"w": jnp.array([2.0, -4.0, 6.0])} + + # Check if the mean of the estimates is close to the true gradient + chex.assert_trees_all_close(mean_grad, expected_grad, atol=0.2) + + @chex.all_variants + def test_spsa_optimizer_integration(self): + # Test using SPSA estimator with standard optax optimizer (e.g. SGD) + def loss_fn(params): + return jnp.sum((params["w"] - 2.0) ** 2) + + estimator = optax.contrib.spsa_estimator(loss_fn) + optimizer = optax.sgd(learning_rate=0.1) + + params = {"w": jnp.array([-5.0, 5.0])} + opt_state = optimizer.init(params) + key = jax.random.PRNGKey(0) + + @self.variant + def step(p, state, k): + k1, k2 = jax.random.split(k) + grad = estimator(p, 0.1, k1) + updates, new_state = optimizer.update(grad, state, p) + new_params = optax.apply_updates(p, updates) + return new_params, new_state, k2 + + for _ in range(50): + params, opt_state, key = step(params, opt_state, key) + + # After 50 steps, parameters should be close to 2.0 + chex.assert_trees_all_close( + params["w"], jnp.array([2.0, 2.0]), atol=1e-2 + ) + + def test_spsa_standard_schedule(self): + schedule = optax.contrib.spsa_standard_schedule( + init_value=1.0, decay_rate=0.5, offset=10.0 + ) + + val_0 = schedule(0) + val_10 = schedule(10) + + self.assertAlmostEqual(val_0, 1.0 / (10.0**0.5)) + self.assertAlmostEqual(val_10, 1.0 / (20.0**0.5)) + + +if __name__ == "__main__": + absltest.main() From ab59984d1d6f4713f858a57ceb6058dd1b7d959e Mon Sep 17 00:00:00 2001 From: Paramveer singh Date: Tue, 23 Jun 2026 17:27:47 +0530 Subject: [PATCH 2/4] Update optax --- fetch_issues.py | 14 ++++++++++++++ issue_357.json | 0 issues.txt | Bin 0 -> 8024 bytes 3 files changed, 14 insertions(+) create mode 100644 fetch_issues.py create mode 100644 issue_357.json create mode 100644 issues.txt diff --git a/fetch_issues.py b/fetch_issues.py new file mode 100644 index 000000000..34d5acaba --- /dev/null +++ b/fetch_issues.py @@ -0,0 +1,14 @@ +import urllib.request +import json +with open("issues.txt", "w", encoding="utf-8") as f: + try: + req = urllib.request.Request('https://api.github.com/repos/google-deepmind/optax/issues?state=open&per_page=100', headers={'User-Agent': 'Mozilla/5.0'}) + response = urllib.request.urlopen(req) + issues = json.loads(response.read()) + for i in issues: + if 'pull_request' not in i: + labels = [l['name'] for l in i['labels']] + f.write(f"#{i['number']}: {i['title']} (Labels: {labels})\n") + f.write(f"URL: {i['html_url']}\n\n") + except Exception as e: + f.write(f"Error: {e}\n") diff --git a/issue_357.json b/issue_357.json new file mode 100644 index 000000000..e69de29bb diff --git a/issues.txt b/issues.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b29160259c23be6c67a86c53bfd2b5bd0c31594 GIT binary patch literal 8024 zcmdU!Ur!rX5XI-YQoqAWeISv_pMnh~FI8wl(h3q3lS)+-!D3@$Vq@1fDd~rAdwzF( zxx2=x^VH4CV!3z%k0Pw?9g6VOXGF@J++QLnRV^f zZmexH>)S}9+y)xW?31p|G?UrXK4@mF(XEaht(e)RW;1)Fr{;E{vw_Ze%R5^O_l`C9 zM9;nqt8ZOY_Kw5@@CKVl zy53rteXRL}8CN66d{6pem;TZdCfbjBFlN_4sEpaRcmrP}W+y^s5E$}YtQ*TZ_=bhQ zh1q^UWu|lYp}MgKd&>xH*-k*Xxt=h9mlknP-X(uXH`_$w5GEAiV$7?5U3I z1UZs9vZbp?&yeTyc!`SJD6Kfx9rfTin3hrAvi%S@f0lLVsKmNwnt7pfc$VoJ%tLvj zz|hw)Zw7_$9&0}4-&>99;+2k6kuN*qk5_9a;=@JY!&EyQBN!gUj%Ri*wtSMGP^S^G zI+|I{m#$VO?N$wA+{Uuid&*=*b?th3UB1B15}S$FS2`Q17DZE2oiR7naiZ%~keS$k zM_dXyU&);F3w}@`di9!Xez|Yqqu@yu$g!BoJRc*)5!aacP;nBuZwk#Z4qs< zo=57pVok9p))7t~lgaAz&payepS~6NIZ~h#b?jJU??u7P;xl|7h+{XxkFH!Vtn6r~ z$}5R#aGgBLaXFK#=)pN!W&E1+enLw;Tc*mRMkI=rJCs>alAZ^1UAvR5unYvjC&hRo z0BV!}#75WgUOCi(cU!8FioJ(@DJ$1jr?Ls3jJBm%|EqAgn^&G53!N{?mseR{wW@S0 z6IW2(xe?#I_u{^hDY{d;cue!2EgsPpLx^b}>CdIP@iQWwcWh3>IdzF= zOZQ6GqMz4Pcs5UF59s@(NLqzbaD)`)nh1k@qL>swXl7M~~vpb3e!>&i@d@Oz@}!@7b9y2{4tH8XWWQS&sE>3w%k zWGC5|4#iM*5*euniF()CXMS6WDhPFyi{p7vL9QE9J;AD|kk0fB&tMaAB;_T#POH$W zk`bVi7(t!kdgyd+4IQ*eK+J!Oi0T;KOnx9P3}hpfRb1)>)}rlnHj5Qi zOC9N!`jYs0u@9f()}qE+z1NGLK^JH(@Nb}Zz(sx;@?J(YsDMHl75@z*Rtjg*q*RjB z*o>EH4cBUU=sxu2L;MPgzNMN#Q%r-05G{#0Ua^klb^JOMbMBY0Xc5i-U#IOXs13Ty waOkBtL)LX9$fnd4ZY>se_2*ry*uhD0C25I8=6ToZCp~FZOsO6jkcd3`7Z2A+AOHXW literal 0 HcmV?d00001 From 3c0f0aba9048ff0ec7f65e1389023d67392931bc Mon Sep 17 00:00:00 2001 From: Paramveer singh Date: Tue, 23 Jun 2026 20:18:26 +0530 Subject: [PATCH 3/4] chore: remove scratch files accidentally committed --- fetch_issues.py | 14 -------------- issue_357.json | 0 issues.txt | Bin 8024 -> 0 bytes 3 files changed, 14 deletions(-) delete mode 100644 fetch_issues.py delete mode 100644 issue_357.json delete mode 100644 issues.txt diff --git a/fetch_issues.py b/fetch_issues.py deleted file mode 100644 index 34d5acaba..000000000 --- a/fetch_issues.py +++ /dev/null @@ -1,14 +0,0 @@ -import urllib.request -import json -with open("issues.txt", "w", encoding="utf-8") as f: - try: - req = urllib.request.Request('https://api.github.com/repos/google-deepmind/optax/issues?state=open&per_page=100', headers={'User-Agent': 'Mozilla/5.0'}) - response = urllib.request.urlopen(req) - issues = json.loads(response.read()) - for i in issues: - if 'pull_request' not in i: - labels = [l['name'] for l in i['labels']] - f.write(f"#{i['number']}: {i['title']} (Labels: {labels})\n") - f.write(f"URL: {i['html_url']}\n\n") - except Exception as e: - f.write(f"Error: {e}\n") diff --git a/issue_357.json b/issue_357.json deleted file mode 100644 index e69de29bb..000000000 diff --git a/issues.txt b/issues.txt deleted file mode 100644 index 5b29160259c23be6c67a86c53bfd2b5bd0c31594..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8024 zcmdU!Ur!rX5XI-YQoqAWeISv_pMnh~FI8wl(h3q3lS)+-!D3@$Vq@1fDd~rAdwzF( zxx2=x^VH4CV!3z%k0Pw?9g6VOXGF@J++QLnRV^f zZmexH>)S}9+y)xW?31p|G?UrXK4@mF(XEaht(e)RW;1)Fr{;E{vw_Ze%R5^O_l`C9 zM9;nqt8ZOY_Kw5@@CKVl zy53rteXRL}8CN66d{6pem;TZdCfbjBFlN_4sEpaRcmrP}W+y^s5E$}YtQ*TZ_=bhQ zh1q^UWu|lYp}MgKd&>xH*-k*Xxt=h9mlknP-X(uXH`_$w5GEAiV$7?5U3I z1UZs9vZbp?&yeTyc!`SJD6Kfx9rfTin3hrAvi%S@f0lLVsKmNwnt7pfc$VoJ%tLvj zz|hw)Zw7_$9&0}4-&>99;+2k6kuN*qk5_9a;=@JY!&EyQBN!gUj%Ri*wtSMGP^S^G zI+|I{m#$VO?N$wA+{Uuid&*=*b?th3UB1B15}S$FS2`Q17DZE2oiR7naiZ%~keS$k zM_dXyU&);F3w}@`di9!Xez|Yqqu@yu$g!BoJRc*)5!aacP;nBuZwk#Z4qs< zo=57pVok9p))7t~lgaAz&payepS~6NIZ~h#b?jJU??u7P;xl|7h+{XxkFH!Vtn6r~ z$}5R#aGgBLaXFK#=)pN!W&E1+enLw;Tc*mRMkI=rJCs>alAZ^1UAvR5unYvjC&hRo z0BV!}#75WgUOCi(cU!8FioJ(@DJ$1jr?Ls3jJBm%|EqAgn^&G53!N{?mseR{wW@S0 z6IW2(xe?#I_u{^hDY{d;cue!2EgsPpLx^b}>CdIP@iQWwcWh3>IdzF= zOZQ6GqMz4Pcs5UF59s@(NLqzbaD)`)nh1k@qL>swXl7M~~vpb3e!>&i@d@Oz@}!@7b9y2{4tH8XWWQS&sE>3w%k zWGC5|4#iM*5*euniF()CXMS6WDhPFyi{p7vL9QE9J;AD|kk0fB&tMaAB;_T#POH$W zk`bVi7(t!kdgyd+4IQ*eK+J!Oi0T;KOnx9P3}hpfRb1)>)}rlnHj5Qi zOC9N!`jYs0u@9f()}qE+z1NGLK^JH(@Nb}Zz(sx;@?J(YsDMHl75@z*Rtjg*q*RjB z*o>EH4cBUU=sxu2L;MPgzNMN#Q%r-05G{#0Ua^klb^JOMbMBY0Xc5i-U#IOXs13Ty waOkBtL)LX9$fnd4ZY>se_2*ry*uhD0C25I8=6ToZCp~FZOsO6jkcd3`7Z2A+AOHXW From 7d243005dabddc0cdfdece250c22640477493673 Mon Sep 17 00:00:00 2001 From: Param Date: Wed, 24 Jun 2026 09:50:34 +0530 Subject: [PATCH 4/4] fix: address reviewer feedback for SPSA estimator --- optax/contrib/_spsa.py | 9 +++++---- tests/contrib/spsa_test.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/optax/contrib/_spsa.py b/optax/contrib/_spsa.py index f128bfe47..26f3308f0 100644 --- a/optax/contrib/_spsa.py +++ b/optax/contrib/_spsa.py @@ -17,6 +17,7 @@ from typing import Any, Callable import jax +import jax.numpy as jnp import optax.tree from optax._src import base @@ -24,7 +25,7 @@ def spsa_standard_schedule( init_value: float, decay_rate: float, - offset: float = 0.0, + offset: float = 1.0, ) -> optax.schedules.Schedule: """Returns a schedule for the SPSA learning rate or perturbation scale. @@ -109,9 +110,9 @@ def sample_rademacher(k, shape, dtype): # Note: Since delta_i is either 1 or -1, dividing by delta_i is # equivalent # to multiplying by delta_i. We multiply for numerical stability. - grad_estimate = jax.tree.map( - lambda d: (y_plus - y_minus) / (2.0 * c) * d, delta - ) + safe_c = jnp.maximum(c, jnp.finfo(jnp.result_type(c)).eps) + scalar_diff = (y_plus - y_minus) / (2.0 * safe_c) + grad_estimate = jax.tree.map(lambda d: scalar_diff * d, delta) return grad_estimate return grad_fn diff --git a/tests/contrib/spsa_test.py b/tests/contrib/spsa_test.py index 1ea5436ac..f29d3fab1 100644 --- a/tests/contrib/spsa_test.py +++ b/tests/contrib/spsa_test.py @@ -18,6 +18,7 @@ import chex import jax import jax.numpy as jnp +import numpy as np import optax @@ -85,8 +86,8 @@ def test_spsa_standard_schedule(self): val_0 = schedule(0) val_10 = schedule(10) - self.assertAlmostEqual(val_0, 1.0 / (10.0**0.5)) - self.assertAlmostEqual(val_10, 1.0 / (20.0**0.5)) + np.testing.assert_allclose(val_0, 1.0 / (10.0**0.5)) + np.testing.assert_allclose(val_10, 1.0 / (20.0**0.5)) if __name__ == "__main__":