Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 88 additions & 126 deletions distributed_shampoo/preconditioner/shampoo_preconditioner_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import Any, Generic, get_args, NoReturn, overload, TypeAlias, TypeVar
from typing import (
Any,
Generic,
get_args,
NoReturn,
overload,
TypeAlias,
TypeVar,
)

import torch
from distributed_shampoo.distributor.shampoo_block_info import BlockInfo
Expand Down Expand Up @@ -608,8 +616,85 @@ def __post_init__(self) -> None:


@dataclass(kw_only=True)
class EigendecomposedShampooKroneckerFactorsUnwrapped(
class EigendecompositionBasedShampooKroneckerFactorsUnwrapped(
BaseShampooKroneckerFactorsUnwrapped
):
"""Base class for Shampoo variants using eigendecomposition.

This class provides shared eigendecomposition logic for variants that compute
eigenvectors of factor matrices. Subclasses control eigenvalue behavior via
field presence: if a subclass has `factor_matrices_eigenvalues`, eigenvalues
from eigendecomposition will be included in amortized computation results.
"""

factor_matrices_eigenvectors: tuple[Tensor, ...]

@torch.compiler.disable
def _amortized_computation(
self,
bias_corrected_factor_matrix: Tensor,
kronecker_factors_iter_dict: dict[str, Any],
) -> tuple[dict[str, Tensor], Exception | None]:
"""Performs amortized eigendecomposition computation.

Computes eigendecomposition of the bias-corrected factor matrix.
On exception, retains previous eigenvectors (and eigenvalues if applicable).

Args:
bias_corrected_factor_matrix: The input matrix to decompose.
kronecker_factors_iter_dict: Dictionary with current eigenvectors
(and eigenvalues if applicable) as estimates.

Returns:
Tuple of (updated dict with new eigenvectors/eigenvalues, exception or None).

"""
factor_matrix_eigenvectors = kronecker_factors_iter_dict[
"factor_matrices_eigenvectors"
]

try:
computed_eigenvalues, computed_eigenvectors = (
matrix_eigendecomposition(
A=bias_corrected_factor_matrix,
eigendecomposition_config=self.amortized_computation_config,
eigenvectors_estimate=factor_matrix_eigenvectors.to(
dtype=bias_corrected_factor_matrix.dtype
),
epsilon=self.epsilon,
)
)

result: dict[str, Tensor] = {
"factor_matrices_eigenvectors": computed_eigenvectors.to(
dtype=factor_matrix_eigenvectors.dtype
),
}

if hasattr(self, "factor_matrices_eigenvalues"):
result["factor_matrices_eigenvalues"] = computed_eigenvalues.to(
dtype=kronecker_factors_iter_dict[
"factor_matrices_eigenvalues"
].dtype
)

return result, None
except Exception as exception:
result = {"factor_matrices_eigenvectors": factor_matrix_eigenvectors}
if hasattr(self, "factor_matrices_eigenvalues"):
result["factor_matrices_eigenvalues"] = kronecker_factors_iter_dict[
"factor_matrices_eigenvalues"
]
return result, exception

def __post_init__(self) -> None:
super().__post_init__()
assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors)


@dataclass(kw_only=True)
class EigendecomposedShampooKroneckerFactorsUnwrapped(
EigendecompositionBasedShampooKroneckerFactorsUnwrapped
):
"""Eigendecomposed Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.

Expand Down Expand Up @@ -639,7 +724,6 @@ class EigendecomposedShampooKroneckerFactorsUnwrapped(
of consecutive failed amortized computations.
"""

factor_matrices_eigenvectors: tuple[Tensor, ...]
factor_matrices_eigenvalues: tuple[Tensor, ...]

@classmethod
Expand Down Expand Up @@ -707,77 +791,11 @@ def from_kronecker_factors_state(
use_trace_scaling=use_trace_scaling,
)

@torch.compiler.disable
def _amortized_computation(
self,
bias_corrected_factor_matrix: Tensor,
kronecker_factors_iter_dict: dict[str, Any],
) -> tuple[dict[str, Tensor], Exception | None]:
"""Performs eigendecomposition for Shampoo preconditioners.

This implementation of the abstract _amortized_computation method specifically handles
the eigendecomposition for the EigendecomposedShampoo variant. It computes both
eigenvalues and eigenvectors for each factor matrix.

The computation uses the configuration specified in amortized_computation_config,
with special handling for QR-based eigendecomposition which requires the previous
eigenvectors as an initial estimate. Error handling is included to gracefully
recover from numerical issues.

Args:
bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction
has been applied.
kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current
factor_matrices_eigenvalues and factor_matrices_eigenvectors for the computation.

Returns:
computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvalues and eigenvectors.
exception (Exception | None): Any exception that occurred during computation, or None if successful.

Note:
This function assumes there are no changes in the selector or masking between
iterations within a single precondition_frequency interval.
"""
(
factor_matrix_eigenvectors,
factor_matrix_eigenvalues,
) = (
kronecker_factors_iter_dict["factor_matrices_eigenvectors"],
kronecker_factors_iter_dict["factor_matrices_eigenvalues"],
)

