Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 7 additions & 33 deletions examples/perturbations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,12 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {
"id": "7hQz6zuPwkpZ"
},
"outputs": [],
"source": [
"N_SAMPLES = 100\n",
"SIGMA = 0.5\n",
"GUMBEL = perturbations.Gumbel()\n",
"\n",
"rng = jax.random.PRNGKey(1)\n",
"pert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,\n",
" num_samples=N_SAMPLES,\n",
" sigma=SIGMA,\n",
" noise=GUMBEL)"
]
"source": "N_SAMPLES = 100\nSIGMA = 0.5\nGUMBEL = perturbations.Gumbel()\n\nrng = jax.random.PRNGKey(1)\npert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,\n num_samples=N_SAMPLES,\n scale=SIGMA,\n noise=GUMBEL)"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -439,21 +429,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {
"id": "Equ3_gDPbf5n"
},
"outputs": [],
"source": [
"N_SAMPLES = 100\n",
"SIGMA = 0.2\n",
"GUMBEL = perturbations.Gumbel()\n",
"\n",
"pert_ranking = perturbations.make_perturbed_fun(ranking,\n",
" num_samples=N_SAMPLES,\n",
" sigma=SIGMA,\n",
" noise=GUMBEL)"
]
"source": "N_SAMPLES = 100\nSIGMA = 0.2\nGUMBEL = perturbations.Gumbel()\n\npert_ranking = perturbations.make_perturbed_fun(ranking,\n num_samples=N_SAMPLES,\n scale=SIGMA,\n noise=GUMBEL)"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -712,19 +693,12 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": null,
"metadata": {
"id": "oKuD_cElxSDd"
},
"outputs": [],
"source": [
"N_SAMPLES = 100\n",
"sigma = 1.0\n",
"\n",
"pert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,\n",
" num_samples=N_SAMPLES,\n",
" sigma=SIGMA)"
]
"source": "N_SAMPLES = 100\nsigma = 1.0\n\npert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,\n num_samples=N_SAMPLES,\n scale=SIGMA)"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -921,4 +895,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
27 changes: 20 additions & 7 deletions optax/perturbations/_make_pert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Creates a differentiable approximation of a function with perturbations."""


from typing import Callable
from typing import Callable, Optional
import warnings

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -56,10 +57,12 @@ def log_prob(self, inputs: jax.Array) -> jax.Array:

def make_perturbed_fun(
fun: Callable[[base.ArrayTree], base.ArrayTree],
num_samples: int = 1000,
sigma: jax.typing.ArrayLike = 0.1,
num_samples: int = 1,
scale: jax.typing.ArrayLike = 0.1,
noise=Gumbel(),
use_baseline=True,
*,
sigma: Optional[jax.typing.ArrayLike] = None,
) -> Callable[[base.PRNGKey, base.ArrayTree], base.ArrayTree]:
r"""Returns a differentiable approximation of a function, using stochastic perturbations.

