Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions distributed_shampoo/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,10 @@ def _initialize_blocked_parameters_state(self) -> None:
for block_info in state_lists[DISTRIBUTOR].local_block_info_list:
param_state = self.state[block_info.param]
assert (
(block_index := block_info.composable_block_ids[1])
not in param_state
), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
block_index := block_info.composable_block_ids[1]
) not in param_state, (
"There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
)
param_state[block_index] = {}

@torch.no_grad()
Expand Down Expand Up @@ -718,9 +719,9 @@ def _preconditioner_config_to_list_cls(
preconditioner_config=preconditioner_config,
)
case SpectralDescentPreconditionerConfig():
assert (
group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2
), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
assert group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2, (
f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig."
)
return SpectralDescentPreconditionerList(
block_list=state_lists[DISTRIBUTOR].local_blocked_params,
preconditioner_config=preconditioner_config,
Expand All @@ -733,9 +734,9 @@ def _instantiate_shampoo_preconditioner_list(self) -> None:
for state_lists, group in zip(
self._per_group_state_lists, self.param_groups, strict=True
):
assert (
group[PRECONDITIONER_CONFIG] is not None
), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
assert group[PRECONDITIONER_CONFIG] is not None, (
f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo."
)
state_lists[SHAMPOO_PRECONDITIONER_LIST] = (
self._preconditioner_config_to_list_cls(
state_lists=state_lists,
Expand Down Expand Up @@ -1653,9 +1654,9 @@ def _post_load_state_dict_hook(optimizer: Optimizer) -> None:
if saved_train_modes:
# Mixed train/eval modes across parameter groups is not supported
# since train() and eval() always operate on all groups uniformly.
assert all(
m == saved_train_modes[0] for m in saved_train_modes
), "Mixed train/eval modes across parameter groups is not supported."
assert all(m == saved_train_modes[0] for m in saved_train_modes), (
"Mixed train/eval modes across parameter groups is not supported."
)
operator.attrgetter("train" if saved_train_modes[0] else "eval")(
optimizer
)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _construct_model(
assert (
sum(param.numel() for param in model.parameters())
== sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2
), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
), (
f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
)
distributed_config.param_to_metadata = compile_fsdp_parameter_metadata(
model
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def _construct_model(
assert (
sum(param.numel() for param in model.parameters())
== sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2
), f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
), (
f"{sum(param.numel() for param in model.parameters())=}, {sum(a * b for a, b in pairwise(model_linear_layers_dims)) // 2=}"
)
distributed_config.param_to_metadata = compile_fsdp_parameter_metadata(
model
)
Expand Down
12 changes: 6 additions & 6 deletions distributed_shampoo/distributor/shampoo_ddp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def all_gather_into_tensor() -> None:
global_buffers = self._global_dist_blocked_buffers

if self._communicate_params:
assert (
len(local_params) == len(blocked_search_directions)
), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_params) == len(blocked_search_directions), (
f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand All @@ -335,9 +335,9 @@ def all_gather_into_tensor() -> None:
)

else:
assert (
len(local_buffers) == len(blocked_search_directions)
), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_buffers) == len(blocked_search_directions), (
f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand Down
12 changes: 6 additions & 6 deletions distributed_shampoo/distributor/shampoo_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def _merge_and_block_gradients(
assert grad is not None

if self._runtime_config.eager_nan_check:
assert torch.isfinite(
grad
).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping."
assert torch.isfinite(grad).all(), (
f"Encountered gradient containing NaN/Inf in parameter with shape {attrgetter('shape')(grad)}. Check your model for numerical instability or consider gradient clipping."
)

# Obtain blocks for each gradient after merging.
blocks_within_grad = multi_dim_split(
Expand Down Expand Up @@ -354,9 +354,9 @@ def update_params(
else self._local_blocked_params
)

assert (
len(blocked_search_directions) == len(target_params)
), f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}."
assert len(blocked_search_directions) == len(target_params), (
f"Expected {len(blocked_search_directions)=} to be equal to {len(target_params)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand Down
6 changes: 3 additions & 3 deletions distributed_shampoo/distributor/shampoo_fsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def _merge_and_block_gradients(
assert flattened_grad is not None

if self._runtime_config.eager_nan_check:
assert torch.isfinite(
flattened_grad
).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping."
assert torch.isfinite(flattened_grad).all(), (
f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping."
)

# Split flattened gradients into valid tensor blocks of the gradient.
split_grads = FSDPDistributor._split_tensor_block_recovery(
Expand Down
25 changes: 12 additions & 13 deletions distributed_shampoo/distributor/shampoo_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def compile_fsdp_parameter_metadata(
shard_param_infos = flat_param._shard_param_infos
sharding_strategy = fsdp_module.sharding_strategy

assert (
flat_param._params is not None
), "flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True "
assert flat_param._params is not None, (
"flat_param._params should not be None! Set the value of `use_orig_params` in FSDP module to True "
)
"would populate flat_param._params."
params = flat_param._params

Expand Down Expand Up @@ -156,13 +156,12 @@ def partition_param_list(
)

assert (
(
unioned_keys := fsdp_params_dict.keys()
| hsdp_params_dict.keys()
| other_params_dict.keys()
)
== original_params_dict.keys()
), f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}"
unioned_keys := fsdp_params_dict.keys()
| hsdp_params_dict.keys()
| other_params_dict.keys()
) == original_params_dict.keys(), (
f"{unioned_keys - original_params_dict.keys()=} {original_params_dict.keys() - unioned_keys=}"
)
for (name1, dict1), (name2, dict2) in itertools.combinations(
(
("fsdp_params_dict", fsdp_params_dict),
Expand All @@ -171,9 +170,9 @@ def partition_param_list(
),
2,
):
assert not (
common_keys := dict1.keys() & dict2.keys()
), f"{common_keys} exist in both {name1} and {name2}!"
assert not (common_keys := dict1.keys() & dict2.keys()), (
f"{common_keys} exist in both {name1} and {name2}!"
)

return (
list(fsdp_params_dict.items()),
Expand Down
18 changes: 9 additions & 9 deletions distributed_shampoo/distributor/shampoo_hsdp_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ def all_gather_into_tensor() -> None:
global_buffers = self._global_dist_blocked_buffers

if self._communicate_params:
assert (
len(local_params) == len(blocked_search_directions)
), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_params) == len(blocked_search_directions), (
f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand All @@ -289,9 +289,9 @@ def all_gather_into_tensor() -> None:
)

else:
assert (
len(local_buffers) == len(blocked_search_directions)
), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_buffers) == len(blocked_search_directions), (
f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand Down Expand Up @@ -564,9 +564,9 @@ def _merge_and_block_gradients(
assert flattened_grad is not None

if self._runtime_config.eager_nan_check:
assert torch.isfinite(
flattened_grad
).all(), f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping."
assert torch.isfinite(flattened_grad).all(), (
f"Encountered gradient containing NaN/Inf in parameter with shape {flattened_grad.shape}. Check your model for numerical instability or consider gradient clipping."
)

# Split flattened gradients into valid tensor blocks of the gradient.
split_grads = HSDPDistributor._split_tensor_block_recovery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def all_gather_into_tensor() -> None:
global_buffers = self._global_dist_blocked_buffers

if self._communicate_params:
assert (
len(local_params) == len(blocked_search_directions)
), f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_params) == len(blocked_search_directions), (
f"Expected {len(local_params)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand All @@ -314,9 +314,9 @@ def all_gather_into_tensor() -> None:
)

else:
assert (
len(local_buffers) == len(blocked_search_directions)
), f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
assert len(local_buffers) == len(blocked_search_directions), (
f"Expected {len(local_buffers)=} to be equal to {len(blocked_search_directions)=}."
)

# torch._foreach only accepts non-empty list
if blocked_search_directions:
Expand Down
4 changes: 3 additions & 1 deletion distributed_shampoo/examples/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _infer_conv_output_shape(
output_shape = []
for input_length in input_shape:
output_length = (input_length - kernel_size + 2 * padding) / stride + 1
assert output_length.is_integer(), f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!"
assert output_length.is_integer(), (
f"Stride {stride} is not compatible with input shape {input_shape}, kernel size {kernel_size} and padding {padding}!"
)
output_shape.append(int(output_length))
return output_shape
12 changes: 6 additions & 6 deletions distributed_shampoo/preconditioner/matrix_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,9 +808,9 @@ def qr_algorithm(
Q = eigenvectors_estimate

# This assertion provides a more clear error message than the internal error message in `torch.mm`, and assertion makes sure that user-side is unable to catch the error.
assert (
Q.dtype == A.dtype
), f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}"
assert Q.dtype == A.dtype, (
f"Q and A must have the same dtype! {Q.dtype=} {A.dtype=}"
)

eigenvalues_estimate = Q.T @ A @ Q
iteration = 0
Expand Down Expand Up @@ -877,9 +877,9 @@ def qr_algorithm(
func=eigh_eigenvalue_decomposition, config=eigendecomposition_config
)(A=A_ridge)
case QREigendecompositionConfig():
assert (
eigenvectors_estimate is not None
), "eigenvectors_estimate should not be None when QR algorithm is used."
assert eigenvectors_estimate is not None, (
"eigenvectors_estimate should not be None when QR algorithm is used."
)
return _assign_function_args_from_config(
func=qr_algorithm, config=eigendecomposition_config
)(A=A_ridge, eigenvectors_estimate=eigenvectors_estimate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1374,14 +1374,14 @@ def _precondition_grad(
) -> Tensor:
# TODO: Need to refactor this function to be more efficient. Ideally eliminate those branches.
# Might consider einsum?
assert (
sum(preconditioned_dims_selector) == len(preconditioner_list)
), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})."
assert sum(preconditioned_dims_selector) == len(preconditioner_list), (
f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})."
)

# Extract all dtypes and assert they are unique
assert (
len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1
), f"All preconditioners must have the same dtype, but found: {unique_dtypes}"
assert len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1, (
f"All preconditioners must have the same dtype, but found: {unique_dtypes}"
)

# Use the single dtype if preconditioners exist, otherwise use grad dtype
target_dtype = next(iter(unique_dtypes), grad.dtype)
Expand Down Expand Up @@ -1566,9 +1566,9 @@ def _get_inverse_exponent(self, dimension: int, order: int) -> float:
return inverse_exponent_override_on_order.get(
dimension, 1 / (2 * max(order, 1))
)
assert isinstance(
inverse_exponent_override_on_order, float
), f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead."
assert isinstance(inverse_exponent_override_on_order, float), (
f"Expected inverse_exponent_override_on_order to be a float or a dict, but got {type(inverse_exponent_override_on_order)} instead."
)
return inverse_exponent_override_on_order

def _create_preconditioned_dims_selector(
Expand Down
6 changes: 3 additions & 3 deletions distributed_shampoo/utils/shampoo_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def __init__(
value.dtype == quantized_dtype for value in self.quantized_value_list
)
self.quantized_dtype = quantized_dtype
assert (
computation_dtype in _FLOAT_DTYPES
), f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!"
assert computation_dtype in _FLOAT_DTYPES, (
f"{computation_dtype=} is not supported! It must be one of {_FLOAT_DTYPES}!"
)
self.computation_dtype = computation_dtype

# All min/max values should be None, or no min/max values are None
Expand Down
18 changes: 9 additions & 9 deletions distributed_shampoo/utils/shampoo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def merge_small_dims(
return (0,)

if isinstance(target_tensor_dimensionality, float):
assert (
target_tensor_dimensionality == math.inf
), f"{target_tensor_dimensionality=} has to be an integer or math.inf."
assert target_tensor_dimensionality == math.inf, (
f"{target_tensor_dimensionality=} has to be an integer or math.inf."
)
return tensor_shape

# Squeeze tensor shape to remove dimension with 1; if all dimensions are 1,
Expand Down Expand Up @@ -151,9 +151,9 @@ def multi_dim_split(tensor: Tensor, split_size: int | float) -> tuple[Tensor, ..

"""
if isinstance(split_size, float):
assert (
split_size == math.inf
), f"{split_size=} has to be an integer or math.inf."
assert split_size == math.inf, (
f"{split_size=} has to be an integer or math.inf."
)
return (tensor,)

return reduce(
Expand Down Expand Up @@ -190,9 +190,9 @@ def compress_list(
Only elements from complete_list where the corresponding selector is True are included.

"""
assert (
len(complete_list) == len(selector)
), f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!"
assert len(complete_list) == len(selector), (
f"Inconsistent lengths between complete_list {len(complete_list)} and selector {len(selector)}!"
)
return tuple(compress(complete_list, selector))


Expand Down
Loading
Loading