try:
# Compute inverse preconditioner.
computed_eigenvalues, computed_eigenvectors = matrix_eigendecomposition(
A=bias_corrected_factor_matrix,
eigendecomposition_config=self.amortized_computation_config,
# To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix.
eigenvectors_estimate=factor_matrix_eigenvectors.to(
dtype=bias_corrected_factor_matrix.dtype
),
epsilon=self.epsilon,
)

return {
"factor_matrices_eigenvalues": computed_eigenvalues.to(
dtype=factor_matrix_eigenvalues.dtype
),
"factor_matrices_eigenvectors": computed_eigenvectors.to(
dtype=factor_matrix_eigenvectors.dtype
),
}, None
except Exception as exception:
return {
"factor_matrices_eigenvalues": factor_matrix_eigenvalues,
"factor_matrices_eigenvectors": factor_matrix_eigenvectors,
}, exception

def __post_init__(self) -> None:
super().__post_init__()
assert (
len(self.roots)
== len(self.factor_matrices)
== len(self.factor_matrices_eigenvectors)
== len(self.factor_matrices_eigenvalues)
)

Expand Down Expand Up @@ -851,7 +869,7 @@ def __post_init__(self) -> None:

@dataclass(kw_only=True)
class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped(
BaseShampooKroneckerFactorsUnwrapped
EigendecompositionBasedShampooKroneckerFactorsUnwrapped
):
"""Eigenvalue-corrected Shampoo Kronecker factors (unwrapped) for operations during optimizer computation.

Expand Down Expand Up @@ -885,7 +903,6 @@ class EigenvalueCorrectedShampooKroneckerFactorsUnwrapped(
of consecutive failed amortized computations.
"""

factor_matrices_eigenvectors: tuple[Tensor, ...]
corrected_eigenvalues: Tensor

@classmethod
Expand Down Expand Up @@ -953,63 +970,8 @@ def from_kronecker_factors_state(
use_trace_scaling=use_trace_scaling,
)

@torch.compiler.disable
def _amortized_computation(
self,
bias_corrected_factor_matrix: Tensor,
kronecker_factors_iter_dict: dict[str, Any],
) -> tuple[dict[str, Tensor], Exception | None]:
"""Computes eigenvectors for eigenvalue-corrected Shampoo preconditioners.

This implementation of the abstract _amortized_computation method specifically handles
the computation of eigenvectors for the EigenvalueCorrectedShampoo variant. Unlike
the EigendecomposedShampoo variant, this only computes eigenvectors and not eigenvalues,
as the eigenvalues are corrected separately during the optimization process.

The computation uses the configuration specified in amortized_computation_config,
with special handling for QR-based eigendecomposition which requires the previous
eigenvectors as an initial estimate. Error handling is included to gracefully
recover from numerical issues.

Args:
bias_corrected_factor_matrix (Tensor): The factor matrix after bias correction
has been applied.
kronecker_factors_iter_dict (dict[str, Any]): Dictionary containing the current
factor_matrices_eigenvectors for the computation.

Returns:
computed_quantities (dict[str, Tensor]): A dictionary with the computed eigenvectors.
exception (Exception | None): Any exception that occurred during computation, or None if successful.

Note:
This function assumes there are no changes in the selector or masking between
iterations within a single precondition_frequency interval.
"""
factor_matrix_eigenvectors = kronecker_factors_iter_dict[
"factor_matrices_eigenvectors"
]

try:
# Compute eigenvectors of factor matrix.
return {
"factor_matrices_eigenvectors": matrix_eigendecomposition(
A=bias_corrected_factor_matrix,
eigendecomposition_config=self.amortized_computation_config,
# To estimate the eigenvalues based on the previous eigenvectors, we need to pass in the previous eigenvectors with the same dtype as the input matrix, i.e., factor_matrix.
eigenvectors_estimate=factor_matrix_eigenvectors.to(
dtype=bias_corrected_factor_matrix.dtype
),
epsilon=self.epsilon,
)[1].to(dtype=factor_matrix_eigenvectors.dtype)
}, None
except Exception as exception:
return {
"factor_matrices_eigenvectors": factor_matrix_eigenvectors
}, exception

def __post_init__(self) -> None:
super().__post_init__()
assert len(self.factor_matrices) == len(self.factor_matrices_eigenvectors)
assert len(self.roots) == 1

def _get_field_dict(self) -> dict[str, Any]:
Expand Down
Loading