From 1c61efe57448b75ac48120622f5abdaa357d2e41 Mon Sep 17 00:00:00 2001 From: runame Date: Sun, 8 Feb 2026 18:34:15 +0000 Subject: [PATCH] Refactor: extract shared EigendecompositionBasedShampooKroneckerFactorsUnwrapped base class Consolidate duplicated eigendecomposition logic from EigendecomposedShampooKroneckerFactorsUnwrapped and EigenvalueCorrectedShampooKroneckerFactorsUnwrapped into a shared base class. The base class provides _perform_eigendecomposition and _amortized_computation, with subclass behavior controlled via hasattr checks on field presence. Co-Authored-By: Claude Opus 4.6 --- .../shampoo_preconditioner_list.py | 214 +++++++----------- 1 file changed, 88 insertions(+), 126 deletions(-) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index bba4509..33089c5 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -16,7 +16,15 @@ from itertools import chain from operator import attrgetter from pathlib import Path -from typing import Any, Generic, get_args, NoReturn, overload, TypeAlias, TypeVar +from typing import ( + Any, + Generic, + get_args, + NoReturn, + overload, + TypeAlias, + TypeVar, +) import torch from distributed_shampoo.distributor.shampoo_block_info import BlockInfo @@ -608,8 +616,85 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) -class EigendecomposedShampooKroneckerFactorsUnwrapped( +class EigendecompositionBasedShampooKroneckerFactorsUnwrapped( BaseShampooKroneckerFactorsUnwrapped +): + """Base class for Shampoo variants using eigendecomposition. + + This class provides shared eigendecomposition logic for variants that compute + eigenvectors of factor matrices. Subclasses control eigenvalue behavior via + field presence: if a subclass has `factor_matrices_eigenvalues`, eigenvalues + from eigendecomposition will be included in amortized computation results. + """ + + factor_matrices_eigenvectors: tuple[Tensor, ...] + + @torch.compiler.disable + def _amortized_computation( + self, + bias_corrected_factor_matrix: Tensor, + kronecker_factors_iter_dict: dict[str, Any], + ) -> tuple[dict[str, Tensor], Exception | None]: + """Performs amortized eigendecomposition computation. + + Computes eigendecomposition of the bias-corrected factor matrix. + On exception, retains previous eigenvectors (and eigenvalues if applicable). + + Args: + bias_corrected_factor_matrix: The input matrix to decompose. + kronecker_factors_iter_dict: Dictionary with current eigenvectors + (and eigenvalues if applicable) as estimates. + + Returns: + Tuple of (updated dict with new eigenvectors/eigenvalues, exception or None). + + """ + factor_matrix_eigenvectors = kronecker_factors_iter_dict[ + "factor_matrices_eigenvectors" + ] + + try: + computed_eigenvalues, computed_eigenvectors = ( + matrix_eigendecomposition( + A=bias_corrected_factor_matrix, + eigendecomposition_config=self.amortized_computation_config, + eigenvectors_estimate=factor_matrix_eigenvectors.to( + dtype=bias_corrected_factor_matrix.dtype + ), + epsilon=self.epsilon, + ) + ) + + result: dict[str, Tensor] = { + "factor_matrices_eigenvectors": computed_eigenvectors.to( + dtype=factor_matrix_eigenvectors.dtype + ), + } + + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = computed_eigenvalues.to( + dtype=kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ].dtype + ) + + return result, None + except Exception as exception: + result = {"factor_matrices_eigenvectors": factor_matrix_eigenvectors} + if hasattr(self, "factor_matrices_eigenvalues"): + result["factor_matrices_eigenvalues"] = kronecker_factors_iter_dict[ + "factor_matrices_eigenvalues" + ] + return result, exception + + def __post_init__(self) -> None: + super().__post_init__() + assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) + + +@dataclass(kw_only=True) +class EigendecomposedShampooKroneckerFactorsUnwrapped( + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigendecomposed Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -639,7 +724,6 @@ class EigendecomposedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] factor_matrices_eigenvalues: tuple[Tensor, ...] @classmethod @@ -707,77 +791,11 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Performs eigendecomposition for Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the eigendecomposition for the EigendecomposedShampoo variant. It computes both - eigenvalues and eigenvectors for each factor matrix. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvalues and factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvalues and eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - ( - factor_matrix_eigenvectors, - factor_matrix_eigenvalues, - ) = ( - kronecker_factors_iter_dict["factor_matrices_eigenvectors"], - kronecker_factors_iter_dict["factor_matrices_eigenvalues"], - ) - - try: - # Compute inverse preconditioner. - computed_eigenvalues, computed_eigenvectors = matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - ) - - return { - "factor_matrices_eigenvalues": computed_eigenvalues.to( - dtype=factor_matrix_eigenvalues.dtype - ), - "factor_matrices_eigenvectors": computed_eigenvectors.to( - dtype=factor_matrix_eigenvectors.dtype - ), - }, None - except Exception as exception: - return { - "factor_matrices_eigenvalues": factor_matrix_eigenvalues, - "factor_matrices_eigenvectors": factor_matrix_eigenvectors, - }, exception - def __post_init__(self) -> None: super().__post_init__() assert ( len(self.roots) == len(self.factor_matrices) - == len(self.factor_matrices_eigenvectors) == len(self.factor_matrices_eigenvalues) ) @@ -851,7 +869,7 @@ def __post_init__(self) -> None: @dataclass(kw_only=True) class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( - BaseShampooKroneckerFactorsUnwrapped + EigendecompositionBasedShampooKroneckerFactorsUnwrapped ): """Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation. @@ -885,7 +903,6 @@ class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped( of consecutive failed amortized computations. """ - factor_matrices_eigenvectors: tuple[Tensor, ...] corrected_eigenvalues: Tensor @classmethod @@ -953,63 +970,8 @@ def from_kronecker_factors_state( use_trace_scaling=use_trace_scaling, ) - @torch.compiler.disable - def _amortized_computation( - self, - bias_corrected_factor_matrix: Tensor, - kronecker_factors_iter_dict: dict[str, Any], - ) -> tuple[dict[str, Tensor], Exception | None]: - """Computes eigenvectors for eigenvalue-corrected Shampoo preconditioners. - - This implementation of the abstract _amortized_computation method specifically handles - the computation of eigenvectors for the EigenvalueCorrectedShampoo variant. Unlike - the EigendecomposedShampoo variant, this only computes eigenvectors and not eigenvalues, - as the eigenvalues are corrected separately during the optimization process. - - The computation uses the configuration specified in amortized_computation_config, - with special handling for QR-based eigendecomposition which requires the previous - eigenvectors as an initial estimate. Error handling is included to gracefully - recover from numerical issues. - - Args: - bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction - has been applied. - kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current - factor_matrices_eigenvectors for the computation. - - Returns: - computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvectors. - exception (Exception | None): Any exception that occurred during computation, or None if successful. - - Note: - This function assumes there are no changes in the selector or masking between - iterations within a single precondition_frequency interval. - """ - factor_matrix_eigenvectors = kronecker_factors_iter_dict[ - "factor_matrices_eigenvectors" - ] - - try: - # Compute eigenvectors of factor matrix. - return { - "factor_matrices_eigenvectors": matrix_eigendecomposition( - A=bias_corrected_factor_matrix, - eigendecomposition_config=self.amortized_computation_config, - # To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix. - eigenvectors_estimate=factor_matrix_eigenvectors.to( - dtype=bias_corrected_factor_matrix.dtype - ), - epsilon=self.epsilon, - )[1].to(dtype=factor_matrix_eigenvectors.dtype) - }, None - except Exception as exception: - return { - "factor_matrices_eigenvectors": factor_matrix_eigenvectors - }, exception - def __post_init__(self) -> None: super().__post_init__() - assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors) assert len(self.roots) == 1 def _get_field_dict(self) -> dict[str, Any]: