Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/projections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ Available projections
.. autosummary::
:toctree: generated/

projection_affine_set
projection_box
projection_box_section
projection_hypercube
projection_l1_ball
projection_l1_sphere
Expand Down
2 changes: 2 additions & 0 deletions optax/projections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

# pylint: disable=g-importing-member

from optax.projections._projections import projection_affine_set
from optax.projections._projections import projection_box
from optax.projections._projections import projection_box_section
from optax.projections._projections import projection_halfspace
from optax.projections._projections import projection_hypercube
from optax.projections._projections import projection_hyperplane
Expand Down
129 changes: 129 additions & 0 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,132 @@ def projection_halfspace(x: Any, a: Any, b: jax.typing.ArrayLike) -> Any:
scalar = (b - optax.tree.vdot(x, a)) / optax.tree.vdot(a, a)
scalar = jnp.clip(scalar, max=0)
return optax.tree.add_scale(x, scalar, a)


def projection_affine_set(
x: jax.typing.ArrayLike,
a: jax.typing.ArrayLike,
b: jax.typing.ArrayLike,
) -> jax.Array:
r"""Projection onto an affine set.

Projects a vector ``x`` onto the affine set defined by a matrix ``a`` and a
vector ``b``.

.. math::

\operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad
a y = b

The projection is computed in closed form,
:math:`y = x + a^\top (a a^\top)^{-1} (b - a x)`.

Args:
x: array of shape ``(n,)`` to project.
a: matrix of shape ``(m, n)``, with ``m <= n``, defining the affine set.
Must have linearly independent rows.
b: vector of shape ``(m,)`` defining the affine set.

Returns:
projected array, with the same shape as ``x``.

Example:

>>> import jax.numpy as jnp
>>> from optax import projections
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> a = jnp.array([[1.0, 1.0, 1.0]])
>>> b = jnp.array([3.0])
>>> print(projections.projection_affine_set(x, a, b))
[0. 1. 2.]

.. versionadded:: 0.2.9
"""
x = jnp.asarray(x)
a = jnp.asarray(a)
b = jnp.asarray(b)
return x + a.T @ jnp.linalg.solve(a @ a.T, b - a @ x)


def projection_box_section(
x: jax.typing.ArrayLike,
lower: jax.typing.ArrayLike,
upper: jax.typing.ArrayLike,
w: jax.typing.ArrayLike,
c: jax.typing.ArrayLike,
) -> jax.Array:
r"""Projection onto a section of a box.

Projects a vector ``x`` onto the intersection of a box (hyperrectangle)
and a hyperplane with positive coefficient vector ``w``.

.. math::

\operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad
\text{lower} \le y \le \text{upper}, \langle w, y \rangle = c

The solution has the form :math:`y_i = \operatorname{clip}(x_i + \tau w_i,
\text{lower}_i, \text{upper}_i)`, where the scalar :math:`\tau` is the root
of a monotone function, found by bisection. The projection is
differentiable, via implicit differentiation of the root.

The constraint set is non-empty if and only if :math:`\langle w,
\text{lower} \rangle \le c \le \langle w, \text{upper} \rangle`; the result
is undefined otherwise.

Args:
x: array of shape ``(n,)`` to project.
lower: lower bound of the box, a scalar or an array broadcastable to the
shape of ``x``.
upper: upper bound of the box, a scalar or an array broadcastable to the
shape of ``x``.
w: weights of the hyperplane, an array with the same shape as ``x``. All
entries must be positive.
c: scalar defining the hyperplane.

Returns:
projected array, with the same shape as ``x``.

Example:

>>> import jax.numpy as jnp
>>> from optax import projections
>>> x = jnp.array([0.5, 1.5])
>>> w = jnp.array([1.0, 1.0])
>>> print(projections.projection_box_section(x, 0.0, 1.0, w, 1.0))
[0. 1.]

.. versionadded:: 0.2.9
"""
x = jnp.asarray(x)
w = jnp.asarray(w)

def residual(tau):
# Monotonically non-decreasing in tau, since w > 0.
return jnp.dot(w, jnp.clip(x + tau * w, lower, upper)) - c

# For tau below (resp. above) the bracket, all coordinates of the candidate
# solution hit the lower (resp. upper) bound, so by feasibility the residual
# is non-positive (resp. non-negative) and the bracket contains a root.
bracket_low = jax.lax.stop_gradient(jnp.min((lower - x) / w))
bracket_high = jax.lax.stop_gradient(jnp.max((upper - x) / w))

def bisect(fun, init):
del init # the root is bracketed by construction

