Skip to content

Add affine set and box section projections from JAXopt#1701

Open
discobot wants to merge 1 commit into
google-deepmind:mainfrom
discobot:fix/1280-affine-set-box-section
Open

Add affine set and box section projections from JAXopt#1701
discobot wants to merge 1 commit into
google-deepmind:mainfrom
discobot:fix/1280-affine-set-box-section

Conversation

@discobot

Copy link
Copy Markdown

Fixes #1280.

This adds projection_affine_set and projection_box_section, two of the
three remaining projections from the checklist.

It turns out neither needs to wait on the solver decision in #977. JAXopt's
projection_affine_set delegates to EqualityConstrainedQP, but the
equality-constrained projection has the closed-form KKT solution
p = x + aᵀ (a aᵀ)⁻¹ (b − a x), which is what's implemented here.
projection_box_section only needs the root of a monotone scalar function,
so the bisection is done internally (the bracket
[min((lower − x)/w), max((upper − x)/w)] contains the root whenever the
problem is feasible) and is wrapped in jax.lax.custom_root, making the
projection differentiable by implicit differentiation of the root. Both work
under jit, vmap, and grad. Following optax conventions, JAXopt's
hyperparams tuples are flattened into named arguments, matching
projection_hyperplane and projection_box.

projection_polyhedron is the only item that genuinely requires an
OSQP-class solver, so it is left out pending #977.

Tests check feasibility, optimality (residual orthogonal to the constraint
nullspace), idempotence, consistency with projection_hyperplane (single-row
a) and projection_simplex (unit weights, inactive upper bound), behavior
under jit, and gradients against finite differences. I also verified both
functions against a scipy SLSQP oracle locally (agreement ~1e-15 in float64).

@google-cla

google-cla Bot commented Jun 13, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Adds projection_affine_set and projection_box_section to optax.projections, two of the remaining items in issue 1280. Unlike their JAXopt counterparts, neither requires a QP solver: the affine set projection uses the closed-form KKT solution, and the box section projection finds the scalar dual variable by bisection wrapped in jax.lax.custom_root so it is differentiable by implicit differentiation. projection_polyhedron is left out since it depends on an OSQP-style solver, pending issue 977. Includes tests for feasibility, optimality, idempotence, consistency with projection_hyperplane and projection_simplex, and gradient correctness.
@discobot discobot force-pushed the fix/1280-affine-set-box-section branch from 9caad96 to 22713fe Compare June 13, 2026 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add missing projections from jaxopt

1 participant