diff --git a/distributed_shampoo/distributed_shampoo.py b/distributed_shampoo/distributed_shampoo.py index faaedef..96288f1 100644 --- a/distributed_shampoo/distributed_shampoo.py +++ b/distributed_shampoo/distributed_shampoo.py @@ -621,9 +621,10 @@ def _initialize_blocked_parameters_state(self) -> None: for block_info in state_lists[DISTRIBUTOR].local_block_info_list: param_state = self.state[block_info.param] assert ( - (block_index := block_info.composable_block_ids[1]) - not in param_state - ), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions." + block_index := block_info.composable_block_ids[1] + ) not in param_state, ( + "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions." + ) param_state[block_index] = {} @torch.no_grad() @@ -718,9 +719,9 @@ def _preconditioner_config_to_list_cls( preconditioner_config=preconditioner_config, ) case SpectralDescentPreconditionerConfig(): - assert ( - group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 - ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." + assert group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2, ( + f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." + ) return SpectralDescentPreconditionerList( block_list=state_lists[DISTRIBUTOR].local_blocked_params, preconditioner_config=preconditioner_config, @@ -733,9 +734,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None: for state_lists, group in zip( self._per_group_state_lists, self.param_groups, strict=True ): - assert ( - group[PRECONDITIONER_CONFIG] is not None - ), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." + assert group[PRECONDITIONER_CONFIG] is not None, ( + f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." + ) state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( self._preconditioner_config_to_list_cls( state_lists=state_lists, @@ -1653,9 +1654,9 @@ def _post_load_state_dict_hook(optimizer: Optimizer) -> None: if saved_train_modes: # Mixed train/eval modes across parameter groups is not supported # since train() and eval() always operate on all groups uniformly. - assert all( - m == saved_train_modes[0] for m in saved_train_modes - ), "Mixed train/eval modes across parameter groups is not supported." + assert all(m == saved_train_modes[0] for m in saved_train_modes), ( + "Mixed train/eval modes across parameter groups is not supported." + ) operator.attrgetter("train" if saved_train_modes[0] else "eval")( optimizer )() diff --git a/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py b/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py index 44947d9..7d91201 100644 --- a/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py +++ b/distributed_shampoo/distributor/gpu_tests/shampoo_fsdp_distributor_test.py @@ -101,7 +101,9 @@ def _construct_model( assert ( sum(param.numel() for param in model.parameters()) == sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2 - ), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" + ), ( + f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" + ) distributed_config.param_to_metadata = compile_fsdp_parameter_metadata( model ) diff --git a/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py b/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py index 8b82b1b..1678109 100644 --- a/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py +++ b/distributed_shampoo/distributor/gpu_tests/shampoo_hsdp_distributor_test.py @@ -102,7 +102,9 @@ def _construct_model( assert ( sum(param.numel() for param in model.parameters()) == sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2 - ), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" + ), ( + f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}" + ) distributed_config.param_to_metadata = compile_fsdp_parameter_metadata( model ) diff --git a/distributed_shampoo/distributor/shampoo_ddp_distributor.py b/distributed_shampoo/distributor/shampoo_ddp_distributor.py index deaeee5..a282b7b 100644 --- a/distributed_shampoo/distributor/shampoo_ddp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_ddp_distributor.py @@ -308,9 +308,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert ( - len(local_params) == len(blocked_search_directions) - ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_params) == len(blocked_search_directions), ( + f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -335,9 +335,9 @@ def all_gather_into_tensor() -> None: ) else: - assert ( - len(local_buffers) == len(blocked_search_directions) - ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_buffers) == len(blocked_search_directions), ( + f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/distributor/shampoo_distributor.py b/distributed_shampoo/distributor/shampoo_distributor.py index e77dc62..744fdc3 100644 --- a/distributed_shampoo/distributor/shampoo_distributor.py +++ b/distributed_shampoo/distributor/shampoo_distributor.py @@ -280,9 +280,9 @@ def _merge_and_block_gradients( assert grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite( - grad - ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping." + assert torch.isfinite(grad).all(), ( + f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping." + ) # Obtain blocks for each gradient after merging. blocks_within_grad = multi_dim_split( @@ -354,9 +354,9 @@ def update_params( else self._local_blocked_params ) - assert ( - len(blocked_search_directions) == len(target_params) - ), f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}." + assert len(blocked_search_directions) == len(target_params), ( + f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/distributor/shampoo_fsdp_distributor.py b/distributed_shampoo/distributor/shampoo_fsdp_distributor.py index cbec4c8..b63cf6b 100644 --- a/distributed_shampoo/distributor/shampoo_fsdp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_fsdp_distributor.py @@ -154,9 +154,9 @@ def _merge_and_block_gradients( assert flattened_grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite( - flattened_grad - ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." + assert torch.isfinite(flattened_grad).all(), ( + f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." + ) # Split flattened gradients into valid tensor blocks of the gradient. split_grads = FSDPDistributor._split_tensor_block_recovery( diff --git a/distributed_shampoo/distributor/shampoo_fsdp_utils.py b/distributed_shampoo/distributor/shampoo_fsdp_utils.py index 59446b0..6b18d18 100644 --- a/distributed_shampoo/distributor/shampoo_fsdp_utils.py +++ b/distributed_shampoo/distributor/shampoo_fsdp_utils.py @@ -41,9 +41,9 @@ def compile_fsdp_parameter_metadata( shard_param_infos = flat_param._shard_param_infos sharding_strategy = fsdp_module.sharding_strategy - assert ( - flat_param._params is not None - ), "flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True " + assert flat_param._params is not None, ( + "flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True " + ) "would populate flat_param._params." params = flat_param._params @@ -156,13 +156,12 @@ def partition_param_list( ) assert ( - ( - unioned_keys := fsdp_params_dict.keys() - | hsdp_params_dict.keys() - | other_params_dict.keys() - ) - == original_params_dict.keys() - ), f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}" + unioned_keys := fsdp_params_dict.keys() + | hsdp_params_dict.keys() + | other_params_dict.keys() + ) == original_params_dict.keys(), ( + f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}" + ) for (name1, dict1), (name2, dict2) in itertools.combinations( ( ("fsdp_params_dict", fsdp_params_dict), @@ -171,9 +170,9 @@ def partition_param_list( ), 2, ): - assert not ( - common_keys := dict1.keys() & dict2.keys() - ), f"{common_keys} exist in both {name1} and {name2}!" + assert not (common_keys := dict1.keys() & dict2.keys()), ( + f"{common_keys} exist in both {name1} and {name2}!" + ) return ( list(fsdp_params_dict.items()), diff --git a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py index 288eeff..a1651ca 100644 --- a/distributed_shampoo/distributor/shampoo_hsdp_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hsdp_distributor.py @@ -262,9 +262,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert ( - len(local_params) == len(blocked_search_directions) - ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_params) == len(blocked_search_directions), ( + f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -289,9 +289,9 @@ def all_gather_into_tensor() -> None: ) else: - assert ( - len(local_buffers) == len(blocked_search_directions) - ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_buffers) == len(blocked_search_directions), ( + f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -564,9 +564,9 @@ def _merge_and_block_gradients( assert flattened_grad is not None if self._runtime_config.eager_nan_check: - assert torch.isfinite( - flattened_grad - ).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." + assert torch.isfinite(flattened_grad).all(), ( + f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping." + ) # Split flattened gradients into valid tensor blocks of the gradient. split_grads = HSDPDistributor._split_tensor_block_recovery( diff --git a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py index 0e63f3f..8617efd 100644 --- a/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py +++ b/distributed_shampoo/distributor/shampoo_hybrid_shard_distributor.py @@ -287,9 +287,9 @@ def all_gather_into_tensor() -> None: global_buffers = self._global_dist_blocked_buffers if self._communicate_params: - assert ( - len(local_params) == len(blocked_search_directions) - ), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_params) == len(blocked_search_directions), ( + f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: @@ -314,9 +314,9 @@ def all_gather_into_tensor() -> None: ) else: - assert ( - len(local_buffers) == len(blocked_search_directions) - ), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + assert len(local_buffers) == len(blocked_search_directions), ( + f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}." + ) # torch._foreach only accepts non-empty list if blocked_search_directions: diff --git a/distributed_shampoo/examples/convnet.py b/distributed_shampoo/examples/convnet.py index 23d4815..9302228 100644 --- a/distributed_shampoo/examples/convnet.py +++ b/distributed_shampoo/examples/convnet.py @@ -67,6 +67,8 @@ def _infer_conv_output_shape( output_shape = [] for input_length in input_shape: output_length = (input_length - kernel_size + 2 * padding) / stride + 1 - assert output_length.is_integer(), f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!" + assert output_length.is_integer(), ( + f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!" + ) output_shape.append(int(output_length)) return output_shape diff --git a/distributed_shampoo/preconditioner/matrix_functions.py b/distributed_shampoo/preconditioner/matrix_functions.py index 29e70db..65ad1c8 100644 --- a/distributed_shampoo/preconditioner/matrix_functions.py +++ b/distributed_shampoo/preconditioner/matrix_functions.py @@ -808,9 +808,9 @@ def qr_algorithm( Q = eigenvectors_estimate # This assertion provides a more clear error message than the internal error message in `torch.mm`, and assertion makes sure that user-side is unable to catch the error. - assert ( - Q.dtype == A.dtype - ), f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}" + assert Q.dtype == A.dtype, ( + f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}" + ) eigenvalues_estimate = Q.T @ A @ Q iteration = 0 @@ -877,9 +877,9 @@ def qr_algorithm( func=eigh_eigenvalue_decomposition, config=eigendecomposition_config )(A=A_ridge) case QREigendecompositionConfig(): - assert ( - eigenvectors_estimate is not None - ), "eigenvectors_estimate should not be None when QR algorithm is used." + assert eigenvectors_estimate is not None, ( + "eigenvectors_estimate should not be None when QR algorithm is used." + ) return _assign_function_args_from_config( func=qr_algorithm, config=eigendecomposition_config )(A=A_ridge, eigenvectors_estimate=eigenvectors_estimate) diff --git a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py index 217c79c..bba4509 100644 --- a/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py +++ b/distributed_shampoo/preconditioner/shampoo_preconditioner_list.py @@ -1374,14 +1374,14 @@ def _precondition_grad( ) -> Tensor: # TODO: Need to refactor this function to be more efficient. Ideally eliminate those branches. # Might consider einsum? - assert ( - sum(preconditioned_dims_selector) == len(preconditioner_list) - ), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})." + assert sum(preconditioned_dims_selector) == len(preconditioner_list), ( + f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})." + ) # Extract all dtypes and assert they are unique - assert ( - len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1 - ), f"All preconditioners must have the same dtype, but found: {unique_dtypes}" + assert len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1, ( + f"All preconditioners must have the same dtype, but found: {unique_dtypes}" + ) # Use the single dtype if preconditioners exist, otherwise use grad dtype target_dtype = next(iter(unique_dtypes), grad.dtype) @@ -1566,9 +1566,9 @@ def _get_inverse_exponent(self, dimension: int, order: int) -> float: return inverse_exponent_override_on_order.get( dimension, 1 / (2 * max(order, 1)) ) - assert isinstance( - inverse_exponent_override_on_order, float - ), f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead." + assert isinstance(inverse_exponent_override_on_order, float), ( + f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead." + ) return inverse_exponent_override_on_order def _create_preconditioned_dims_selector( diff --git a/distributed_shampoo/utils/shampoo_quantization.py b/distributed_shampoo/utils/shampoo_quantization.py index 437f61f..de05058 100644 --- a/distributed_shampoo/utils/shampoo_quantization.py +++ b/distributed_shampoo/utils/shampoo_quantization.py @@ -153,9 +153,9 @@ def __init__( value.dtype == quantized_dtype for value in self.quantized_value_list ) self.quantized_dtype = quantized_dtype - assert ( - computation_dtype in _FLOAT_DTYPES - ), f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!" + assert computation_dtype in _FLOAT_DTYPES, ( + f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!" + ) self.computation_dtype = computation_dtype # All min/max values should be None, or no min/max values are None diff --git a/distributed_shampoo/utils/shampoo_utils.py b/distributed_shampoo/utils/shampoo_utils.py index 1a0e690..23060d9 100644 --- a/distributed_shampoo/utils/shampoo_utils.py +++ b/distributed_shampoo/utils/shampoo_utils.py @@ -79,9 +79,9 @@ def merge_small_dims( return (0,) if isinstance(target_tensor_dimensionality, float): - assert ( - target_tensor_dimensionality == math.inf - ), f"{target_tensor_dimensionality=} has to be an integer or math.inf." + assert target_tensor_dimensionality == math.inf, ( + f"{target_tensor_dimensionality=} has to be an integer or math.inf." + ) return tensor_shape # Squeeze tensor shape to remove dimension with 1; if all dimensions are 1, @@ -151,9 +151,9 @@ def multi_dim_split(tensor: Tensor, split_size: int | float) -> tuple[Tensor, .. """ if isinstance(split_size, float): - assert ( - split_size == math.inf - ), f"{split_size=} has to be an integer or math.inf." + assert split_size == math.inf, ( + f"{split_size=} has to be an integer or math.inf." + ) return (tensor,) return reduce( @@ -190,9 +190,9 @@ def compress_list( Only elements from complete_list where the corresponding selector is True are included. """ - assert ( - len(complete_list) == len(selector) - ), f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!" + assert len(complete_list) == len(selector), ( + f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!" + ) return tuple(compress(complete_list, selector)) diff --git a/gpa/gpa_adamw.py b/gpa/gpa_adamw.py index 462ced0..4c5b127 100644 --- a/gpa/gpa_adamw.py +++ b/gpa/gpa_adamw.py @@ -13,7 +13,6 @@ import torch import torch.optim -from torch import Tensor from gpa.gpa_types import ( BETA1, BETA2, @@ -35,6 +34,7 @@ WEIGHT_SUM, Z_BUFFER, ) +from torch import Tensor from torch.optim.optimizer import ParamsT logger = getLogger() @@ -428,9 +428,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] lr_max = max(lr, self.state[group_first_param][LR_MAX].item()) self.state[group_first_param][LR_MAX].fill_(lr_max) - assert ( - lr_max > 0 - ), f"lr_max must be positive, got lr_max={lr_max}. Check that lr={lr} is positive." + assert lr_max > 0, ( + f"lr_max must be positive, got lr_max={lr_max}. Check that lr={lr} is positive." + ) # Compute avg_coeff ONCE per step (before the parameter loop). # This is important for Schedule-Free: the coefficient should be the same