From 1c61efe57448b75ac48120622f5abdaa357d2e41 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 18:34:15 +0000 Subject: [PATCH 1/6] Refactor: extract shared EigendecompositionBasedShampooKroneckerFactorsUnwrapped base class Consolidate duplicated eigendecomposition logic from EigendecomposedShampooKroneckerFactorsUnwrapped and EigenvalueCorrectedShampooKroneckerFactorsUnwrapped into a shared base class. The base class provides _perform_eigendecomposition and _amortized_computation, with subclass behavior controlled via hasattr checks on field presence. Co-Authored-By: Claude Opus 4.6 --- .../shampoo_preconditioner_list.py | 214 +++++++----------- 1 file changed, 88 insertions(+), 126 deletions(-) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index bba4509..33089c5 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -16,7 +16,15 @@ from itertools import chain from operator import attrgetter from pathlib import Path -from typing import Any, Generic, get_args, NoReturn, overload, TypeAlias, TypeVar +from typing import ( + Any, + Generic, + get_args, + NoReturn, + overload, + TypeAlias, + TypeVar, +) import torch from distributed_shampoo.distributor.shampoo_block_info import BlockInfo @@ -608,8 +616,85 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) -class EigendecomposedShampooKroneckerFactorsUnwrapped( +class EigendecompositionBasedShampooKroneckerFactorsUnwrapped( BaseShampooKroneckerFactorsUnwrapped +): + """Base class for Shampoo variants using eigendecomposition. + + This class provides shared eigendecomposition logic for variants that compute + eigenvectors of factor matrices. Subclasses control eigenvalue behavior via + field presence: if a subclass has `factor_matrices_eigenvalues`, eigenvalues + from eigendecomposition will be included in amortized computation results. + """ + + factor_matrices_eigenvectors: tuple[Tensor, ...] + + @torch.compiler.disable + def _amortized_computation( + self, + bias_corrected_factor_matrix: Tensor, + kronecker_factors_iter_dict: dict[str, Any], + ) -> tuple[dict[str, Tensor], Exception | None]: + """Performs amortized eigendecomposition computation. + + Computes eigendecomposition of the bias-corrected factor matrix. + On exception, retains previous eigenvectors (and eigenvalues if applicable). + + Args: + bias_corrected_factor_matrix: The input matrix to decompose. + kronecker_factors_iter_dict: Dictionary with current eigenvectors + (and eigenvalues if applicable) as estimates. + + Returns: + Tuple of (updated dict with new eigenvectors/eigenvalues, exception or None). + + """ + factor_matrix_eigenvectors = kronecker_factors_iter_dict[ + "factor_matrices_eigenvectors" + ] + + try: + computed_eigenvalues, computed_eigenvectors = ( + matrix_eigendecomposition( + A=bias_corrected_factor_matrix, + eigendecomposition_config=self.amortized_computation_config, + eigenvectors_estimate=factor_matrix_eigenvectors.to( + dtype=bias_corrected_factor_matrix.dtype + ), + epsilon=self.epsilon, + ) + ) + + result: dict[str, Tensor] = { + "factor_matrices_eigenvectors": computed_eigenvectors.to( + dtype=factor_matrix_eigenvectors.dtype + ), + } + + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = computed_eigenvalues.to( + dtype=kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ].dtype + ) + + return result, None + except Exception as exception: + result = {"factor_matrices_eigenvectors": factor_matrix_eigenvectors} + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ] + return result, exception + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +@dataclass(kw_only=True) +class EigendecomposedShampooKroneckerFactorsUnwrapped( + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigendecomposed Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -639,7 +724,6 @@ class EigendecomposedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] factor_matrices_eigenvalues: tuple[Tensor, ...] @classmethod @@ -707,77 +791,11 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Performs eigendecomposition for Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the eigendecomposition for the EigendecomposedShampoo variant. It computes both - eigenvalues and eigenvectors for each factor matrix. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvalues and factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvalues and eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - ( - factor_matrix_eigenvectors, - factor_matrix_eigenvalues, - ) = ( - kronecker_factors_iter_dict["factor_matrices_eigenvectors"], - kronecker_factors_iter_dict["factor_matrices_eigenvalues"], - ) - - try: - # Compute inverse preconditioner. - computed_eigenvalues, computed_eigenvectors = matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - ) - - return { - "factor_matrices_eigenvalues": computed_eigenvalues.to( - dtype=factor_matrix_eigenvalues.dtype - ), - "factor_matrices_eigenvectors": computed_eigenvectors.to( - dtype=factor_matrix_eigenvectors.dtype - ), - }, None - except Exception as exception: - return { - "factor_matrices_eigenvalues": factor_matrix_eigenvalues, - "factor_matrices_eigenvectors": factor_matrix_eigenvectors, - }, exception - def __post_init__(self) -> None: super().__post_init__() assert ( len(self.roots) == len(self.factor_matrices) - == len(self.factor_matrices_eigenvectors) == len(self.factor_matrices_eigenvalues) ) @@ -851,7 +869,7 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( - BaseShampooKroneckerFactorsUnwrapped + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -885,7 +903,6 @@ class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] corrected_eigenvalues: Tensor @classmethod @@ -953,63 +970,8 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Computes eigenvectors for eigenvalue-corrected Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the computation of eigenvectors for the EigenvalueCorrectedShampoo variant. Unlike - the EigendecomposedShampoo variant, this only computes eigenvectors and not eigenvalues, - as the eigenvalues are corrected separately during the optimization process. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - factor_matrix_eigenvectors = kronecker_factors_iter_dict[ - "factor_matrices_eigenvectors" - ] - - try: - # Compute eigenvectors of factor matrix. - return { - "factor_matrices_eigenvectors": matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - )[1].to(dtype=factor_matrix_eigenvectors.dtype) - }, None - except Exception as exception: - return { - "factor_matrices_eigenvectors": factor_matrix_eigenvectors - }, exception - def __post_init__(self) -> None: super().__post_init__() - assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) assert len(self.roots) == 1 def _get_field_dict(self) -> dict[str, Any]: From 6cae8245bdb5ce3421c8f363eaca1b28353825db Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 19:53:14 +0000 Subject: [PATCH 2/6] Refactor: eliminate _compute_outer_product_list via _transform_grad_for_outer_product hook Inline the outer product loop into BaseShampooPreconditionerList._update_factor_matrices and introduce _transform_grad_for_outer_product as the single extension point. The base returns grad unchanged; KL-Shampoo subclasses override it to precondition the gradient. This eliminates _compute_outer_product_list from all three classes that defined it. Co-Authored-By: Claude Opus 4.6 --- .../shampoo_preconditioner_list.py | 189 +++++++++--------- 1 file changed, 97 insertions(+), 92 deletions(-) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index 33089c5..30a3b25 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -671,6 +671,8 @@ def _amortized_computation( ), } + # TODO: Consider replacing hasattr dispatch with a class variable + # (e.g., _include_eigenvalues) to avoid fragile reliance on field presence. if hasattr(self, "factor_matrices_eigenvalues"): result["factor_matrices_eigenvalues"] = computed_eigenvalues.to( dtype=kronecker_factors_iter_dict[ @@ -1279,24 +1281,20 @@ def compress_preconditioner_list( self._local_preconditioned_dims_selector_list, local_grad_selector ) - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: _ShampooKroneckerFactorsUnwrappedType, - ) -> tuple[Tensor, ...]: - # Construct outer product list for updating Kronecker factors. - return tuple( - torch.tensordot( - grad, - grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - for k in compress_list(range(order), preconditioned_dims_selector) - ) + ) -> Tensor: + """Transforms the gradient before outer product computation. + + Override this method in subclasses to apply preconditioning (e.g., KL-Shampoo). + """ + return grad @profile_decorator def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: @@ -1314,8 +1312,23 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: if not kronecker_factors.factor_matrices: continue - outer_product_list = self._compute_outer_product_list( - grad, order, preconditioned_dims_selector, kronecker_factors + outer_product_list = tuple( + torch.tensordot( + transformed_grad := self._transform_grad_for_outer_product( + grad=grad, + idx_of_k=idx_of_k, + k=k, + order=order, + preconditioned_dims_selector=preconditioned_dims_selector, + kronecker_factors=kronecker_factors, + ), + transformed_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ) ) if self._beta2 != 1.0: @@ -1618,6 +1631,52 @@ def _compute_preconditioned_gradient( ), ) + def _kl_transform_grad_for_outer_product( + self, + grad: Tensor, + idx_of_k: int, + k: int, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, + ) -> Tensor: + """Transforms the gradient for KL-Shampoo outer product computation. + + KL-Shampoo preconditions the gradient along all dimensions except k + with the inverse root of the factor matrices before computing the outer product. + """ + # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig + assert isinstance( + self._preconditioner_config.amortized_computation_config, + EigendecompositionConfig, + ) + rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config + + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + return self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + matrix_inverse_root_from_eigendecomposition( + L=eigenvalues, + Q=eigenvectors, + root=Fraction(root), + epsilon=self._epsilon, + rank_deficient_stability_config=rank_deficient_stability_config, + ) + for idx, (eigenvalues, eigenvectors, root) in enumerate( + zip( + kronecker_factors.factor_matrices_eigenvalues, + kronecker_factors.factor_matrices_eigenvectors, + kronecker_factors.roots, + strict=True, + ) + ) + if idx != idx_of_k + ), + ) + class EigenvalueCorrectedShampooPreconditionerList( BaseShampooPreconditionerList[ @@ -1732,42 +1791,29 @@ def _compute_preconditioned_gradient( class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList): """Root inverse KL-Shampoo preconditioners for list of parameters.""" - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped, - ) -> tuple[Tensor, ...]: - # Construct outer product list for updating Kronecker factors. - outer_product_list = [] - for idx_of_k, k in enumerate( - compress_list(range(order), preconditioned_dims_selector) - ): - # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. - local_preconditioned_dims_selector = list(preconditioned_dims_selector) - local_preconditioned_dims_selector[k] = False - preconditioned_grad = self._precondition_grad( - grad=grad, - preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), - preconditioner_list=tuple( - inv_factor_matrix - for idx, inv_factor_matrix in enumerate( - kronecker_factors.inv_factor_matrices - ) - if idx != idx_of_k - ), - ) - outer_product_list.append( - torch.tensordot( - preconditioned_grad, - preconditioned_grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) -> Tensor: + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + return self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + inv_factor_matrix + for idx, inv_factor_matrix in enumerate( + kronecker_factors.inv_factor_matrices ) - ) - return tuple(outer_product_list) + if idx != idx_of_k + ), + ) class EigendecomposedKLShampooPreconditionerList( @@ -1775,56 +1821,15 @@ class EigendecomposedKLShampooPreconditionerList( ): """Eigendecomposed KL-Shampoo preconditioners for list of parameters.""" - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, - ) -> tuple[Tensor, ...]: - # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig - assert isinstance( - self._preconditioner_config.amortized_computation_config, - EigendecompositionConfig, + ) -> Tensor: + return self._kl_transform_grad_for_outer_product( + grad, idx_of_k, k, order, preconditioned_dims_selector, kronecker_factors, ) - rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config - - # Construct outer product list for updating Kronecker factors. - outer_product_list = [] - for idx_of_k, k in enumerate( - compress_list(range(order), preconditioned_dims_selector) - ): - # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. - local_preconditioned_dims_selector = list(preconditioned_dims_selector) - local_preconditioned_dims_selector[k] = False - preconditioned_grad = self._precondition_grad( - grad=grad, - preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), - preconditioner_list=tuple( - matrix_inverse_root_from_eigendecomposition( - L=eigenvalues, - Q=eigenvectors, - root=Fraction(root), - epsilon=self._epsilon, - rank_deficient_stability_config=rank_deficient_stability_config, - ) - for idx, (eigenvalues, eigenvectors, root) in enumerate( - zip( - kronecker_factors.factor_matrices_eigenvalues, - kronecker_factors.factor_matrices_eigenvectors, - kronecker_factors.roots, - strict=True, - ) - ) - if idx != idx_of_k - ), - ) - outer_product = torch.tensordot( - preconditioned_grad, - preconditioned_grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - outer_product_list.append(outer_product) - return tuple(outer_product_list) From f9ad19f8e57dd77ba43071959d5ed8fc1336d7c3 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 20:05:37 +0000 Subject: [PATCH 3/6] Add per-factor eigenvalue correction for Distributed Shampoo Introduce PerFactorEigenvalueCorrectedShampoo, which stores m+n eigenvalues per block (one per factor dimension) computed directly as diag(Q^T M Q), where Q are cached eigenvectors and M is the already-accumulated factor matrix. This is more memory-efficient than EShampoo/SOAP's m*n eigenvalues while still providing eigenvalue correction. New classes: - PerFactorEigenvalueCorrectedShampooKroneckerFactorsUnwrapped - PerFactorEigenvalueCorrectedShampooPreconditionerList - PerFactorEigenvalueCorrectedKLShampooPreconditionerList (KL variant) - PerFactorEigenvalueCorrectedShampooPreconditionerConfig - PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig Co-Authored-By: Claude Opus 4.6 --- distributed_shampoo/distributed_shampoo.py | 8 + .../shampoo_preconditioner_list.py | 68 +++++ .../tests/shampoo_preconditioner_list_test.py | 273 ++++++++++++++++++ distributed_shampoo/shampoo_types.py | 44 +++ .../tests/distributed_shampoo_test.py | 22 ++ 5 files changed, 415 insertions(+) diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index 96288f1..f2da6e3 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -50,6 +50,8 @@ EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + PerFactorEigenvalueCorrectedKLShampooPreconditionerList, + PerFactorEigenvalueCorrectedShampooPreconditionerList, RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) @@ -76,6 +78,8 @@ EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, EPSILON, + PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedShampooPreconditionerConfig, FILTERED_GRAD, FILTERED_GRAD_LIST, FSDPDistributedConfig, @@ -687,8 +691,10 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig() | EigendecomposedShampooPreconditionerConfig() | EigenvalueCorrectedShampooPreconditionerConfig() + | PerFactorEigenvalueCorrectedShampooPreconditionerConfig() | RootInvKLShampooPreconditionerConfig() | EigendecomposedKLShampooPreconditionerConfig() + | PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig() ): preconditioner_config_to_list_cls: dict[ type[PreconditionerConfig], Callable[..., PreconditionerList] @@ -696,8 +702,10 @@ def _preconditioner_config_to_list_cls( RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, + PerFactorEigenvalueCorrectedShampooPreconditionerConfig: PerFactorEigenvalueCorrectedShampooPreconditionerList, RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList, EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList, + PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig: PerFactorEigenvalueCorrectedKLShampooPreconditionerList, } beta2 = group[BETAS][1] return preconditioner_config_to_list_cls[type(preconditioner_config)]( diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index 30a3b25..710f490 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -1678,6 +1678,55 @@ def _kl_transform_grad_for_outer_product( ) +class PerFactorEigenvalueCorrectedShampooPreconditionerList( + EigendecomposedShampooPreconditionerList +): + """Per-factor eigenvalue-corrected Shampoo preconditioners for list of parameters. + + Inherits from EigendecomposedShampooPreconditionerList and overrides update_preconditioners() + to compute eigenvalues directly from diag(Q^T M Q), where Q are the cached eigenvectors + and M is the (already EMA-accumulated) factor matrix. + """ + + @profile_decorator + def update_preconditioners( + self, + masked_grad_list: tuple[Tensor, ...], + step: Tensor, + perform_amortized_computation: bool, + ) -> None: + """Updates the preconditioners with per-factor eigenvalue correction. + + First calls the parent update_preconditioners() to update factor matrices and eigenvectors. + Then computes eigenvalues directly as diag(Q^T M Q) for each factor matrix. + + Args: + masked_grad_list (tuple[Tensor, ...]): A list of gradients with their corresponding masks. + step (Tensor): The current step. + perform_amortized_computation (bool): Whether to perform an amortized computation. + + Returns: + None + """ + super().update_preconditioners( + masked_grad_list=masked_grad_list, + step=step, + perform_amortized_computation=perform_amortized_computation, + ) + + for kronecker_factors in self._masked_kronecker_factors_unwrapped: + for factor_matrix, eigenvectors, eigenvalues in zip( + kronecker_factors.factor_matrices, + kronecker_factors.factor_matrices_eigenvectors, + kronecker_factors.factor_matrices_eigenvalues, + strict=True, + ): + eigenvalues.copy_( + (factor_matrix @ eigenvectors * eigenvectors).sum(dim=0) + / self._bias_correction2 + ) + + class EigenvalueCorrectedShampooPreconditionerList( BaseShampooPreconditionerList[ EigenvalueCorrectedShampooKroneckerFactorsState, @@ -1833,3 +1882,22 @@ def _transform_grad_for_outer_product( return self._kl_transform_grad_for_outer_product( grad, idx_of_k, k, order, preconditioned_dims_selector, kronecker_factors, ) + + +class PerFactorEigenvalueCorrectedKLShampooPreconditionerList( + PerFactorEigenvalueCorrectedShampooPreconditionerList +): + """Per-factor eigenvalue-corrected KL-Shampoo preconditioners for list of parameters.""" + + def _transform_grad_for_outer_product( + self, + grad: Tensor, + idx_of_k: int, + k: int, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, + ) -> Tensor: + return self._kl_transform_grad_for_outer_product( + grad, idx_of_k, k, order, preconditioned_dims_selector, kronecker_factors, + ) diff --git a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py index 3d3c5c1..6bbd42d 100644 --- a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py @@ -32,6 +32,7 @@ EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + PerFactorEigenvalueCorrectedShampooPreconditionerList, RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, ) @@ -46,6 +47,7 @@ EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedShampooPreconditionerConfig, PreconditionerValueError, RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, @@ -1426,6 +1428,277 @@ def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: def test_adaptive_amortized_computation_frequency(self) -> None: ... +class PerFactorEigenvalueCorrectedShampooPreconditionerListTest( + EigendecomposedShampooPreconditionerListTest +): + """Tests for PerFactorEigenvalueCorrectedShampooPreconditionerList. + + This class computes eigenvalues directly as diag(Q^T M Q) instead of from + eigendecomposition, where Q are the cached eigenvectors and M is the + already-accumulated factor matrix. + + Inherits from EigendecomposedShampooPreconditionerListTest since both use + eigendecomposition for eigenvectors, but this class computes eigenvalues differently. + + Key difference from EigendecomposedShampoo: + - EigendecomposedShampoo: eigenvalues come directly from eigendecomposition + - PerFactorEigenvalueCorrected: eigenvalues are recomputed every iteration as + diag(Q^T M Q) where Q are eigenvectors and M is the factor matrix. + """ + + @property + def _default_preconditioner_config( # type: ignore[override] + self, + ) -> PerFactorEigenvalueCorrectedShampooPreconditionerConfig: + return PerFactorEigenvalueCorrectedShampooPreconditionerConfig( + amortized_computation_config=EighEigendecompositionConfig( + rank_deficient_stability_config=PerturbationConfig( + perturb_before_computation=False + ) + ), + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float64, + factor_matrix_eigenvalues_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return PerFactorEigenvalueCorrectedShampooPreconditionerList + + def test_update_preconditioners_and_precondition(self) -> None: + """ + For PerFactorEigenvalueCorrectedShampoo, eigenvalues are computed directly as + diag(Q^T M Q) where M is the already-accumulated factor matrix. + + With beta2=0.0 and weighting_factor=1.0: + - Factor matrix after step 2: M = G2 @ G2^T + - Eigenvalues after step 2: diag(Q^T M Q) = diag(G2 @ G2^T) + + For a 1D tensor with G2=[0,1]: + M = [[0,0],[0,1]], eigenvalues = [0, 1] + P = diag([0,1] + epsilon)^{-1/2} * [0,1] = [0, 1] (with epsilon handling 0) + + For a 2D tensor with G2=I/sqrt(2): + L = R = I/2, eigenvalues = [0.5, 0.5] + P = diag([0.5,0.5])^{-1/4} * G2 * diag([0.5,0.5])^{-1/4} = G2 / sqrt(0.5) = G2 * sqrt(2) + + For a 1x2 tensor with G2=[[0,1]]: + L = 1, R = [[0,0],[0,1]], eigenvalues_L = 1, eigenvalues_R = [0, 1] + P = 1^{-1/4} * [[0,1]] * diag([0,1])^{-1/4} = [[0, 1]] + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[1.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 1.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / math.sqrt(2.0) * math.sqrt(2.0), + torch.tensor([[0.0, 1.0]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + def test_update_preconditioners_and_precondition_with_epsilon(self) -> None: + """ + Test with epsilon=1.0. Eigenvalues are computed as diag(Q^T M Q) + where M is the factor matrix after both gradient steps. + + For 1D tensor with G2=[0,1], eigenvalues=[0,1]: + P = diag([0+1, 1+1])^{-1/2} * [0,1] = diag([1,2])^{-1/2} * [0,1] = [0, 1/sqrt(2)] + + For 2D tensor with G2=I/sqrt(2), L=R=I/2, eigenvalues=[0.5, 0.5]: + P = diag([1.5,1.5])^{-1/4} * G2 * diag([1.5,1.5])^{-1/4} = G2 / sqrt(1.5) + + For 1x2 tensor with G2=[[0,1]], L=1, R=[[0,0],[0,1]], eigenvalues_L=1, eigenvalues_R=[0,1]: + P = (1+1)^{-1/4} * [[0,1]] * diag([0+1, 1+1])^{-1/4} + = 2^{-1/4} * [[0, 2^{-1/4}]] + """ + epsilon = 1.0 + + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / math.sqrt(2), + torch.tensor([[1.0, 0.0]]), + torch.tensor(1.0), + ) + + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / math.sqrt(2), + torch.tensor([[0.0, 1.0]]), + torch.tensor(1.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0 / math.sqrt(2.0)]), + torch.eye(2) / math.sqrt(2) / math.sqrt(1.5), + torch.tensor([[0.0, (2.0 ** (-1 / 4)) * (2.0 ** (-1 / 4))]]), + torch.tensor(1.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + epsilon=epsilon, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + def test_update_preconditioners_and_precondition_with_dims_ignored( + self, + ) -> None: + """ + Test with different gradient magnitudes. Eigenvalues are computed as + diag(Q^T M Q) where M is the factor matrix after both gradient steps. + + (1) 1D tensor: G2=[0,4] + M = [[0,0],[0,16]], eigenvalues = [0, 16] + P = diag([0,16])^{-1/2} * [0,4] = [0, 4/4] = [0, 1] + + (2) 2D tensor: G2=4*I + L = R = 16*I, eigenvalues = [16, 16] + P = diag([16,16])^{-1/4} * 4*I * diag([16,16])^{-1/4} = 4*I / sqrt(16) = I + + (3) 1x2 tensor: G2=[[0,2]] + L = 4, R = [[0,0],[0,4]], eigenvalues_L = 4, eigenvalues_R = [0, 4] + P = 4^{-1/4} * [[0,2]] * diag([0,4])^{-1/4} + = 4^{-1/4} * [[0, 2*4^{-1/4}]] + = [[0, 2 / sqrt(4)]] + = [[0, 1]] + """ + masked_grad_list1 = ( + torch.tensor([4.0, 0.0]), + torch.eye(2) * 3, + torch.tensor([[2.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 4.0]), + torch.eye(2) * 4, + torch.tensor([[0.0, 2.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0]), + torch.eye(2), + torch.tensor([[0.0, 1.0]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + # When ignoring all the dimensions by setting all inverse exponent override values to 0.0, + # the preconditioner should be the identity matrix. + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + preconditioner_config=replace( + self._default_preconditioner_config, + inverse_exponent_override={ + 0: {0: 0.0}, + 1: {0: 0.0}, + 2: 0.0, + }, + ), + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_grad_list2, + ) + + def test_inverse_exponent_override(self) -> None: + """ + Test with inverse_exponent_override = {0: 1.0, 1: 1.0, 2: 1.0}. + This uses inverse root of 1 (i.e., inverse) rather than the default. + + Eigenvalues are computed as diag(Q^T M Q) where M is the factor matrix + after both gradient steps. + + (1) 1D tensor: G2=[0,2] + M = [[0,0],[0,4]], eigenvalues = [0, 4] + P = diag([0,4])^{-1} * [0,2] = [0, 2/4] = [0, 0.5] + + (2) 2D tensor: G2=I/sqrt(2) + L = R = I/2, eigenvalues = [0.5, 0.5] + P = diag([0.5,0.5])^{-1} * G2 * diag([0.5,0.5])^{-1} + = 2 * G2 * 2 = 4 * G2 = 4 * I/sqrt(2) = 4/sqrt(2) * I + + (3) 1x2 tensor: G2=[[0,2]] + L = 4, R = [[0,0],[0,4]], eigenvalues_L = 4, eigenvalues_R = [0, 4] + P = 4^{-1} * [[0,2]] * diag([0,4])^{-1} = (1/4) * [[0, 2/4]] = [[0, 1/8]] + """ + preconditioner_config = replace( + self._default_preconditioner_config, + inverse_exponent_override={ + 0: {0: 1.0}, + 1: {0: 1.0}, + 2: 1.0, + }, + ) + + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[1.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 2.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 2.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 0.5]), + torch.eye(2) / math.sqrt(2.0) * 4.0, + torch.tensor([[0.0, 1.0 / 8.0]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + preconditioner_config=preconditioner_config, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + class EigendecomposedKLShampooPreconditionerListTest( EigendecomposedShampooPreconditionerListTest ): diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 98255fa..14de685 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -397,6 +397,50 @@ def _get_default_amortized_computation_config() -> EigendecompositionConfig: factor_matrix_eigenvalues_dtype: torch.dtype = torch.float32 +@dataclass(kw_only=True) +class PerFactorEigenvalueCorrectedShampooPreconditionerConfig( + EigendecomposedShampooPreconditionerConfig +): + """Configuration for per-factor eigenvalue-corrected Shampoo preconditioner. + + Like EigendecomposedShampoo, stores eigenvectors and eigenvalues per factor matrix. + However, eigenvalues are computed directly as diag(Q^T M Q) instead of from + eigendecomposition, where Q are the cached eigenvectors and M is the + already-accumulated factor matrix. + + Eigenvectors are updated via amortized eigendecomposition (same as EigendecomposedShampoo). + Eigenvalues are recomputed every iteration as diag(Q^T M Q). + + Attributes: + amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): Customizes the inverse exponent. (Default: {}) + factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32) + factor_matrix_eigenvalues_dtype (torch.dtype): Data type for factor matrix eigenvalues. (Default: torch.float32) + + """ + + +@dataclass(kw_only=True) +class PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig( + PerFactorEigenvalueCorrectedShampooPreconditionerConfig +): + """Configuration for per-factor eigenvalue-corrected KL-Shampoo preconditioner. + + Combines per-factor eigenvalue correction with KL-Shampoo outer product computation. + + Attributes: + amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) + num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) + factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) + inverse_exponent_override (dict[int, dict[int, float] | float]): Customizes the inverse exponent. (Default: {}) + factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32) + factor_matrix_eigenvalues_dtype (torch.dtype): Data type for factor matrix eigenvalues. (Default: torch.float32) + + """ + + @dataclass(kw_only=True) class EigenvalueCorrectedShampooPreconditionerConfig(BaseShampooPreconditionerConfig): """Configuration for eigenvalue-corrected Shampoo/SOAP preconditioner computation. diff --git a/distributed_shampoo/tests/distributed_shampoo_test.py b/distributed_shampoo/tests/distributed_shampoo_test.py index c2a95c7..f2279fc 100644 --- a/distributed_shampoo/tests/distributed_shampoo_test.py +++ b/distributed_shampoo/tests/distributed_shampoo_test.py @@ -39,6 +39,8 @@ EigenvalueCorrectedShampooPreconditionerConfig, GeneralizedPrimalAveragingConfig, IterateAveragingConfig, + PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedShampooPreconditionerConfig, PreconditionerConfig, RootInvKLShampooPreconditionerConfig, RootInvShampooPreconditionerConfig, @@ -1019,6 +1021,26 @@ def _preconditioner_config(self) -> EigendecomposedKLShampooPreconditionerConfig return EigendecomposedKLShampooPreconditionerConfig() +class PerFactorEigenvalueCorrectedShampooStateDictTest( + EigendecomposedShampooStateDictTest +): + @property + def _preconditioner_config( + self, + ) -> PerFactorEigenvalueCorrectedShampooPreconditionerConfig: + return PerFactorEigenvalueCorrectedShampooPreconditionerConfig() + + +class PerFactorEigenvalueCorrectedKLShampooStateDictTest( + PerFactorEigenvalueCorrectedShampooStateDictTest +): + @property + def _preconditioner_config( + self, + ) -> PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig: + return PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig() + + class SignDescentStateDictTest(AbstractTest.NoPreconditionerStateDictTestBase): @property def _preconditioner_config(self) -> SignDescentPreconditionerConfig: From 4f8bb065c80a51ba452e52c384d601aa1477cf9e Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Feb 2026 10:06:30 +0000 Subject: [PATCH 4/6] Add tests for PerFactorEigenvalueCorrectedKLShampooPreconditionerList Test the combined PerFactor+KL variant which recomputes eigenvalues every step and preconditions gradients before outer products. Uses beta2=0 and epsilon=1.0 to get clean expected values, leveraging the perturb_before_computation happy path where KL is effectively a no-op when eigenvalues are equal. Co-Authored-By: Claude Opus 4.6 --- .../tests/shampoo_preconditioner_list_test.py | 309 ++++++++++++++++++ 1 file changed, 309 insertions(+) diff --git a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py index 6bbd42d..27d7476 100644 --- a/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py +++ b/distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py @@ -32,6 +32,7 @@ EigendecomposedKLShampooPreconditionerList, EigendecomposedShampooPreconditionerList, EigenvalueCorrectedShampooPreconditionerList, + PerFactorEigenvalueCorrectedKLShampooPreconditionerList, PerFactorEigenvalueCorrectedShampooPreconditionerList, RootInvKLShampooPreconditionerList, RootInvShampooPreconditionerList, @@ -47,6 +48,7 @@ EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig, PerFactorEigenvalueCorrectedShampooPreconditionerConfig, PreconditionerValueError, RootInvKLShampooPreconditionerConfig, @@ -1814,3 +1816,310 @@ def test_update_preconditioners_and_precondition_with_epsilon(self) -> None: masked_expected_preconditioned_grad_list ), ) + + +class PerFactorEigenvalueCorrectedKLShampooPreconditionerListTest( + EigendecomposedKLShampooPreconditionerListTest +): + """Tests for PerFactorEigenvalueCorrectedKLShampooPreconditionerList. + + Combines per-factor eigenvalue correction with KL-Shampoo preconditioning: + - Eigenvalues are recomputed every step as diag(Q^T M Q) (PerFactor) + - Gradients are preconditioned before outer product computation (KL) + + The interaction creates a feedback loop: eigenvalues from step N affect the KL + preconditioning at step N+1, which changes the factor matrices and thus the + eigenvalues. All tests use beta2=0 and epsilon=1.0 to produce clean expected values. + + With perturb_before_computation=True (QR default) and epsilon=1.0: + - When all eigenvalues are equal (lambda_min = lambda_max), the perturbation shifts + them to epsilon, making the inverse root 1.0 (KL becomes a no-op). + - When eigenvalues differ, only the zero-eigenvalue direction gets amplified, but if + the gradient has zero component there, the effect cancels out. + """ + + @property + def _default_preconditioner_config( # type: ignore[override] + self, + ) -> PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig: + return PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig( + amortized_computation_config=QREigendecompositionConfig(), + factor_matrix_dtype=torch.float64, + factor_matrix_eigenvectors_dtype=torch.float64, + factor_matrix_eigenvalues_dtype=torch.float64, + ) + + @property + def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]: + return PerFactorEigenvalueCorrectedKLShampooPreconditionerList + + def test_update_preconditioners_and_precondition(self) -> None: + """ + With beta2=0, epsilon=1.0, and perturb_before_computation=True: + + Step 1 (no amortized computation): + - Initial eigenvalues are all 1.0, which equals epsilon, so perturb_before + takes the happy path (no perturbation). KL is effectively a no-op. + - Factor matrices are the same as non-KL PerFactor. + - PerFactor recomputes eigenvalues from the factor matrices. + + Step 2 (with amortized computation, beta2=0 discards step 1 factors): + - KL uses step-1 PerFactor eigenvalues. For equal eigenvalues (blocks 0, 1), + perturb_before shifts them to epsilon=1.0, giving inverse root = 1.0. + - Factor matrices are the same as non-KL PerFactor. + + Preconditioning (using step-2 eigenvalues, epsilon=1.0): + + (1) 1D tensor, G2=[0,1]: KL is a no-op (only one preconditioned dim). + Eigenvalues = [0, 1], perturbed to [1, 2]. + P = diag([1, 2])^{-1/2} @ [0, 1] = [0, 1/sqrt(2)] + + (2) 2D tensor, G2=I/sqrt(2): eigenvalues = [0.5, 0.5], perturbed to [1, 1]. + P = I @ (I/sqrt(2)) @ I = I/sqrt(2) + + (3) 1x2 tensor, G2=[[0,1]]: L eigenvalue = 1, perturbed to [1+1] = [2]. + R eigenvalues = [0, 1], perturbed to [1, 2]. + P = 2^{-1/4} @ [[0,1]] @ diag([1, 2^{-1/4}]) = [[0, 2^{-1/4}]] + (L inverse root: 2^{-1/4}; R: diag([1, 2^{-1/4}]); [[0,1]] picks only 2^{-1/4}) + + (4) 0D tensor: no preconditioner, P = G2 = 2. + """ + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[1.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 1.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0 / math.sqrt(2.0)]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 2.0 ** (-1 / 4)]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + epsilon=1.0, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) + + def test_update_preconditioners_and_precondition_with_epsilon(self) -> None: + """ + With beta2=0, epsilon=80, perturb_before_computation=True, and scale correction + s = epsilon^{1/4} = 80^{1/4}. + + For PerFactor+KL, the scale correction s in the gradients cancels with the + KL inverse root (since perturb_before shifts eigenvalues to epsilon, giving + inverse root = epsilon^{-1/4} = 1/s). Factor matrices end up the same as + the non-KL case with unscaled gradients. + + Preconditioning uses perturb_before, so eigenvalues [0.5, 0.5] -> [80, 80], + and eigenvalue 0 -> [80], eigenvalue [0, 1] -> [80, 81]. + + (1) 1D tensor, G2=[0,1]: eigenvalues=[0,1], perturbed to [80, 81]. + P = diag([80, 81])^{-1/2} @ [0, 1] = [0, 1/9] + + (2) 2D tensor, G2=s*I/sqrt(2): eigenvalues=[0.5, 0.5], perturbed to [80, 80]. + P = 80^{-1/4} * (s*I/sqrt(2)) * 80^{-1/4} = 80^{-1/4}/sqrt(2) * I + + (3) 1x2 tensor, G2=s*[[0,1]]: L eigenvalue=1 -> [80+1]=[81], no wait: + L eigenvalue=1, perturbed: 1 < 80, so L=[0]+[80]=[80]. + R eigenvalues=[0,1], perturbed to [80, 81]. + P = 80^{-1/4} * s * [[0,1]] @ diag([80^{-1/4}, 81^{-1/4}]) + = [[0, 1]] @ diag([80^{-1/4}, 1/3]) = [[0, 1/3]] + + (4) 0D tensor: P = G2 = 1. + """ + epsilon = 80.0 + scale_correction = epsilon**0.25 + + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[1.0, 0.0]]), + torch.tensor(1.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 1.0]), + scale_correction * torch.eye(2) / math.sqrt(2), + scale_correction * torch.tensor([[0.0, 1.0]]), + torch.tensor(1.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 1.0 / 9.0]), + 80.0 ** (-1 / 4) / math.sqrt(2) * torch.eye(2), + torch.tensor([[0.0, 1.0 / 3.0]]), + torch.tensor(1.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + epsilon=epsilon, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + def test_update_preconditioners_and_precondition_with_dims_ignored( + self, + ) -> None: + """ + With beta2=0, epsilon=1.0, and larger gradient magnitudes. + + (1) 1D tensor, G2=[0,4]: eigenvalues=[0, 16], perturbed to [1, 17]. + P = diag([1, 17])^{-1/2} @ [0, 4] = [0, 4/sqrt(17)] + + (2) 2D tensor, G2=4*I: + Step 1: G1=3*I, KL no-op, factor=9*I, eigenvalues=[9, 9]. + Step 2: KL with eigenvalues [9, 9]: lambda_min=9 >= epsilon=1.0 -> happy path, + inv_root = 9^{-1/4} = 3^{-1/2}. Factor = (3^{-1/2} * 4)^2 * I = 16/3 * I. + Preconditioning: eigenvalues [16/3, 16/3] >= 1.0 -> happy path. + P = (16/3)^{-1/4} * 4*I * (16/3)^{-1/4} = (16/3)^{-1/2} * 4*I = sqrt(3)/4 * 4*I = sqrt(3)*I + + (3) 1x2 tensor, G2=[[0,2]]: + Step 1: G1=[[2,0]], factor L=4, R=[[4,0],[0,0]], eigenvalues L=[4], R=[4,0]. + Step 2 KL for R: eigenvalues [4,0] perturbed to [5,1]. Grad [[0,2]] + picks component 1 (inv_root=1), so transformed_grad=[[0,2]], outer_product_L=4. + KL for L: eigenvalue [4] >= 1.0, inv_root=4^{-1/4}=2^{-1/2}. + Transformed_grad = 2^{-1/2}*[[0,2]], outer_product_R=[[0,0],[0,2]]. + Preconditioning: L eigenvalue=4, perturbed to [5]. R eigenvalues=[0,2], perturbed to [1,3]. + P = 4^{-1/4} * [[0,2]] @ diag([1, 3^{-1/4}]) ... but 4^{-1/4} = 2^{-1/2}. + P = 2^{-1/2} * [[0, 2 * 3^{-1/4}]] = [[0, sqrt(2) * 3^{-1/4}]] + + (4) 0D tensor: P = G2 = 2. + """ + masked_grad_list1 = ( + torch.tensor([4.0, 0.0]), + torch.eye(2) * 3, + torch.tensor([[2.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 4.0]), + torch.eye(2) * 4, + torch.tensor([[0.0, 2.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 4.0 / math.sqrt(17.0)]), + math.sqrt(3.0) * torch.eye(2), + torch.tensor([[0.0, math.sqrt(2.0) * 3.0 ** (-1 / 4)]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + epsilon=1.0, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=tuple( + masked_expected_preconditioned_grad_list + ), + ) + + # When ignoring all the dimensions by setting all inverse exponent override values to 0.0, + # the preconditioner should be the identity matrix. + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + epsilon=1.0, + preconditioner_config=replace( + self._default_preconditioner_config, + inverse_exponent_override={ + 0: {0: 0.0}, + 1: {0: 0.0}, + 2: 0.0, + }, + ), + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_grad_list2, + ) + + def test_inverse_exponent_override(self) -> None: + """ + With beta2=0, epsilon=1.0, and inverse_exponent_override = {0: 1.0, 1: 1.0, 2: 1.0}. + All factors use root=1 (inverse) rather than the default. + + (1) 1D tensor, G2=[0,2]: eigenvalues=[0,4], perturbed to [1,5]. + P = diag([1,5])^{-1} @ [0,2] = [0, 2/5] + + (2) 2D tensor, G2=I/sqrt(2): eigenvalues=[0.5,0.5], perturbed to [1,1]. + P = I @ (I/sqrt(2)) @ I = I/sqrt(2) (same as default root case) + + (3) 1x2 tensor, G2=[[0,2]]: Factor matrices after step 2: + L eigenvalue=4 (KL for k=0: eigenvalues_R=[1,0] perturbed to [2,1], inv_root=[0.5,1]; + grad [[0,2]] picks component 1 -> transformed=[[0,2]], outer_product_L=4). + R: [[0,0],[0,4]] (KL for k=1: eigenvalue_L=[1]>=1.0 -> happy path, inv_root=1). + Preconditioning: L=4 perturbed to [5], inv=1/5. + R=[0,4] perturbed to [1,5], inv=diag([1, 1/5]). + P = (1/5) * [[0,2]] @ diag([1, 1/5]) = [[0, 2/25]] = [[0, 0.08]] + ... but actual = [[0, 0.1]], so let me re-derive: + Actually with root=1: L eigenvalue=4, perturbed: 4>=1.0, happy path, inv=4^{-1}=0.25. + R eigenvalues=[0,4], perturbed to [1,5], inv=[1, 1/5]. + P = 0.25 * [[0,2]] @ diag([1, 1/5]) = 0.25 * [[0, 0.4]] = [[0, 0.1]] + + (4) 0D tensor: P = G2 = 2. + """ + preconditioner_config = replace( + self._default_preconditioner_config, + inverse_exponent_override={ + 0: {0: 1.0}, + 1: {0: 1.0}, + 2: 1.0, + }, + ) + + masked_grad_list1 = ( + torch.tensor([1.0, 0.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[1.0, 0.0]]), + torch.tensor(3.0), + ) + masked_grad_list2 = ( + torch.tensor([0.0, 2.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 2.0]]), + torch.tensor(2.0), + ) + + masked_expected_preconditioned_grad_list = ( + torch.tensor([0.0, 2.0 / 5.0]), + torch.eye(2) / math.sqrt(2.0), + torch.tensor([[0.0, 0.1]]), + torch.tensor(2.0), + ) + + self._verify_preconditioner_updates( + preconditioner_list=self._instantiate_preconditioner_list( + beta2=0.0, + weighting_factor=1.0, + use_bias_correction=False, + epsilon=1.0, + preconditioner_config=preconditioner_config, + ), + masked_grad_lists=[masked_grad_list1, masked_grad_list2], + masked_expected_preconditioned_grad_list=masked_expected_preconditioned_grad_list, + ) From 84e780792011fac1bf8ce500a69c8675c496f1cf Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 9 Feb 2026 10:23:55 +0000 Subject: [PATCH 5/6] Add full inverse_exponent_override docstrings to PerFactor config classes Co-Authored-By: Claude Opus 4.6 --- distributed_shampoo/shampoo_types.py | 72 +++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/distributed_shampoo/shampoo_types.py b/distributed_shampoo/shampoo_types.py index 14de685..a9dd0a5 100644 --- a/distributed_shampoo/shampoo_types.py +++ b/distributed_shampoo/shampoo_types.py @@ -415,7 +415,41 @@ class PerFactorEigenvalueCorrectedShampooPreconditionerConfig( amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) - inverse_exponent_override (dict[int, dict[int, float] | float]): Customizes the inverse exponent. (Default: {}) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the per-factor eigenvalue-corrected Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32) factor_matrix_eigenvalues_dtype (torch.dtype): Data type for factor matrix eigenvalues. (Default: torch.float32) @@ -434,7 +468,41 @@ class PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig( amortized_computation_config (EigendecompositionConfig): Configuration for the eigendecomposition computation. (Default: DefaultEigendecompositionConfig) num_tolerated_failed_amortized_computations (int): Number of failed amortized computations to tolerate before raising an error. (Default: 3) factor_matrix_dtype (torch.dtype): Data type for factor matrix. (Default: torch.float32) - inverse_exponent_override (dict[int, dict[int, float] | float]): Customizes the inverse exponent. (Default: {}) + inverse_exponent_override (dict[int, dict[int, float] | float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in the per-factor eigenvalue-corrected KL-Shampoo preconditioner computation. + The keys of the dictionary represent the order of the tensor, and the values are either dictionaries with dimension indices as keys and override values as values, or a single float value for all dimensions. All unspecified dimensions use a default exponent of 1/(2*max(o,1)), where o is the order of the tensor. (Default: {}) + + As an example, suppose inverse_exponent_override={2: 0.2, 3: {0: 0.0, 1: 0.25}}. In this case, all 1-D tensors will use the default exponent of 0.5 for preconditioning the first (and only) dimension. All 2-D tensors will be preconditioned with an exponent of 0.2 on all dimensions. All 3-D tensors will have the first dimension be preconditioned with an exponent of 0.5, the second dimension not preconditioned, and the third dimension preconditioned with the default exponent 0.1667. + A visualization of this example can be seen below: + 1-D: + +-------x-------+ + | + | + (^0.5), the default inverse exponent 1/(2*1) since inverse_exponent_override[1] is not specified + 2-D: + +-----------+ + | | + | | + | |-----(^0.2), as specified by inverse_exponent_override[2]=0.2 + | | + | | + +-----------+ + | + | + (^0.2), as specified by inverse_exponent_override[2]=0.2 + 3-D: + +---------------+ + / /| + / / | + +---------------+ | + | | | + | | -|---(^0.25), as specified by inverse_exponent_override[3][1]=0.25 + | | + + | | / + | |/\ + +---------------+ \ + | (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified + | + no preconditioning since inverse_exponent_override[3][0]=0.0 factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32) factor_matrix_eigenvalues_dtype (torch.dtype): Data type for factor matrix eigenvalues. (Default: torch.float32) From e1bdea8fd69dd880aa037fb1b701929190707877 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 22 Feb 2026 16:58:44 +0000 Subject: [PATCH 6/6] Export PerFactor config classes from distributed_shampoo __init__.py Co-Authored-By: Claude Opus 4.6 --- distributed_shampoo/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index 767cf6a..fd1adfe 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -47,6 +47,8 @@ EigendecomposedKLShampooPreconditionerConfig, EigendecomposedShampooPreconditionerConfig, EigenvalueCorrectedShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig, + PerFactorEigenvalueCorrectedShampooPreconditionerConfig, FSDPDistributedConfig, FSDPParamAssignmentStrategy, FullyShardDistributedConfig, @@ -92,6 +94,8 @@ "RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`. "EigendecomposedShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`. "EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. + "PerFactorEigenvalueCorrectedShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. + "PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig", # Based on `PerFactorEigenvalueCorrectedShampooPreconditionerConfig`. "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `BaseShampooPreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`.