From 1c61efe57448b75ac48120622f5abdaa357d2e41 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 18:34:15 +0000 Subject: [PATCH 1/2] Refactor: extract shared EigendecompositionBasedShampooKroneckerFactorsUnwrapped base class Consolidate duplicated eigendecomposition logic from EigendecomposedShampooKroneckerFactorsUnwrapped and EigenvalueCorrectedShampooKroneckerFactorsUnwrapped into a shared base class. The base class provides _perform_eigendecomposition and _amortized_computation, with subclass behavior controlled via hasattr checks on field presence. Co-Authored-By: Claude Opus 4.6 --- .../shampoo_preconditioner_list.py | 214 +++++++----------- 1 file changed, 88 insertions(+), 126 deletions(-) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index bba4509..33089c5 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -16,7 +16,15 @@ from itertools import chain from operator import attrgetter from pathlib import Path -from typing import Any, Generic, get_args, NoReturn, overload, TypeAlias, TypeVar +from typing import ( + Any, + Generic, + get_args, + NoReturn, + overload, + TypeAlias, + TypeVar, +) import torch from distributed_shampoo.distributor.shampoo_block_info import BlockInfo @@ -608,8 +616,85 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) -class EigendecomposedShampooKroneckerFactorsUnwrapped( +class EigendecompositionBasedShampooKroneckerFactorsUnwrapped( BaseShampooKroneckerFactorsUnwrapped +): + """Base class for Shampoo variants using eigendecomposition. + + This class provides shared eigendecomposition logic for variants that compute + eigenvectors of factor matrices. Subclasses control eigenvalue behavior via + field presence: if a subclass has `factor_matrices_eigenvalues`, eigenvalues + from eigendecomposition will be included in amortized computation results. + """ + + factor_matrices_eigenvectors: tuple[Tensor, ...] + + @torch.compiler.disable + def _amortized_computation( + self, + bias_corrected_factor_matrix: Tensor, + kronecker_factors_iter_dict: dict[str, Any], + ) -> tuple[dict[str, Tensor], Exception | None]: + """Performs amortized eigendecomposition computation. + + Computes eigendecomposition of the bias-corrected factor matrix. + On exception, retains previous eigenvectors (and eigenvalues if applicable). + + Args: + bias_corrected_factor_matrix: The input matrix to decompose. + kronecker_factors_iter_dict: Dictionary with current eigenvectors + (and eigenvalues if applicable) as estimates. + + Returns: + Tuple of (updated dict with new eigenvectors/eigenvalues, exception or None). + + """ + factor_matrix_eigenvectors = kronecker_factors_iter_dict[ + "factor_matrices_eigenvectors" + ] + + try: + computed_eigenvalues, computed_eigenvectors = ( + matrix_eigendecomposition( + A=bias_corrected_factor_matrix, + eigendecomposition_config=self.amortized_computation_config, + eigenvectors_estimate=factor_matrix_eigenvectors.to( + dtype=bias_corrected_factor_matrix.dtype + ), + epsilon=self.epsilon, + ) + ) + + result: dict[str, Tensor] = { + "factor_matrices_eigenvectors": computed_eigenvectors.to( + dtype=factor_matrix_eigenvectors.dtype + ), + } + + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = computed_eigenvalues.to( + dtype=kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ].dtype + ) + + return result, None + except Exception as exception: + result = {"factor_matrices_eigenvectors": factor_matrix_eigenvectors} + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ] + return result, exception + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +@dataclass(kw_only=True) +class EigendecomposedShampooKroneckerFactorsUnwrapped( + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigendecomposed Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -639,7 +724,6 @@ class EigendecomposedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] factor_matrices_eigenvalues: tuple[Tensor, ...] @classmethod @@ -707,77 +791,11 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Performs eigendecomposition for Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the eigendecomposition for the EigendecomposedShampoo variant. It computes both - eigenvalues and eigenvectors for each factor matrix. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvalues and factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvalues and eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - ( - factor_matrix_eigenvectors, - factor_matrix_eigenvalues, - ) = ( - kronecker_factors_iter_dict["factor_matrices_eigenvectors"], - kronecker_factors_iter_dict["factor_matrices_eigenvalues"], - ) - - try: - # Compute inverse preconditioner. - computed_eigenvalues, computed_eigenvectors = matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - ) - - return { - "factor_matrices_eigenvalues": computed_eigenvalues.to( - dtype=factor_matrix_eigenvalues.dtype - ), - "factor_matrices_eigenvectors": computed_eigenvectors.to( - dtype=factor_matrix_eigenvectors.dtype - ), - }, None - except Exception as exception: - return { - "factor_matrices_eigenvalues": factor_matrix_eigenvalues, - "factor_matrices_eigenvectors": factor_matrix_eigenvectors, - }, exception - def __post_init__(self) -> None: super().__post_init__() assert ( len(self.roots) == len(self.factor_matrices) - == len(self.factor_matrices_eigenvectors) == len(self.factor_matrices_eigenvalues) ) @@ -851,7 +869,7 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( - BaseShampooKroneckerFactorsUnwrapped + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -885,7 +903,6 @@ class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] corrected_eigenvalues: Tensor @classmethod @@ -953,63 +970,8 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Computes eigenvectors for eigenvalue-corrected Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the computation of eigenvectors for the EigenvalueCorrectedShampoo variant. Unlike - the EigendecomposedShampoo variant, this only computes eigenvectors and not eigenvalues, - as the eigenvalues are corrected separately during the optimization process. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - factor_matrix_eigenvectors = kronecker_factors_iter_dict[ - "factor_matrices_eigenvectors" - ] - - try: - # Compute eigenvectors of factor matrix. - return { - "factor_matrices_eigenvectors": matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - )[1].to(dtype=factor_matrix_eigenvectors.dtype) - }, None - except Exception as exception: - return { - "factor_matrices_eigenvectors": factor_matrix_eigenvectors - }, exception - def __post_init__(self) -> None: super().__post_init__() - assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) assert len(self.roots) == 1 def _get_field_dict(self) -> dict[str, Any]: From 6cae8245bdb5ce3421c8f363eaca1b28353825db Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 19:53:14 +0000 Subject: [PATCH 2/2] Refactor: eliminate _compute_outer_product_list via _transform_grad_for_outer_product hook Inline the outer product loop into BaseShampooPreconditionerList._update_factor_matrices and introduce _transform_grad_for_outer_product as the single extension point. The base returns grad unchanged; KL-Shampoo subclasses override it to precondition the gradient. This eliminates _compute_outer_product_list from all three classes that defined it. Co-Authored-By: Claude Opus 4.6 --- .../shampoo_preconditioner_list.py | 189 +++++++++--------- 1 file changed, 97 insertions(+), 92 deletions(-) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index 33089c5..30a3b25 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -671,6 +671,8 @@ def _amortized_computation( ), } + # TODO: Consider replacing hasattr dispatch with a class variable + # (e.g., _include_eigenvalues) to avoid fragile reliance on field presence. if hasattr(self, "factor_matrices_eigenvalues"): result["factor_matrices_eigenvalues"] = computed_eigenvalues.to( dtype=kronecker_factors_iter_dict[ @@ -1279,24 +1281,20 @@ def compress_preconditioner_list( self._local_preconditioned_dims_selector_list, local_grad_selector ) - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: _ShampooKroneckerFactorsUnwrappedType, - ) -> tuple[Tensor, ...]: - # Construct outer product list for updating Kronecker factors. - return tuple( - torch.tensordot( - grad, - grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - for k in compress_list(range(order), preconditioned_dims_selector) - ) + ) -> Tensor: + """Transforms the gradient before outer product computation. + + Override this method in subclasses to apply preconditioning (e.g., KL-Shampoo). + """ + return grad @profile_decorator def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: @@ -1314,8 +1312,23 @@ def _update_factor_matrices(self, masked_grad_list: tuple[Tensor, ...]) -> None: if not kronecker_factors.factor_matrices: continue - outer_product_list = self._compute_outer_product_list( - grad, order, preconditioned_dims_selector, kronecker_factors + outer_product_list = tuple( + torch.tensordot( + transformed_grad := self._transform_grad_for_outer_product( + grad=grad, + idx_of_k=idx_of_k, + k=k, + order=order, + preconditioned_dims_selector=preconditioned_dims_selector, + kronecker_factors=kronecker_factors, + ), + transformed_grad, + # Contracts across all dimensions except for k. + dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) + for idx_of_k, k in enumerate( + compress_list(range(order), preconditioned_dims_selector) + ) ) if self._beta2 != 1.0: @@ -1618,6 +1631,52 @@ def _compute_preconditioned_gradient( ), ) + def _kl_transform_grad_for_outer_product( + self, + grad: Tensor, + idx_of_k: int, + k: int, + order: int, + preconditioned_dims_selector: tuple[bool, ...], + kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, + ) -> Tensor: + """Transforms the gradient for KL-Shampoo outer product computation. + + KL-Shampoo preconditions the gradient along all dimensions except k + with the inverse root of the factor matrices before computing the outer product. + """ + # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig + assert isinstance( + self._preconditioner_config.amortized_computation_config, + EigendecompositionConfig, + ) + rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config + + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + return self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + matrix_inverse_root_from_eigendecomposition( + L=eigenvalues, + Q=eigenvectors, + root=Fraction(root), + epsilon=self._epsilon, + rank_deficient_stability_config=rank_deficient_stability_config, + ) + for idx, (eigenvalues, eigenvectors, root) in enumerate( + zip( + kronecker_factors.factor_matrices_eigenvalues, + kronecker_factors.factor_matrices_eigenvectors, + kronecker_factors.roots, + strict=True, + ) + ) + if idx != idx_of_k + ), + ) + class EigenvalueCorrectedShampooPreconditionerList( BaseShampooPreconditionerList[ @@ -1732,42 +1791,29 @@ def _compute_preconditioned_gradient( class RootInvKLShampooPreconditionerList(RootInvShampooPreconditionerList): """Root inverse KL-Shampoo preconditioners for list of parameters.""" - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: RootInvShampooKroneckerFactorsUnwrapped, - ) -> tuple[Tensor, ...]: - # Construct outer product list for updating Kronecker factors. - outer_product_list = [] - for idx_of_k, k in enumerate( - compress_list(range(order), preconditioned_dims_selector) - ): - # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. - local_preconditioned_dims_selector = list(preconditioned_dims_selector) - local_preconditioned_dims_selector[k] = False - preconditioned_grad = self._precondition_grad( - grad=grad, - preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), - preconditioner_list=tuple( - inv_factor_matrix - for idx, inv_factor_matrix in enumerate( - kronecker_factors.inv_factor_matrices - ) - if idx != idx_of_k - ), - ) - outer_product_list.append( - torch.tensordot( - preconditioned_grad, - preconditioned_grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] + ) -> Tensor: + # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. + local_preconditioned_dims_selector = list(preconditioned_dims_selector) + local_preconditioned_dims_selector[k] = False + return self._precondition_grad( + grad=grad, + preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), + preconditioner_list=tuple( + inv_factor_matrix + for idx, inv_factor_matrix in enumerate( + kronecker_factors.inv_factor_matrices ) - ) - return tuple(outer_product_list) + if idx != idx_of_k + ), + ) class EigendecomposedKLShampooPreconditionerList( @@ -1775,56 +1821,15 @@ class EigendecomposedKLShampooPreconditionerList( ): """Eigendecomposed KL-Shampoo preconditioners for list of parameters.""" - @profile_decorator - def _compute_outer_product_list( + def _transform_grad_for_outer_product( self, grad: Tensor, + idx_of_k: int, + k: int, order: int, preconditioned_dims_selector: tuple[bool, ...], kronecker_factors: EigendecomposedShampooKroneckerFactorsUnwrapped, - ) -> tuple[Tensor, ...]: - # TODO: remove assertion when rank_deficient_stability_config is generalized to MatrixFunctionConfig - assert isinstance( - self._preconditioner_config.amortized_computation_config, - EigendecompositionConfig, + ) -> Tensor: + return self._kl_transform_grad_for_outer_product( + grad, idx_of_k, k, order, preconditioned_dims_selector, kronecker_factors, ) - rank_deficient_stability_config = self._preconditioner_config.amortized_computation_config.rank_deficient_stability_config - - # Construct outer product list for updating Kronecker factors. - outer_product_list = [] - for idx_of_k, k in enumerate( - compress_list(range(order), preconditioned_dims_selector) - ): - # KL-Shampoo uses the gradient preconditioned (along all dimensions that are contracted) with the inverse root of the factor matrices to compute the outer products. - local_preconditioned_dims_selector = list(preconditioned_dims_selector) - local_preconditioned_dims_selector[k] = False - preconditioned_grad = self._precondition_grad( - grad=grad, - preconditioned_dims_selector=tuple(local_preconditioned_dims_selector), - preconditioner_list=tuple( - matrix_inverse_root_from_eigendecomposition( - L=eigenvalues, - Q=eigenvectors, - root=Fraction(root), - epsilon=self._epsilon, - rank_deficient_stability_config=rank_deficient_stability_config, - ) - for idx, (eigenvalues, eigenvectors, root) in enumerate( - zip( - kronecker_factors.factor_matrices_eigenvalues, - kronecker_factors.factor_matrices_eigenvectors, - kronecker_factors.roots, - strict=True, - ) - ) - if idx != idx_of_k - ), - ) - outer_product = torch.tensordot( - preconditioned_grad, - preconditioned_grad, - # Contracts across all dimensions except for k. - dims=[[*chain(range(k), range(k + 1, order))]] * 2, # type: ignore[has-type] - ) - outer_product_list.append(outer_product) - return tuple(outer_product_list)