From 3c34c40540a1fa452039736711e01822dac7e900 Mon Sep 17 00:00:00 2001 From: yurekami Date: Sun, 19 Apr 2026 20:59:58 +0800 Subject: [PATCH 1/2] feat(perturbations): rename `sigma` to `scale`, default `num_samples` to 1 Implements the two refinements proposed in #1342: - Rename `make_perturbed_fun`'s `sigma` parameter to `scale`, the more generic name suggested in the issue. The old `sigma` kwarg is kept as a keyword-only deprecated alias that emits a `DeprecationWarning`; positional callers are unaffected. - Change the default value of `num_samples` from 1000 to 1, matching the convention common in stochastic optimization. Updates tests (adds test_sigma_deprecation and test_default_num_samples), the perturbations example notebook, and the docstring. --- examples/perturbations.ipynb | 40 +++++--------------------- optax/perturbations/_make_pert.py | 27 ++++++++++++----- optax/perturbations/_make_pert_test.py | 28 ++++++++++++++++-- 3 files changed, 53 insertions(+), 42 deletions(-) diff --git a/examples/perturbations.ipynb b/examples/perturbations.ipynb index 166ed95a6..67690be19 100644 --- a/examples/perturbations.ipynb +++ b/examples/perturbations.ipynb @@ -128,22 +128,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "7hQz6zuPwkpZ" }, "outputs": [], - "source": [ - "N_SAMPLES = 100\n", - "SIGMA = 0.5\n", - "GUMBEL = perturbations.Gumbel()\n", - "\n", - "rng = jax.random.PRNGKey(1)\n", - "pert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,\n", - " num_samples=N_SAMPLES,\n", - " sigma=SIGMA,\n", - " noise=GUMBEL)" - ] + "source": "N_SAMPLES = 100\nSIGMA = 0.5\nGUMBEL = perturbations.Gumbel()\n\nrng = jax.random.PRNGKey(1)\npert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,\n num_samples=N_SAMPLES,\n scale=SIGMA,\n noise=GUMBEL)" }, { "cell_type": "markdown", @@ -439,21 +429,12 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "id": "Equ3_gDPbf5n" }, "outputs": [], - "source": [ - "N_SAMPLES = 100\n", - "SIGMA = 0.2\n", - "GUMBEL = perturbations.Gumbel()\n", - "\n", - "pert_ranking = perturbations.make_perturbed_fun(ranking,\n", - " num_samples=N_SAMPLES,\n", - " sigma=SIGMA,\n", - " noise=GUMBEL)" - ] + "source": "N_SAMPLES = 100\nSIGMA = 0.2\nGUMBEL = perturbations.Gumbel()\n\npert_ranking = perturbations.make_perturbed_fun(ranking,\n num_samples=N_SAMPLES,\n scale=SIGMA,\n noise=GUMBEL)" }, { "cell_type": "code", @@ -712,19 +693,12 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "id": "oKuD_cElxSDd" }, "outputs": [], - "source": [ - "N_SAMPLES = 100\n", - "sigma = 1.0\n", - "\n", - "pert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,\n", - " num_samples=N_SAMPLES,\n", - " sigma=SIGMA)" - ] + "source": "N_SAMPLES = 100\nsigma = 1.0\n\npert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,\n num_samples=N_SAMPLES,\n scale=SIGMA)" }, { "cell_type": "code", @@ -921,4 +895,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/optax/perturbations/_make_pert.py b/optax/perturbations/_make_pert.py index b9ee7df94..60a6f1d0f 100644 --- a/optax/perturbations/_make_pert.py +++ b/optax/perturbations/_make_pert.py @@ -16,7 +16,8 @@ """Creates a differentiable approximation of a function with perturbations.""" -from typing import Callable +from typing import Callable, Optional +import warnings import jax import jax.numpy as jnp @@ -56,10 +57,12 @@ def log_prob(self, inputs: jax.Array) -> jax.Array: def make_perturbed_fun( fun: Callable[[base.ArrayTree], base.ArrayTree], - num_samples: int = 1000, - sigma: jax.typing.ArrayLike = 0.1, + num_samples: int = 1, + scale: jax.typing.ArrayLike = 0.1, noise=Gumbel(), use_baseline=True, + *, + sigma: Optional[jax.typing.ArrayLike] = None, ) -> Callable[[base.PRNGKey, base.ArrayTree], base.ArrayTree]: r"""Returns a differentiable approximation of a function, using stochastic perturbations. @@ -86,11 +89,13 @@ def make_perturbed_fun( fun: The function to transform into a differentiable function. The signature currently supported is from pytree to pytree, whose leaves are JAX arrays. num_samples: an int, the number of perturbed outputs to average over. - sigma: a float, the scale of the random perturbation. + scale: a float, the scale of the random perturbation (denoted + :math:`\sigma` in the formula above). noise: a distribution object that implements ``sample`` and ``log_prob`` methods, like :class:`optax.perturbations.Gumbel` (which is the default). use_baseline: Use the value of the function at the unperturbed input as a baseline for variance reduction. + sigma: deprecated alias for ``scale``, kept for backward compatibility. Returns: A new function with the same signature as the original function, but with a @@ -103,7 +108,7 @@ def make_perturbed_fun( >>> key = jax.random.key(0) >>> x = jnp.array([0.0, 0.0, 0.0]) >>> f = lambda x: jnp.sum(jnp.maximum(x, 0.0)) - >>> fn = make_perturbed_fun(f, 1_000, 0.1) + >>> fn = make_perturbed_fun(f, num_samples=1_000, scale=0.1) >>> with jnp.printoptions(precision=2): ... print(jax.grad(fn, argnums=1)(key, x)) [0.69 0.72 0.58] @@ -137,6 +142,14 @@ def make_perturbed_fun( .. seealso:: * :doc:`../_collections/examples/perturbations` example. """ # noqa: E501 + if sigma is not None: + warnings.warn( + "The `sigma` argument of `make_perturbed_fun` is deprecated; use" + " `scale` instead.", + DeprecationWarning, + stacklevel=2, + ) + scale = sigma def mc_estimator(key: base.PRNGKey, x: base.ArrayTree) -> base.ArrayTree: @@ -144,9 +157,9 @@ def stoch_estimator( key: base.PRNGKey, x: base.ArrayTree, baseline: base.ArrayTree ) -> base.ArrayTree: sample = optax.tree.random_like(key, x, sampler=noise.sample) - shifted_sample = jax.tree.map(lambda x, z: x + sigma * z, x, sample) + shifted_sample = jax.tree.map(lambda x, z: x + scale * z, x, sample) shifted_sample = jax.lax.stop_gradient(shifted_sample) - sample = jax.tree.map(lambda x, y: (y - x) / sigma, x, shifted_sample) + sample = jax.tree.map(lambda x, y: (y - x) / scale, x, shifted_sample) log_prob_sample = optax.tree.sum(jax.tree.map(noise.log_prob, sample)) box = _magicbox(log_prob_sample) diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 41203809c..8911f034b 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -16,6 +16,7 @@ """Tests for optax.perturbations, checking values and gradients.""" from functools import partial # pylint: disable=g-importing-member +import warnings from absl.testing import absltest from absl.testing import parameterized @@ -162,7 +163,7 @@ def loss(tree): return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss) loss_pert = jax.jit(_make_pert.make_perturbed_fun( - loss, num_samples=100, sigma=0.1, noise=_make_pert.Normal() + loss, num_samples=100, scale=0.1, noise=_make_pert.Normal() )) keys = jax.random.split(key, 3) low_loss = loss_pert(keys[0], example_tree) # pytype: disable=wrong-arg-types # noqa: E501 @@ -233,7 +234,7 @@ def f(x): return jnp.stack([y0, y1]) f1 = jax.jit(_make_pert.make_perturbed_fun( - f, num_samples=num_samples, sigma=sigma, noise=noise)) + f, num_samples=num_samples, scale=sigma, noise=noise)) f2 = simple_make_perturbed_fun(f, num_samples=num_samples, sigma=sigma, noise=noise) x = jnp.array([0.3, 0.4, 0.5]) @@ -279,6 +280,29 @@ def test_hessian(self, sigma, noise): expected = jax.hessian(fun)(x) test_utils.assert_trees_all_close(got, expected, atol=1e-1) + def test_sigma_deprecation(self): + """Passing the legacy ``sigma`` kwarg emits a DeprecationWarning.""" + fun = lambda x: jnp.sum(x) + x = jnp.array([0.0, 0.0]) + key = jax.random.key(0) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + fp = _make_pert.make_perturbed_fun(fun, num_samples=2, sigma=0.1) + fp(key, x) + deprecation_warnings = [ + w for w in caught if issubclass(w.category, DeprecationWarning) + ] + self.assertTrue(deprecation_warnings) + self.assertIn('sigma', str(deprecation_warnings[0].message)) + + def test_default_num_samples(self): + """The default ``num_samples=1`` executes without error.""" + fun = lambda x: jnp.sum(x) + x = jnp.array([0.0, 0.0]) + key = jax.random.key(0) + fp = _make_pert.make_perturbed_fun(fun) + fp(key, x) + if __name__ == '__main__': absltest.main() From ebfccbb4726d31eca1be4840937e26578b210787 Mon Sep 17 00:00:00 2001 From: yurekami Date: Sun, 19 Apr 2026 21:30:07 +0800 Subject: [PATCH 2/2] fix(perturbations): drop unnecessary lambdas in new tests (pylint W0108) --- optax/perturbations/_make_pert_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optax/perturbations/_make_pert_test.py b/optax/perturbations/_make_pert_test.py index 8911f034b..cc9306839 100644 --- a/optax/perturbations/_make_pert_test.py +++ b/optax/perturbations/_make_pert_test.py @@ -282,7 +282,7 @@ def test_hessian(self, sigma, noise): def test_sigma_deprecation(self): """Passing the legacy ``sigma`` kwarg emits a DeprecationWarning.""" - fun = lambda x: jnp.sum(x) + fun = jnp.sum x = jnp.array([0.0, 0.0]) key = jax.random.key(0) with warnings.catch_warnings(record=True) as caught: @@ -297,7 +297,7 @@ def test_sigma_deprecation(self): def test_default_num_samples(self): """The default ``num_samples=1`` executes without error.""" - fun = lambda x: jnp.sum(x) + fun = jnp.sum x = jnp.array([0.0, 0.0]) key = jax.random.key(0) fp = _make_pert.make_perturbed_fun(fun)