Skip to content

Use eigenvalue EMA over per-step outer products for per-factor correction#264

Open
runame wants to merge 7 commits into
facebookresearch:mainfrom
runame:pr4/eigenvalue-ema-v2
Open

Use eigenvalue EMA over per-step outer products for per-factor correction#264
runame wants to merge 7 commits into
facebookresearch:mainfrom
runame:pr4/eigenvalue-ema-v2

Conversation

@runame

@runame runame commented May 7, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Replace the per-iteration diag(Q^T M Q) eigenvalue computation (where M is the accumulated factor matrix) with an EMA over per-step projections: E_t = beta2 * E_{t-1} + w * diag(Q^T O_t Q), where O_t is the per-step outer product.
  • This avoids redundant work (the accumulated factor matrix is already an EMA of outer products) and provides smoother eigenvalue tracking across eigenvector updates.
  • Implemented via _post_outer_product_hook and _get_eigenvalue_allocator overrides, with the amortized computation's eigenvalue overwrite disabled per-instance.

Stack

This PR is part of a stack adding per-factor eigenvalue correction to Distributed Shampoo:

  1. Refactor: extract shared EigendecompositionBasedShampooKroneckerFactorsUnwrapped base class #261 — extract shared base class
  2. Refactor: eliminate _compute_outer_product_list via _transform_grad_for_outer_product hook #262 — add _transform_grad_for_outer_product hook (KL refactor)
  3. Add per-factor eigenvalue correction for Distributed Shampoo #263 — per-factor eigenvalue correction (implementation + tests)
  4. This PR — eigenvalue EMA over per-step outer products

Test plan

  • Existing tests pass (distributed_shampoo/tests/, distributed_shampoo/preconditioner/tests/)
  • mypy clean (make type-check)
  • ruff clean

Generated with Claude Code

runame and others added 7 commits May 7, 2026 09:34
…rsUnwrapped base class

Consolidate duplicated eigendecomposition logic from EigendecomposedShampooKroneckerFactorsUnwrapped
and EigenvalueCorrectedShampooKroneckerFactorsUnwrapped into a shared base class. The base class
provides _perform_eigendecomposition and _amortized_computation, with subclass behavior controlled
via hasattr checks on field presence.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…or_outer_product hook

Inline the outer product loop into BaseShampooPreconditionerList._update_factor_matrices
and introduce _transform_grad_for_outer_product as the single extension point. The base
returns grad unchanged; KL-Shampoo subclasses override it to precondition the gradient.
This eliminates _compute_outer_product_list from all three classes that defined it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce PerFactorEigenvalueCorrectedShampoo, which stores m+n eigenvalues
per block (one per factor dimension) computed directly as diag(Q^T M Q), where
Q are cached eigenvectors and M is the already-accumulated factor matrix. This
is more memory-efficient than EShampoo/SOAP's m*n eigenvalues while still
providing eigenvalue correction.

New classes:
- PerFactorEigenvalueCorrectedShampooKroneckerFactorsUnwrapped
- PerFactorEigenvalueCorrectedShampooPreconditionerList
- PerFactorEigenvalueCorrectedKLShampooPreconditionerList (KL variant)
- PerFactorEigenvalueCorrectedShampooPreconditionerConfig
- PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Test the combined PerFactor+KL variant which recomputes eigenvalues
every step and preconditions gradients before outer products. Uses
beta2=0 and epsilon=1.0 to get clean expected values, leveraging the
perturb_before_computation happy path where KL is effectively a no-op
when eigenvalues are equal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…sses

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…tion

Instead of recomputing eigenvalues as diag(Q^T M Q) every step (which causes
discontinuities when eigenvectors Q change at amortized computation steps),
maintain a direct EMA: E_t = beta2 * E_{t-1} + w * diag(Q^T O_t Q).

Key changes:
- Add _get_eigenvalue_allocator hook for zero-init eigenvalues (EMA from zero)
- Add _post_outer_product_hook for eigenvalue EMA update using outer products
- Replace hasattr dispatch with _include_eigenvalues_in_amortized_computation
  ClassVar (addresses existing TODO), set False on PerFactor instances to
  prevent amortized computation from overwriting EMA'd eigenvalues
- Add _prepare_eigenvalues_for_preconditioning hook for bias correction
- Add zero-eigenvalue guard in KL variant to skip preconditioning at init

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 7, 2026
@meta-codesync

meta-codesync Bot commented May 12, 2026

Copy link
Copy Markdown

@hjmshi has imported this pull request. If you are a Meta employee, you can view this in D104875559.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant