diff --git a/distributed_shampoo/__init__.py b/distributed_shampoo/__init__.py index b767a52..767cf6a 100644 --- a/distributed_shampoo/__init__.py +++ b/distributed_shampoo/__init__.py @@ -34,7 +34,8 @@ from distributed_shampoo.shampoo_types import ( AdaGradPreconditionerConfig, AdamPreconditionerConfig, - AmortizedPreconditionerConfig, + BaseShampooPreconditionerConfig, + ClassicShampooPreconditionerConfig, DDPDistributedConfig, DefaultEigenvalueCorrectedShampooConfig, DefaultShampooConfig, @@ -59,7 +60,6 @@ RootInvShampooPreconditionerConfig, ScheduleFreeConfig, SGDPreconditionerConfig, - ShampooPreconditionerConfig, ShampooPT2CompileConfig, SignDescentPreconditionerConfig, SingleDeviceDistributedConfig, @@ -85,14 +85,14 @@ # `precision_config`. # `preconditioner_config` options. "PreconditionerConfig", # Abstract base class. - "AmortizedPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`). - "ShampooPreconditionerConfig", # Abstract base class (based on `AmortizedPreconditionerConfig`). - "RootInvShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "BaseShampooPreconditionerConfig", # Abstract base class (based on `PreconditionerConfig`). + "ClassicShampooPreconditionerConfig", # Abstract base class (based on `BaseShampooPreconditionerConfig`). + "RootInvShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`. "DefaultShampooConfig", # Default `RootInvShampooPreconditionerConfig` using `EigenConfig`. "RootInvKLShampooPreconditionerConfig", # Based on `RootInvShampooPreconditionerConfig`. - "EigendecomposedShampooPreconditionerConfig", # Based on `ShampooPreconditionerConfig`. + "EigendecomposedShampooPreconditionerConfig", # Based on `ClassicShampooPreconditionerConfig`. "EigendecomposedKLShampooPreconditionerConfig", # Based on `EigendecomposedShampooPreconditionerConfig`. - "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `AmortizedPreconditionerConfig`. + "EigenvalueCorrectedShampooPreconditionerConfig", # Based on `BaseShampooPreconditionerConfig`. "DefaultEigenvalueCorrectedShampooConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `EighEigendecompositionConfig`. "DefaultSOAPConfig", # Default `EigenvalueCorrectedShampooPreconditionerConfig` using `QREigendecompositionConfig`. "SpectralDescentPreconditionerConfig", # Based on `PreconditionerConfig`.