Skip to content

Stateful optimization#2188

Draft
f0uriest wants to merge 2 commits into
masterfrom
stateful_optimization
Draft

Stateful optimization#2188
f0uriest wants to merge 2 commits into
masterfrom
stateful_optimization

Conversation

@f0uriest
Copy link
Copy Markdown
Member

@f0uriest f0uriest commented Apr 28, 2026

First pass at #1034. For now this just adds the lowest level API - allowing desc optimizers to maintain and update state between iterations.

Still to do:

  • Plumb this up to objectives
  • add state to scipy wrappers
  • update proximal to use this?
  • figure out how objectives declare that they need state, defaults, etc.

Some open questions:

  • The state that we want to keep track of for the forward pass may be different for the tangent/backward pass. Consider an objective that depends on the solution of a linear system, $A(x) f = b(x)$ where $f$ is the output of the objective (and the state we want to maintain, ie, we solve Af=b iteratively) and $x$ is the input. The derivative would be $A(x) df = db/dx(x) - dA/dx(x) f$ - note that the state we want to keep track of here is $df$, not $f$. This is all to say that keeping track of state may accelerate forward calculations, but may not give much benefit for derivatives depending on the operation. As far as I can tell, it's not possible to have jax return aux information pertaining to the derivative without a ton of custom jvp/vjp logic for each objective.
  • I suppose in the above example, always starting with df=0 would be reasonable? since if we're assuming that f doesn't change much then df should be small. So maybe we just always want to take 0 for the tangent state? This may depend on what the state is etc, so may require some custom jvp logic for stateful objectives.

Resolves #1034

@f0uriest f0uriest marked this pull request as draft April 28, 2026 23:20
@github-actions
Copy link
Copy Markdown
Contributor

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |   -2.39 %    |     4.155e+03      |     4.055e+03      |    -99.25    |       34.95        |       31.74        |
  test_proximal_jac_w7x_with_eq_update   |    0.37 %    |     6.553e+03      |     6.577e+03      |    24.37     |       158.82       |       153.16       |
  test_proximal_freeb_jac                |    0.06 %    |     1.343e+04      |     1.343e+04      |     7.75     |       86.15        |       81.22        |
  test_proximal_freeb_jac_blocked        |    0.07 %    |     7.755e+03      |     7.760e+03      |     5.31     |       73.00        |       69.58        |
  test_proximal_freeb_jac_batched        |    0.79 %    |     7.659e+03      |     7.719e+03      |    60.19     |       71.22        |       69.38        |
  test_proximal_jac_ripple               |    1.82 %    |     3.591e+03      |     3.656e+03      |    65.37     |       56.97        |       55.34        |
  test_proximal_jac_ripple_bounce1d      |   -1.30 %    |     3.819e+03      |     3.770e+03      |    -49.64    |       72.76        |       69.32        |
  test_eq_solve                          |    0.30 %    |     2.167e+03      |     2.173e+03      |     6.59     |       92.16        |       90.14        |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 29, 2026

Codecov Report

❌ Patch coverage is 97.53086% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.44%. Comparing base (4b1c43b) to head (47755fe).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
desc/derivatives.py 78.94% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2188      +/-   ##
==========================================
+ Coverage   94.40%   94.44%   +0.03%     
==========================================
  Files         101      101              
  Lines       28739    28803      +64     
==========================================
+ Hits        27132    27202      +70     
+ Misses       1607     1601       -6     
Files with missing lines Coverage Δ
desc/optimize/aug_lagrangian.py 97.14% <100.00%> (+0.20%) ⬆️
desc/optimize/aug_lagrangian_ls.py 96.04% <100.00%> (+0.23%) ⬆️
desc/optimize/fmin_scalar.py 98.38% <100.00%> (+0.05%) ⬆️
desc/optimize/least_squares.py 99.43% <100.00%> (+0.01%) ⬆️
desc/optimize/stochastic.py 98.23% <100.00%> (+0.21%) ⬆️
desc/optimize/utils.py 95.41% <100.00%> (+0.03%) ⬆️
desc/derivatives.py 92.24% <78.94%> (-2.31%) ⬇️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@unalmis
Copy link
Copy Markdown
Collaborator

unalmis commented May 16, 2026

Some open questions:

  • The state that we want to keep track of for the forward pass may be different for the tangent/backward pass. Consider an objective that depends on the solution of a linear system, $A(x) f = b(x)$ where $f$ is the output of the objective (and the state we want to maintain, ie, we solve Af=b iteratively) and $x$ is the input. The derivative would be $A(x) df = db/dx(x) - dA/dx(x) f$ - note that the state we want to keep track of here is $df$, not $f$. This is all to say that keeping track of state may accelerate forward calculations, but may not give much benefit for derivatives depending on the operation. As far as I can tell, it's not possible to have jax return aux information pertaining to the derivative without a ton of custom jvp/vjp logic for each objective.

primal system is A f = b.

  • store the solution f.

The correct tangent system is on wikipedia.

  • store solution tangent df

update state,

  • initialize next primal solve at f + df
  • initialize next tangent solve at df + (dx old dot dx)/(dx old dot dx old) * (df - df old).

note that the state we want to keep track of here is $df$, not $f$.

keep both; see above.

@unalmis
Copy link
Copy Markdown
Collaborator

unalmis commented May 17, 2026

I suppose in the above example, always starting with df=0 would be reasonable? since if we're assuming that f doesn't change much then df should be small. So maybe we just always want to take 0 for the tangent state? This may depend on what the state is etc, so may require some custom jvp logic for stateful objectives.

The goal is to use known information to estimate the initial guess for the next primal and tangent solves. My suggestion above does that exactly to first order for the primal solve, and uses past history to do it for nearly first order for the tangent solve. You can choose any other estimation heurstic from optimization literature to estimate df_next too, e.g. momentum, mirror etc. In the worst case, the simplest initial guess for the next tangent solve is the previous tangent solve solution, e.g. df_new = df. choosing df_new = 0 doesn't make sense to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimizers/objectives with auxiliary output

2 participants