Expand All @@ -86,11 +89,13 @@ def make_perturbed_fun(
fun: The function to transform into a differentiable function. The signature
currently supported is from pytree to pytree, whose leaves are JAX arrays.
num_samples: an int, the number of perturbed outputs to average over.
sigma: a float, the scale of the random perturbation.
scale: a float, the scale of the random perturbation (denoted
:math:`\sigma` in the formula above).
noise: a distribution object that implements ``sample`` and ``log_prob``
methods, like :class:`optax.perturbations.Gumbel` (which is the default).
use_baseline: Use the value of the function at the unperturbed input as a
baseline for variance reduction.
sigma: deprecated alias for ``scale``, kept for backward compatibility.

Returns:
A new function with the same signature as the original function, but with a
Expand All @@ -103,7 +108,7 @@ def make_perturbed_fun(
>>> key = jax.random.key(0)
>>> x = jnp.array([0.0, 0.0, 0.0])
>>> f = lambda x: jnp.sum(jnp.maximum(x, 0.0))
>>> fn = make_perturbed_fun(f, 1_000, 0.1)
>>> fn = make_perturbed_fun(f, num_samples=1_000, scale=0.1)
>>> with jnp.printoptions(precision=2):
... print(jax.grad(fn, argnums=1)(key, x))
[0.69 0.72 0.58]
Expand Down Expand Up @@ -137,16 +142,24 @@ def make_perturbed_fun(
.. seealso::
* :doc:`../_collections/examples/perturbations` example.
""" # noqa: E501
if sigma is not None:
warnings.warn(
"The `sigma` argument of `make_perturbed_fun` is deprecated; use"
" `scale` instead.",
DeprecationWarning,
stacklevel=2,
)
scale = sigma

def mc_estimator(key: base.PRNGKey, x: base.ArrayTree) -> base.ArrayTree:

def stoch_estimator(
key: base.PRNGKey, x: base.ArrayTree, baseline: base.ArrayTree
) -> base.ArrayTree:
sample = optax.tree.random_like(key, x, sampler=noise.sample)
shifted_sample = jax.tree.map(lambda x, z: x + sigma * z, x, sample)
shifted_sample = jax.tree.map(lambda x, z: x + scale * z, x, sample)
shifted_sample = jax.lax.stop_gradient(shifted_sample)
sample = jax.tree.map(lambda x, y: (y - x) / sigma, x, shifted_sample)
sample = jax.tree.map(lambda x, y: (y - x) / scale, x, shifted_sample)

log_prob_sample = optax.tree.sum(jax.tree.map(noise.log_prob, sample))
box = _magicbox(log_prob_sample)
Expand Down
28 changes: 26 additions & 2 deletions optax/perturbations/_make_pert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests for optax.perturbations, checking values and gradients."""

from functools import partial # pylint: disable=g-importing-member
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -162,7 +163,7 @@ def loss(tree):
return jax.tree.map(lambda *leaves: sum(leaves) / len(leaves), list_loss)

loss_pert = jax.jit(_make_pert.make_perturbed_fun(
loss, num_samples=100, sigma=0.1, noise=_make_pert.Normal()
loss, num_samples=100, scale=0.1, noise=_make_pert.Normal()
))
keys = jax.random.split(key, 3)
low_loss = loss_pert(keys[0], example_tree) # pytype: disable=wrong-arg-types # noqa: E501
Expand Down Expand Up @@ -233,7 +234,7 @@ def f(x):
return jnp.stack([y0, y1])

f1 = jax.jit(_make_pert.make_perturbed_fun(
f, num_samples=num_samples, sigma=sigma, noise=noise))
f, num_samples=num_samples, scale=sigma, noise=noise))
f2 = simple_make_perturbed_fun(f, num_samples=num_samples, sigma=sigma,
noise=noise)
x = jnp.array([0.3, 0.4, 0.5])
Expand Down Expand Up @@ -279,6 +280,29 @@ def test_hessian(self, sigma, noise):
expected = jax.hessian(fun)(x)
test_utils.assert_trees_all_close(got, expected, atol=1e-1)

def test_sigma_deprecation(self):
"""Passing the legacy ``sigma`` kwarg emits a DeprecationWarning."""
fun = jnp.sum
x = jnp.array([0.0, 0.0])
key = jax.random.key(0)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter('always')
fp = _make_pert.make_perturbed_fun(fun, num_samples=2, sigma=0.1)
fp(key, x)
deprecation_warnings = [
w for w in caught if issubclass(w.category, DeprecationWarning)
]
self.assertTrue(deprecation_warnings)
self.assertIn('sigma', str(deprecation_warnings[0].message))

def test_default_num_samples(self):
"""The default ``num_samples=1`` executes without error."""
fun = jnp.sum
x = jnp.array([0.0, 0.0])
key = jax.random.key(0)
fp = _make_pert.make_perturbed_fun(fun)
fp(key, x)


if __name__ == '__main__':
absltest.main()
Loading