From afd5a84dbb865f41bf8d6b75181d5b421d73806b Mon Sep 17 00:00:00 2001 From: Hana Joo Date: Thu, 2 Jul 2026 05:48:14 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 941652029 --- init2winit/main.py | 6 +++--- .../linalg/paterson_stockmeyer.py | 10 +++++----- .../optimizer_lib/linalg/pth_inv_root_rmn.py | 6 +++--- .../linalg/pth_inv_root_rmn_test.py | 2 +- init2winit/optimizer_lib/sla.py | 2 +- .../notebook_utils/parquet_util.py | 10 +++++----- .../optlrschedule/notebook_utils/plot_util.py | 18 +++++++++--------- .../twopointslinear_schedule_family.py | 2 +- .../search_algorithm/grid_search.py | 2 +- .../search_algorithm/random_search.py | 2 +- .../optlrschedule/workload/base_workload.py | 2 +- .../optlrschedule/workload/cifar10_cnn.py | 6 +++--- .../workload/datasets/wikitext_103.py | 6 +++--- .../workload/linear_regression.py | 4 ++-- .../workload/wikitext103_transformer.py | 6 +++--- init2winit/trainer_lib/trainer.py | 4 ++-- init2winit/trainer_lib/training_algorithm.py | 12 ++++++------ init2winit/utils.py | 2 +- 18 files changed, 51 insertions(+), 51 deletions(-) diff --git a/init2winit/main.py b/init2winit/main.py index bfd2f54b..b430763d 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -203,7 +203,7 @@ def _run( meta_data = {'worker_id': worker_id, 'status': 'incomplete'} if jax.process_index() == 0: logging.info('rng: %s', rng) - makedirs(trial_dir, mode=0o775) + makedirs(trial_dir, mode=0o775) # pyrefly: ignore[unexpected-keyword] # Set up the metric loggers for host 0. metrics_logger, init_logger = utils.set_up_loggers(trial_dir, xm_work_unit) hparams_fname = os.path.join(trial_dir, 'hparams.json') @@ -291,11 +291,11 @@ def main(unused_argv): # CNS2 cell, as it's just the parent for the trial directories. # The trial directories themselves will get the correct cell placement. kwargs = {} - makedirs(experiment_dir, mode=0o775, **kwargs) + makedirs(experiment_dir, mode=0o775, **kwargs) # pyrefly: ignore[unexpected-keyword] log_dir = os.path.join( experiment_dir, CNS_LOGS_ENCODING, 'logs', str(worker_id) ) - makedirs(log_dir, mode=0o775) + makedirs(log_dir, mode=0o775) # pyrefly: ignore[unexpected-keyword] log_path = os.path.join( log_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index()) ) diff --git a/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py b/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py index 5a285e4d..b453ef70 100644 --- a/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py +++ b/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py @@ -25,11 +25,11 @@ def _powers(x: T, n: int, product: Callable[[T, T], T]) -> List[T]: """Returns the list [x, x², ..., xⁿ].""" xp = [None] * (n + 1) - xp[1] = x + xp[1] = x # pyrefly: ignore[unsupported-operation] for j in range(2, n + 1): # To reduce round-off, compute xʲ as the result of O(log j) mutliplies - xp[j] = product(xp[j // 2], xp[(j + 1) // 2]) - return xp[1:] + xp[j] = product(xp[j // 2], xp[(j + 1) // 2]) # pyrefly: ignore[bad-argument-type, unsupported-operation] + return xp[1:] # pyrefly: ignore[bad-return] def polynomial_no_constant( @@ -82,5 +82,5 @@ def polynomial_no_constant( i = (n + s - 1) // s - 1 y = inner_poly(i) for i in reversed(range(i)): - y = inner_poly(i) + product(xp[s - 1], y) - return y + y = inner_poly(i) + product(xp[s - 1], y) # pyrefly: ignore[bad-argument-type, unsupported-operation] + return y # pyrefly: ignore[bad-return] diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py index 8c9e4505..40170356 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py @@ -47,7 +47,7 @@ def _xxt_params(m: int, n: int) -> dict[str, int] | None: def pthroot_xxt(x: chex.Array, scale: chex.Array | None = None) -> chex.Array: """(x * scale) @ x.T .""" - return (x if scale is None else x * scale) @ x.T + return (x if scale is None else x * scale) @ x.T # pyrefly: ignore[unsupported-operation] def _scalar_power(x: chex.Array, n: int) -> chex.Array: @@ -78,7 +78,7 @@ def _scalar_inverse_root(x: chex.Array, n: int) -> chex.Array: r = x ** (1 / n) # One step of Newton's method to polish the root r = ((n - 1) / n) * r + (x / n) / _scalar_power(r, n - 1) - return 1 / r + return 1 / r # pyrefly: ignore[bad-return] @functools.cache @@ -219,7 +219,7 @@ def inside_iter2(s, w_sum): if p > 1: x += x @ paterson_stockmeyer.polynomial_no_constant( - bc[1:], w, operator.matmul + bc[1:], w, operator.matmul # pyrefly: ignore[bad-argument-type] ) x = cpm1 * x return x, y diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py index b5261150..2e534b78 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py @@ -138,7 +138,7 @@ def test_random_diagonal_matrix(self, p): n = 16 eps = np.finfo(np.float32).eps rng = np.random.RandomState(seed=37) - s = _random_singular_values(n, eps, rng) + s = _random_singular_values(n, eps, rng) # pyrefly: ignore[bad-argument-type] exact = s.astype(np.float64) ** (-1 / p) x = _root(np.diag(s).astype(np.float32), p).astype(np.float64) # since the matrix is diagonal, the error should be small despite the diff --git a/init2winit/optimizer_lib/sla.py b/init2winit/optimizer_lib/sla.py index bc9df7db..11ef983f 100644 --- a/init2winit/optimizer_lib/sla.py +++ b/init2winit/optimizer_lib/sla.py @@ -119,7 +119,7 @@ def update_fn( steps_since_sync=steps_since_sync, ) - return optax.GradientTransformation(init_fn, update_fn) + return optax.GradientTransformation(init_fn, update_fn) # pyrefly: ignore[bad-argument-type] def super_lookahead( diff --git a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py index 121e4fa1..32a9b08e 100644 --- a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py +++ b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py @@ -47,12 +47,12 @@ def load_parquet_file( """ if file_name: - path = epath.Path(path) / file_name + path = epath.Path(path) / file_name # pyrefly: ignore[bad-assignment] else: - path = epath.Path(path) + path = epath.Path(path) # pyrefly: ignore[bad-assignment] # Read the file - with path.open('rb') as in_f: + with path.open('rb') as in_f: # pyrefly: ignore[missing-attribute] buf = io.BytesIO(in_f.read()) df = pd.read_parquet(buf) @@ -105,7 +105,7 @@ def load_all_parquet_files( if dfs: # Concat will ignore empty DataFrames properly. merged_df = pd.concat(dfs, ignore_index=True) - return merged_df + return merged_df # pyrefly: ignore[bad-return] else: return pd.DataFrame() @@ -148,6 +148,6 @@ def load_all_parquet_files_sequentially( if dfs: merged_df = pd.concat(dfs, ignore_index=True) - return merged_df + return merged_df # pyrefly: ignore[bad-return] else: return pd.DataFrame() diff --git a/init2winit/projects/optlrschedule/notebook_utils/plot_util.py b/init2winit/projects/optlrschedule/notebook_utils/plot_util.py index 4de7aa8c..2fcdab05 100644 --- a/init2winit/projects/optlrschedule/notebook_utils/plot_util.py +++ b/init2winit/projects/optlrschedule/notebook_utils/plot_util.py @@ -790,7 +790,7 @@ def plot_base_lr_heatmap( ) ax.grid(False) # Turn off grid lines on heatmap - cbar = fig.colorbar(image, ax=ax) + cbar = fig.colorbar(image, ax=ax) # pyrefly: ignore[missing-attribute] cbar.set_label(metric_label) ax.set_yticks(np.arange(len(sched_names))) @@ -970,13 +970,13 @@ def make_coordinate_descent_plots( if subplot_cols is not None: # Dimensions of subplot - n_rows = (n_plots - 1) // n_cols + 1 + n_rows = (n_plots - 1) // n_cols + 1 # pyrefly: ignore[unsupported-operation] if fig is None: # Create subplots fig, axes = plt.subplots( n_rows, n_cols, - figsize=(single_figsize[0] * n_cols, single_figsize[1] * n_rows), + figsize=(single_figsize[0] * n_cols, single_figsize[1] * n_rows), # pyrefly: ignore[unsupported-operation] ) else: axes = fig.subplots( @@ -994,8 +994,8 @@ def make_coordinate_descent_plots( for plot_idx, sweep_param in enumerate(param_list): init_param_val = initial_param_dict[sweep_param] if subplot_cols is not None: - row, col = divmod(plot_idx, n_cols) - ax = axes[row, col] + row, col = divmod(plot_idx, n_cols) # pyrefly: ignore[no-matching-overload] + ax = axes[row, col] # pyrefly: ignore[unsupported-operation] else: ax = None ax = make_single_descent_plot( @@ -1006,10 +1006,10 @@ def make_coordinate_descent_plots( if subplot_cols is None: return ax_list else: - for plot_idx in range(n_plots, n_rows * n_cols): - row, col = divmod(plot_idx, n_cols) - fig.delaxes(axes[row, col]) - fig.tight_layout() + for plot_idx in range(n_plots, n_rows * n_cols): # pyrefly: ignore[unbound-name, unsupported-operation] + row, col = divmod(plot_idx, n_cols) # pyrefly: ignore[no-matching-overload] + fig.delaxes(axes[row, col]) # pyrefly: ignore[missing-attribute, unsupported-operation] + fig.tight_layout() # pyrefly: ignore[missing-attribute] return fig diff --git a/init2winit/projects/optlrschedule/scheduler/twopointslinear_schedule_family.py b/init2winit/projects/optlrschedule/scheduler/twopointslinear_schedule_family.py index 5ca70553..3f3509c6 100644 --- a/init2winit/projects/optlrschedule/scheduler/twopointslinear_schedule_family.py +++ b/init2winit/projects/optlrschedule/scheduler/twopointslinear_schedule_family.py @@ -130,7 +130,7 @@ def _compute_linear_interpolation( # Create a piecewise linear interpolation lr_values = np.interp(np.arange(total_steps), x_steps, y_points) - return lr_values + return lr_values # pyrefly: ignore[bad-return] def list_schedule_parameter_keys(self) -> list[str]: """List the keys of the schedule parameters.""" diff --git a/init2winit/projects/optlrschedule/search_algorithm/grid_search.py b/init2winit/projects/optlrschedule/search_algorithm/grid_search.py index b199e34f..600f9fae 100644 --- a/init2winit/projects/optlrschedule/search_algorithm/grid_search.py +++ b/init2winit/projects/optlrschedule/search_algorithm/grid_search.py @@ -160,7 +160,7 @@ def update( if best_score_in_gen < self.internal_state["best_score"]: self.internal_state["best_score"] = best_score_in_gen self.internal_state["best_augumented_param"] = ( - current_best_augmented_param + current_best_augmented_param # pyrefly: ignore[bad-assignment] ) self.internal_state["generation"] = gen_idx diff --git a/init2winit/projects/optlrschedule/search_algorithm/random_search.py b/init2winit/projects/optlrschedule/search_algorithm/random_search.py index 950f0ffa..1107a839 100644 --- a/init2winit/projects/optlrschedule/search_algorithm/random_search.py +++ b/init2winit/projects/optlrschedule/search_algorithm/random_search.py @@ -181,7 +181,7 @@ def update( # Update best solution if current generation found a better one if best_score_in_gen < self._best_score: self._best_score = best_score_in_gen - self._best_augmented_param = current_best_augmented_param + self._best_augmented_param = current_best_augmented_param # pyrefly: ignore[bad-assignment] self._generation = gen_idx def get_best_solution(self) -> Tuple[Dict[str, float], float]: diff --git a/init2winit/projects/optlrschedule/workload/base_workload.py b/init2winit/projects/optlrschedule/workload/base_workload.py index 5152c8b7..2471ced7 100644 --- a/init2winit/projects/optlrschedule/workload/base_workload.py +++ b/init2winit/projects/optlrschedule/workload/base_workload.py @@ -72,7 +72,7 @@ def __init__(self, config: ConfigType) -> None: self.config = copy.deepcopy(config) if isinstance(self.config, ml_collections.ConfigDict): - self.config = config.to_dict() + self.config = config.to_dict() # pyrefly: ignore[missing-attribute] # Add optimizer configuration default_optimizer_config = { diff --git a/init2winit/projects/optlrschedule/workload/cifar10_cnn.py b/init2winit/projects/optlrschedule/workload/cifar10_cnn.py index 9035af26..564ddf6b 100644 --- a/init2winit/projects/optlrschedule/workload/cifar10_cnn.py +++ b/init2winit/projects/optlrschedule/workload/cifar10_cnn.py @@ -281,7 +281,7 @@ def _load_dummy_data( x_train = x_train.astype(jnp.float32) / 255.0 x_test = x_test.astype(jnp.float32) / 255.0 - return x_train, y_train, x_test, y_test + return x_train, y_train, x_test, y_test # pyrefly: ignore[bad-return] def _load_data( self, @@ -330,7 +330,7 @@ def create_train_state( model = CNN() params = model.init(init_param_rng, np.ones([1, 32, 32, 3])) - tx = optimizers.get_optimizer_from_config(self.config) + tx = optimizers.get_optimizer_from_config(self.config) # pyrefly: ignore[bad-argument-type] state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx @@ -524,7 +524,7 @@ def train_and_evaluate_models( self.make_global_array(batch_labels[start:end]), ) vmap_states, _ = self.train_step( - vmap_states, batch, schedules[:, global_step] + vmap_states, batch, schedules[:, global_step] # pyrefly: ignore[bad-argument-type] ) # Evaluate based on configuration diff --git a/init2winit/projects/optlrschedule/workload/datasets/wikitext_103.py b/init2winit/projects/optlrschedule/workload/datasets/wikitext_103.py index 022ea892..341350c1 100644 --- a/init2winit/projects/optlrschedule/workload/datasets/wikitext_103.py +++ b/init2winit/projects/optlrschedule/workload/datasets/wikitext_103.py @@ -146,9 +146,9 @@ def get_wikitext103_dataset( test_path = os.path.join(data_dir, TEST_FILENAME) # Get TextLineDataset from raw files - train_text_dataset = tf.data.TextLineDataset(train_path) - valid_text_dataset = tf.data.TextLineDataset(valid_path) - test_text_dataset = tf.data.TextLineDataset(test_path) + train_text_dataset = tf.data.TextLineDataset(train_path) # pyrefly: ignore[bad-instantiation] + valid_text_dataset = tf.data.TextLineDataset(valid_path) # pyrefly: ignore[bad-instantiation] + test_text_dataset = tf.data.TextLineDataset(test_path) # pyrefly: ignore[bad-instantiation] # Tokenize data tokenizer = get_trained_tokenizer( diff --git a/init2winit/projects/optlrschedule/workload/linear_regression.py b/init2winit/projects/optlrschedule/workload/linear_regression.py index 8d7a1cee..c552fe63 100644 --- a/init2winit/projects/optlrschedule/workload/linear_regression.py +++ b/init2winit/projects/optlrschedule/workload/linear_regression.py @@ -57,8 +57,8 @@ def generate_ntk_matrix( ) -> jnp.ndarray: """Generates an NTK (Neural Tangent Kernel) matrix with a known spectrum.""" if spectrum is None: - weight_matrix = random.normal(key, shape=(num_data, num_params)) - return weight_matrix @ weight_matrix.T / num_data + weight_matrix = random.normal(key, shape=(num_data, num_params)) # pyrefly: ignore[bad-argument-type] + return weight_matrix @ weight_matrix.T / num_data # pyrefly: ignore[unsupported-operation] else: orthogonal_matrix = random.orthogonal(key, len(spectrum)) return orthogonal_matrix @ jnp.diag(spectrum) @ orthogonal_matrix.T diff --git a/init2winit/projects/optlrschedule/workload/wikitext103_transformer.py b/init2winit/projects/optlrschedule/workload/wikitext103_transformer.py index c1a9557a..0c640c60 100644 --- a/init2winit/projects/optlrschedule/workload/wikitext103_transformer.py +++ b/init2winit/projects/optlrschedule/workload/wikitext103_transformer.py @@ -258,7 +258,7 @@ def _load_dummy_data( ) y_test = x_test - return x_train, y_train, x_test, y_test + return x_train, y_train, x_test, y_test # pyrefly: ignore[bad-return] def _load_data( self, @@ -280,7 +280,7 @@ def _load_data( wikitext_103.get_wikitext103_dataset() ) - return ( + return ( # pyrefly: ignore[bad-return] train_dataset['inputs'].astype(jnp.float32), train_dataset['targets'].astype(jnp.float32), validation_dataset['inputs'].astype(jnp.float32), @@ -325,7 +325,7 @@ def create_train_state( train=False, ) - tx = optimizers.get_optimizer_from_config(self.config) + tx = optimizers.get_optimizer_from_config(self.config) # pyrefly: ignore[bad-argument-type] state = train_state.TrainState.create( apply_fn=transformer.apply, params=variables['params'], tx=tx diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index 272b4fa4..11b47ef8 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -27,7 +27,7 @@ class Trainer(base_trainer.BaseTrainer): """Default trainer.""" - def update(self, batch, rng, metrics_state, training_cost): + def update(self, batch, rng, metrics_state, training_cost): # pyrefly: ignore[bad-override] """Single step of the training loop. Uses the training algorithm's update_params function to get the updated @@ -71,7 +71,7 @@ def update(self, batch, rng, metrics_state, training_cost): new_metrics_state = None if metrics_state is not None: - new_metrics_state = self._metrics_update_fn( + new_metrics_state = self._metrics_update_fn( # pyrefly: ignore[not-callable] metrics_state, step, cost_value, diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index 72f2164b..bf9087f2 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -830,8 +830,8 @@ def get_default_training_hparams(cls, optimizer_name=None, model_name=None): model_defaults = dict(_MODEL_TRAINING_DEFAULTS[model_name]) model_optimizer = model_defaults.pop('optimizer', 'adam') # Start with the optimizer's own defaults, then overlay model-specific. - opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(model_optimizer, {})) - opt_defaults.update(model_defaults.pop('opt_hparams', {})) + opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(model_optimizer, {})) # pyrefly: ignore[no-matching-overload] + opt_defaults.update(model_defaults.pop('opt_hparams', {})) # pyrefly: ignore[no-matching-overload] training_hparams.update({ 'optimizer': model_optimizer, 'opt_hparams': opt_defaults, @@ -899,7 +899,7 @@ def update_params( new_params: Pytree of model parameters. new_model_state: Pytree of model state. """ - del ( + del ( # pyrefly: ignore[unsupported-delete] workload, hyperparameters, param_types, @@ -910,7 +910,7 @@ def update_params( grad_clip = self.hps.opt_hparams.get('grad_clip', None) # We pass the lr directly because the lr functions from sehedules.py # have numpy dependencies and can't be jitted. - lr = self._lr_fn(global_step) + lr = self._lr_fn(global_step) # pyrefly: ignore[not-callable] jitted_update_fn = jax.jit( optax_update_params_helper, static_argnames=( @@ -1009,11 +1009,11 @@ def get_ema_eval_params(self, optimizer_state, params): """ del params # Unused if isinstance(optimizer_state, optax.InjectStatefulHyperparamsState): - eval_params = optimizer_state.inner_state[0][0].ema + eval_params = optimizer_state.inner_state[0][0].ema # pyrefly: ignore[bad-index] elif isinstance( optimizer_state, gradient_accumulator.GradientAccumulatorState ): - eval_params = optimizer_state.base_state.inner_state[0][0].ema + eval_params = optimizer_state.base_state.inner_state[0][0].ema # pyrefly: ignore[missing-attribute] else: raise ValueError( 'EMA computation should be the very first transformation in defined' diff --git a/init2winit/utils.py b/init2winit/utils.py index c9dad95a..c97a7489 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -153,7 +153,7 @@ def run_in_parallel(function, list_of_kwargs_to_function, num_workers): for f in concurrent.futures.as_completed(futures): if f.exception(): # Propagate exception to main thread. - raise f.exception() + raise f.exception() # pyrefly: ignore[bad-raise] return [f.result() for f in futures]