diff --git a/docs/api/projections.rst b/docs/api/projections.rst index 4bb179047..4490bb6ab 100644 --- a/docs/api/projections.rst +++ b/docs/api/projections.rst @@ -33,7 +33,9 @@ Available projections .. autosummary:: :toctree: generated/ + projection_affine_set projection_box + projection_box_section projection_hypercube projection_l1_ball projection_l1_sphere diff --git a/optax/projections/__init__.py b/optax/projections/__init__.py index 6e38839b0..3fc223b01 100644 --- a/optax/projections/__init__.py +++ b/optax/projections/__init__.py @@ -17,7 +17,9 @@ # pylint: disable=g-importing-member +from optax.projections._projections import projection_affine_set from optax.projections._projections import projection_box +from optax.projections._projections import projection_box_section from optax.projections._projections import projection_halfspace from optax.projections._projections import projection_hypercube from optax.projections._projections import projection_hyperplane diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index 97209c9d8..78dc600ea 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -364,3 +364,132 @@ def projection_halfspace(x: Any, a: Any, b: jax.typing.ArrayLike) -> Any: scalar = (b - optax.tree.vdot(x, a)) / optax.tree.vdot(a, a) scalar = jnp.clip(scalar, max=0) return optax.tree.add_scale(x, scalar, a) + + +def projection_affine_set( + x: jax.typing.ArrayLike, + a: jax.typing.ArrayLike, + b: jax.typing.ArrayLike, +) -> jax.Array: + r"""Projection onto an affine set. + + Projects a vector ``x`` onto the affine set defined by a matrix ``a`` and a + vector ``b``. + + .. math:: + + \operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad + a y = b + + The projection is computed in closed form, + :math:`y = x + a^\top (a a^\top)^{-1} (b - a x)`. + + Args: + x: array of shape ``(n,)`` to project. + a: matrix of shape ``(m, n)``, with ``m <= n``, defining the affine set. + Must have linearly independent rows. + b: vector of shape ``(m,)`` defining the affine set. + + Returns: + projected array, with the same shape as ``x``. + + Example: + + >>> import jax.numpy as jnp + >>> from optax import projections + >>> x = jnp.array([1.0, 2.0, 3.0]) + >>> a = jnp.array([[1.0, 1.0, 1.0]]) + >>> b = jnp.array([3.0]) + >>> print(projections.projection_affine_set(x, a, b)) + [0. 1. 2.] + + .. versionadded:: 0.2.9 + """ + x = jnp.asarray(x) + a = jnp.asarray(a) + b = jnp.asarray(b) + return x + a.T @ jnp.linalg.solve(a @ a.T, b - a @ x) + + +def projection_box_section( + x: jax.typing.ArrayLike, + lower: jax.typing.ArrayLike, + upper: jax.typing.ArrayLike, + w: jax.typing.ArrayLike, + c: jax.typing.ArrayLike, +) -> jax.Array: + r"""Projection onto a section of a box. + + Projects a vector ``x`` onto the intersection of a box (hyperrectangle) + and a hyperplane with positive coefficient vector ``w``. + + .. math:: + + \operatorname{argmin}_y \|x - y\|_2^2 \quad \text{subject to} \quad + \text{lower} \le y \le \text{upper}, \langle w, y \rangle = c + + The solution has the form :math:`y_i = \operatorname{clip}(x_i + \tau w_i, + \text{lower}_i, \text{upper}_i)`, where the scalar :math:`\tau` is the root + of a monotone function, found by bisection. The projection is + differentiable, via implicit differentiation of the root. + + The constraint set is non-empty if and only if :math:`\langle w, + \text{lower} \rangle \le c \le \langle w, \text{upper} \rangle`; the result + is undefined otherwise. + + Args: + x: array of shape ``(n,)`` to project. + lower: lower bound of the box, a scalar or an array broadcastable to the + shape of ``x``. + upper: upper bound of the box, a scalar or an array broadcastable to the + shape of ``x``. + w: weights of the hyperplane, an array with the same shape as ``x``. All + entries must be positive. + c: scalar defining the hyperplane. + + Returns: + projected array, with the same shape as ``x``. + + Example: + + >>> import jax.numpy as jnp + >>> from optax import projections + >>> x = jnp.array([0.5, 1.5]) + >>> w = jnp.array([1.0, 1.0]) + >>> print(projections.projection_box_section(x, 0.0, 1.0, w, 1.0)) + [0. 1.] + + .. versionadded:: 0.2.9 + """ + x = jnp.asarray(x) + w = jnp.asarray(w) + + def residual(tau): + # Monotonically non-decreasing in tau, since w > 0. + return jnp.dot(w, jnp.clip(x + tau * w, lower, upper)) - c + + # For tau below (resp. above) the bracket, all coordinates of the candidate + # solution hit the lower (resp. upper) bound, so by feasibility the residual + # is non-positive (resp. non-negative) and the bracket contains a root. + bracket_low = jax.lax.stop_gradient(jnp.min((lower - x) / w)) + bracket_high = jax.lax.stop_gradient(jnp.max((upper - x) / w)) + + def bisect(fun, init): + del init # the root is bracketed by construction + + def body_fun(_, bracket): + low, high = bracket + mid = 0.5 * (low + high) + go_left = fun(mid) >= 0 + return jnp.where(go_left, low, mid), jnp.where(go_left, mid, high) + + # 100 iterations narrow the bracket by a factor of 2**100, well below + # float64 resolution for any realistic input. + low, high = jax.lax.fori_loop(0, 100, body_fun, (bracket_low, bracket_high)) + return 0.5 * (low + high) + + def tangent_solve(g, y): + return y / g(1.0) + + tau = jax.lax.custom_root(residual, 0.0, bisect, tangent_solve) + return jnp.clip(x + tau * w, lower, upper) diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 132298043..8a04b96c8 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -270,6 +270,95 @@ def test_projection_halfspace_2(self): y_expected = x assert optax.tree.allclose(y_actual, y_expected) + def test_projection_affine_set(self): + rng = np.random.RandomState(0) + a = jnp.asarray(rng.randn(3, 7).astype(np.float32)) + b = jnp.asarray(rng.randn(3).astype(np.float32)) + x = jnp.asarray(rng.randn(7).astype(np.float32)) + p = proj.projection_affine_set(x, a, b) + + with self.subTest('Check feasibility'): + np.testing.assert_array_almost_equal(a @ p, b) + + with self.subTest('Check optimality'): + # p is the projection iff a @ p = b and x - p is orthogonal to the + # null space of a. + null_space = np.linalg.svd(a)[2][a.shape[0]:] + np.testing.assert_array_almost_equal( + null_space @ (x - p), jnp.zeros(null_space.shape[0]) + ) + + with self.subTest('Check idempotence'): + np.testing.assert_array_almost_equal( + proj.projection_affine_set(p, a, b), p, decimal=5 + ) + + with self.subTest('Check consistency with projection_hyperplane'): + p_affine = proj.projection_affine_set(x, a[:1], b[:1]) + p_hyperplane = proj.projection_hyperplane(x, a[0], b[0]) + np.testing.assert_array_almost_equal(p_affine, p_hyperplane) + + def test_projection_box_section(self): + rng = np.random.RandomState(0) + x = jnp.asarray(rng.randn(8).astype(np.float32)) + w = jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(np.float32)) + lower = jnp.asarray(rng.randn(8).astype(np.float32) - 1.0) + upper = lower + jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype( + np.float32)) + c = 0.7 * jnp.dot(w, lower) + 0.3 * jnp.dot(w, upper) + p = proj.projection_box_section(x, lower, upper, w, c) + + with self.subTest('Check feasibility'): + np.testing.assert_almost_equal(jnp.dot(w, p), c, decimal=4) + self.assertTrue(jnp.all(p >= lower)) + self.assertTrue(jnp.all(p <= upper)) + + with self.subTest('Check idempotence'): + np.testing.assert_array_almost_equal( + proj.projection_box_section(p, lower, upper, w, c), p, decimal=5 + ) + + with self.subTest('Check under jit'): + p_jit = jax.jit(proj.projection_box_section)(x, lower, upper, w, c) + np.testing.assert_array_almost_equal(p_jit, p) + + with self.subTest('Check consistency with projection_simplex'): + # With lower=0, an inactive upper bound and unit weights, the box + # section is the unit simplex. + ones = jnp.ones_like(x) + np.testing.assert_array_almost_equal( + proj.projection_box_section(x, 0.0, 10.0, ones, 1.0), + proj.projection_simplex(x), + ) + + def test_projection_box_section_jacobian(self): + rng = np.random.RandomState(0) + x = jnp.asarray(rng.randn(8).astype(np.float32)) + w = jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype(np.float32)) + lower = jnp.asarray(rng.randn(8).astype(np.float32) - 1.0) + upper = lower + jnp.asarray(rng.uniform(0.5, 2.0, size=8).astype( + np.float32)) + c = 0.7 * jnp.dot(w, lower) + 0.3 * jnp.dot(w, upper) + v = jnp.asarray(rng.randn(8).astype(np.float32)) + + fun = lambda x: proj.projection_box_section(x, lower, upper, w, c) + + with self.subTest('Check against finite difference'): + jvp = jax.jvp(fun, (x,), (v,))[1] + eps = 1e-3 + jvp_finite_diff = (fun(x + eps * v) - fun(x - eps * v)) / (2 * eps) + np.testing.assert_array_almost_equal(jvp, jvp_finite_diff, decimal=3) + + with self.subTest('Check gradient with respect to c'): + fun_c = lambda c: jnp.sum( + proj.projection_box_section(x, lower, upper, w, c) ** 2 + ) + eps = 1e-2 + grad_finite_diff = (fun_c(c + eps) - fun_c(c - eps)) / (2 * eps) + np.testing.assert_almost_equal( + jax.grad(fun_c)(c), grad_finite_diff, decimal=3 + ) + if __name__ == '__main__': absltest.main()