Skip to content
4 changes: 4 additions & 0 deletions distributed_shampoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
EigendecomposedKLShampooPreconditionerConfig,
EigendecomposedShampooPreconditionerConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig,
PerFactorEigenvalueCorrectedShampooPreconditionerConfig,
FSDPDistributedConfig,
FSDPParamAssignmentStrategy,
FullyShardDistributedConfig,
Expand Down Expand Up @@ -92,6 +94,8 @@
"RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`.
"EigendecomposedShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`.
"EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`.
"PerFactorEigenvalueCorrectedShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`.
"PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig", # Based on `PerFactorEigenvalueCorrectedShampooPreconditionerConfig`.
"EigenvalueCorrectedShampooPreconditionerConfig", # Based on `BaseShampooPreconditionerConfig`.
"DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`.
"DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`.
Expand Down
8 changes: 8 additions & 0 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
EigendecomposedKLShampooPreconditionerList,
EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerList,
PerFactorEigenvalueCorrectedKLShampooPreconditionerList,
PerFactorEigenvalueCorrectedShampooPreconditionerList,
RootInvKLShampooPreconditionerList,
RootInvShampooPreconditionerList,
)
Expand All @@ -76,6 +78,8 @@
EigendecomposedShampooPreconditionerConfig,
EigenvalueCorrectedShampooPreconditionerConfig,
EPSILON,
PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig,
PerFactorEigenvalueCorrectedShampooPreconditionerConfig,
FILTERED_GRAD,
FILTERED_GRAD_LIST,
FSDPDistributedConfig,
Expand Down Expand Up @@ -687,17 +691,21 @@ def _preconditioner_config_to_list_cls(
RootInvShampooPreconditionerConfig()
| EigendecomposedShampooPreconditionerConfig()
| EigenvalueCorrectedShampooPreconditionerConfig()
| PerFactorEigenvalueCorrectedShampooPreconditionerConfig()
| RootInvKLShampooPreconditionerConfig()
| EigendecomposedKLShampooPreconditionerConfig()
| PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig()
):
preconditioner_config_to_list_cls: dict[
type[PreconditionerConfig], Callable[..., PreconditionerList]
] = {
RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList,
EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList,
EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList,
PerFactorEigenvalueCorrectedShampooPreconditionerConfig: PerFactorEigenvalueCorrectedShampooPreconditionerList,
RootInvKLShampooPreconditionerConfig: RootInvKLShampooPreconditionerList,
EigendecomposedKLShampooPreconditionerConfig: EigendecomposedKLShampooPreconditionerList,
PerFactorEigenvalueCorrectedKLShampooPreconditionerConfig: PerFactorEigenvalueCorrectedKLShampooPreconditionerList,
}
beta2 = group[BETAS][1]
return preconditioner_config_to_list_cls[type(preconditioner_config)](
Expand Down
Loading
Loading