def body_fun(_, bracket):
low, high = bracket
mid = 0.5 * (low + high)
go_left = fun(mid) >= 0
return jnp.where(go_left, low, mid), jnp.where(go_left, mid, high)

# 100 iterations narrow the bracket by a factor of 2**100, well below
# float64 resolution for any realistic input.
low, high = jax.lax.fori_loop(0, 100, body_fun, (bracket_low, bracket_high))
return 0.5 * (low + high)

def tangent_solve(g, y):
return y / g(1.0)

tau = jax.lax.custom_root(residual, 0.0, bisect, tangent_solve)
return jnp.clip(x + tau * w, lower, upper)
89 changes: 89 additions & 0 deletions optax/projections/_projections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,95 @@ def test_projection_halfspace_2(self):
y_expected = x
assert optax.tree.allclose(y_actual, y_expected)

def test_projection_affine_set(self):
rng = np.random.RandomState(0)
a = jnp.asarray(rng.randn(3, 7).astype(np.float32))
b = jnp.asarray(rng.randn(3).astype(np.float32))
x = jnp.asarray(rng.randn(7).astype(np.float32))
p = proj.projection_affine_set(x, a, b)

with self.subTest('Check feasibility'):
np.testing.assert_array_almost_equal(a @ p, b)

with self.subTest('Check optimality'):
# p is the projection iff a @ p = b and x - p is orthogonal to the
# null space of a.
null_space = np.linalg.svd(a)[2][a.shape[0]:]
np.testing.assert_array_almost_equal(
null_space @ (x - p), jnp.zeros(null_space.shape[0])
)

with self.subTest('Check idempotence'):
np.testing.assert_array_almost_equal(
proj.projection_affine_set(p, a, b), p, decimal=5
)

with self.subTest('Check consistency with projection_hyperplane'):
p_affine = proj.projection_affine_set(x, a[:1], b[:1])
p_hyperplane = proj.projection_hyperplane(x, a[0], b[0])
np.testing.assert_array_almost_equal(p_affine, p_hyperplane)

def test_projection_box_section(self):
rng = np.random.RandomState(0)
x = jnp.asarray(rng.randn(8).astype(np.float32))
w = jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(np.float32))
lower = jnp.asarray(rng.randn(8).astype(np.float32) - 1.0)
upper = lower + jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(
np.float32))
c = 0.7 * jnp.dot(w, lower) + 0.3 * jnp.dot(w, upper)
p = proj.projection_box_section(x, lower, upper, w, c)

with self.subTest('Check feasibility'):
np.testing.assert_almost_equal(jnp.dot(w, p), c, decimal=4)
self.assertTrue(jnp.all(p >= lower))
self.assertTrue(jnp.all(p <= upper))

with self.subTest('Check idempotence'):
np.testing.assert_array_almost_equal(
proj.projection_box_section(p, lower, upper, w, c), p, decimal=5
)

with self.subTest('Check under jit'):
p_jit = jax.jit(proj.projection_box_section)(x, lower, upper, w, c)
np.testing.assert_array_almost_equal(p_jit, p)

with self.subTest('Check consistency with projection_simplex'):
# With lower=0, an inactive upper bound and unit weights, the box
# section is the unit simplex.
ones = jnp.ones_like(x)
np.testing.assert_array_almost_equal(
proj.projection_box_section(x, 0.0, 10.0, ones, 1.0),
proj.projection_simplex(x),
)

def test_projection_box_section_jacobian(self):
rng = np.random.RandomState(0)
x = jnp.asarray(rng.randn(8).astype(np.float32))
w = jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(np.float32))
lower = jnp.asarray(rng.randn(8).astype(np.float32) - 1.0)
upper = lower + jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(
np.float32))
c = 0.7 * jnp.dot(w, lower) + 0.3 * jnp.dot(w, upper)
v = jnp.asarray(rng.randn(8).astype(np.float32))

fun = lambda x: proj.projection_box_section(x, lower, upper, w, c)

with self.subTest('Check against finite difference'):
jvp = jax.jvp(fun, (x,), (v,))[1]
eps = 1e-3
jvp_finite_diff = (fun(x + eps * v) - fun(x - eps * v)) / (2 * eps)
np.testing.assert_array_almost_equal(jvp, jvp_finite_diff, decimal=3)

with self.subTest('Check gradient with respect to c'):
fun_c = lambda c: jnp.sum(
proj.projection_box_section(x, lower, upper, w, c) ** 2
)
eps = 1e-2
grad_finite_diff = (fun_c(c + eps) - fun_c(c - eps)) / (2 * eps)
np.testing.assert_almost_equal(
jax.grad(fun_c)(c), grad_finite_diff, decimal=3
)


if __name__ == '__main__':
absltest.main()
Loading