Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions init2winit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _run(
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
if jax.process_index() == 0:
logging.info('rng: %s', rng)
makedirs(trial_dir, mode=0o775)
makedirs(trial_dir, mode=0o775) # pyrefly: ignore[unexpected-keyword]
# Set up the metric loggers for host 0.
metrics_logger, init_logger = utils.set_up_loggers(trial_dir, xm_work_unit)
hparams_fname = os.path.join(trial_dir, 'hparams.json')
Expand Down Expand Up @@ -291,11 +291,11 @@ def main(unused_argv):
# CNS2 cell, as it's just the parent for the trial directories.
# The trial directories themselves will get the correct cell placement.
kwargs = {}
makedirs(experiment_dir, mode=0o775, **kwargs)
makedirs(experiment_dir, mode=0o775, **kwargs) # pyrefly: ignore[unexpected-keyword]
log_dir = os.path.join(
experiment_dir, CNS_LOGS_ENCODING, 'logs', str(worker_id)
)
makedirs(log_dir, mode=0o775)
makedirs(log_dir, mode=0o775) # pyrefly: ignore[unexpected-keyword]
log_path = os.path.join(
log_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index())
)
Expand Down
10 changes: 5 additions & 5 deletions init2winit/optimizer_lib/linalg/paterson_stockmeyer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
def _powers(x: T, n: int, product: Callable[[T, T], T]) -> List[T]:
"""Returns the list [x, x², ..., xⁿ]."""
xp = [None] * (n + 1)
xp[1] = x
xp[1] = x # pyrefly: ignore[unsupported-operation]
for j in range(2, n + 1):
# To reduce round-off, compute xʲ as the result of O(log j) mutliplies
xp[j] = product(xp[j // 2], xp[(j + 1) // 2])
return xp[1:]
xp[j] = product(xp[j // 2], xp[(j + 1) // 2]) # pyrefly: ignore[bad-argument-type, unsupported-operation]
return xp[1:] # pyrefly: ignore[bad-return]


def polynomial_no_constant(
Expand Down Expand Up @@ -82,5 +82,5 @@ def polynomial_no_constant(
i = (n + s - 1) // s - 1
y = inner_poly(i)
for i in reversed(range(i)):
y = inner_poly(i) + product(xp[s - 1], y)
return y
y = inner_poly(i) + product(xp[s - 1], y) # pyrefly: ignore[bad-argument-type, unsupported-operation]
return y # pyrefly: ignore[bad-return]
6 changes: 3 additions & 3 deletions init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _xxt_params(m: int, n: int) -> dict[str, int] | None:

def pthroot_xxt(x: chex.Array, scale: chex.Array | None = None) -> chex.Array:
"""(x * scale) @ x.T ."""
return (x if scale is None else x * scale) @ x.T
return (x if scale is None else x * scale) @ x.T # pyrefly: ignore[unsupported-operation]


def _scalar_power(x: chex.Array, n: int) -> chex.Array:
Expand Down Expand Up @@ -78,7 +78,7 @@ def _scalar_inverse_root(x: chex.Array, n: int) -> chex.Array:
r = x ** (1 / n)
# One step of Newton's method to polish the root
r = ((n - 1) / n) * r + (x / n) / _scalar_power(r, n - 1)
return 1 / r
return 1 / r # pyrefly: ignore[bad-return]


@functools.cache
Expand Down Expand Up @@ -219,7 +219,7 @@ def inside_iter2(s, w_sum):

if p > 1:
x += x @ paterson_stockmeyer.polynomial_no_constant(
bc[1:], w, operator.matmul
bc[1:], w, operator.matmul # pyrefly: ignore[bad-argument-type]
)
x = cpm1 * x
return x, y
Expand Down
2 changes: 1 addition & 1 deletion init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_random_diagonal_matrix(self, p):
n = 16
eps = np.finfo(np.float32).eps
rng = np.random.RandomState(seed=37)
s = _random_singular_values(n, eps, rng)
s = _random_singular_values(n, eps, rng) # pyrefly: ignore[bad-argument-type]
exact = s.astype(np.float64) ** (-1 / p)
x = _root(np.diag(s).astype(np.float32), p).astype(np.float64)
# since the matrix is diagonal, the error should be small despite the
Expand Down
2 changes: 1 addition & 1 deletion init2winit/optimizer_lib/sla.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def update_fn(
steps_since_sync=steps_since_sync,
)

return optax.GradientTransformation(init_fn, update_fn)
return optax.GradientTransformation(init_fn, update_fn) # pyrefly: ignore[bad-argument-type]


def super_lookahead(
Expand Down
10 changes: 5 additions & 5 deletions init2winit/projects/optlrschedule/notebook_utils/parquet_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def load_parquet_file(
"""

if file_name:
path = epath.Path(path) / file_name
path = epath.Path(path) / file_name # pyrefly: ignore[bad-assignment]
else:
path = epath.Path(path)
path = epath.Path(path) # pyrefly: ignore[bad-assignment]

# Read the file
with path.open('rb') as in_f:
with path.open('rb') as in_f: # pyrefly: ignore[missing-attribute]
buf = io.BytesIO(in_f.read())
df = pd.read_parquet(buf)

Expand Down Expand Up @@ -105,7 +105,7 @@ def load_all_parquet_files(
if dfs:
# Concat will ignore empty DataFrames properly.
merged_df = pd.concat(dfs, ignore_index=True)
return merged_df
return merged_df # pyrefly: ignore[bad-return]
else:
return pd.DataFrame()

Expand Down Expand Up @@ -148,6 +148,6 @@ def load_all_parquet_files_sequentially(

if dfs:
merged_df = pd.concat(dfs, ignore_index=True)
return merged_df
return merged_df # pyrefly: ignore[bad-return]
else:
return pd.DataFrame()
18 changes: 9 additions & 9 deletions init2winit/projects/optlrschedule/notebook_utils/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def plot_base_lr_heatmap(
)
ax.grid(False) # Turn off grid lines on heatmap

cbar = fig.colorbar(image, ax=ax)
cbar = fig.colorbar(image, ax=ax) # pyrefly: ignore[missing-attribute]
cbar.set_label(metric_label)

ax.set_yticks(np.arange(len(sched_names)))
Expand Down Expand Up @@ -970,13 +970,13 @@ def make_coordinate_descent_plots(

if subplot_cols is not None:
# Dimensions of subplot
n_rows = (n_plots - 1) // n_cols + 1
n_rows = (n_plots - 1) // n_cols + 1 # pyrefly: ignore[unsupported-operation]
if fig is None:
# Create subplots
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(single_figsize[0] * n_cols, single_figsize[1] * n_rows),
figsize=(single_figsize[0] * n_cols, single_figsize[1] * n_rows), # pyrefly: ignore[unsupported-operation]
)
else:
axes = fig.subplots(
Expand All @@ -994,8 +994,8 @@ def make_coordinate_descent_plots(
for plot_idx, sweep_param in enumerate(param_list):
init_param_val = initial_param_dict[sweep_param]
if subplot_cols is not None:
row, col = divmod(plot_idx, n_cols)
ax = axes[row, col]
row, col = divmod(plot_idx, n_cols) # pyrefly: ignore[no-matching-overload]
ax = axes[row, col] # pyrefly: ignore[unsupported-operation]
else:
ax = None
ax = make_single_descent_plot(
Expand All @@ -1006,10 +1006,10 @@ def make_coordinate_descent_plots(
if subplot_cols is None:
return ax_list
else:
for plot_idx in range(n_plots, n_rows * n_cols):
row, col = divmod(plot_idx, n_cols)
fig.delaxes(axes[row, col])
fig.tight_layout()
for plot_idx in range(n_plots, n_rows * n_cols): # pyrefly: ignore[unbound-name, unsupported-operation]
row, col = divmod(plot_idx, n_cols) # pyrefly: ignore[no-matching-overload]
fig.delaxes(axes[row, col]) # pyrefly: ignore[missing-attribute, unsupported-operation]
fig.tight_layout() # pyrefly: ignore[missing-attribute]
return fig


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _compute_linear_interpolation(
# Create a piecewise linear interpolation
lr_values = np.interp(np.arange(total_steps), x_steps, y_points)

return lr_values
return lr_values # pyrefly: ignore[bad-return]

def list_schedule_parameter_keys(self) -> list[str]:
"""List the keys of the schedule parameters."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def update(
if best_score_in_gen < self.internal_state["best_score"]:
self.internal_state["best_score"] = best_score_in_gen
self.internal_state["best_augumented_param"] = (
current_best_augmented_param
current_best_augmented_param # pyrefly: ignore[bad-assignment]
)
self.internal_state["generation"] = gen_idx

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update(
# Update best solution if current generation found a better one
if best_score_in_gen < self._best_score:
self._best_score = best_score_in_gen
self._best_augmented_param = current_best_augmented_param
self._best_augmented_param = current_best_augmented_param # pyrefly: ignore[bad-assignment]
self._generation = gen_idx

def get_best_solution(self) -> Tuple[Dict[str, float], float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, config: ConfigType) -> None:

self.config = copy.deepcopy(config)
if isinstance(self.config, ml_collections.ConfigDict):
self.config = config.to_dict()
self.config = config.to_dict() # pyrefly: ignore[missing-attribute]

# Add optimizer configuration
default_optimizer_config = {
Expand Down
6 changes: 3 additions & 3 deletions init2winit/projects/optlrschedule/workload/cifar10_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _load_dummy_data(
x_train = x_train.astype(jnp.float32) / 255.0
x_test = x_test.astype(jnp.float32) / 255.0

return x_train, y_train, x_test, y_test
return x_train, y_train, x_test, y_test # pyrefly: ignore[bad-return]

def _load_data(
self,
Expand Down Expand Up @@ -330,7 +330,7 @@ def create_train_state(
model = CNN()
params = model.init(init_param_rng, np.ones([1, 32, 32, 3]))

tx = optimizers.get_optimizer_from_config(self.config)
tx = optimizers.get_optimizer_from_config(self.config) # pyrefly: ignore[bad-argument-type]

state = train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=tx
Expand Down Expand Up @@ -524,7 +524,7 @@ def train_and_evaluate_models(
self.make_global_array(batch_labels[start:end]),
)
vmap_states, _ = self.train_step(
vmap_states, batch, schedules[:, global_step]
vmap_states, batch, schedules[:, global_step] # pyrefly: ignore[bad-argument-type]
)

# Evaluate based on configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def get_wikitext103_dataset(
test_path = os.path.join(data_dir, TEST_FILENAME)

# Get TextLineDataset from raw files
train_text_dataset = tf.data.TextLineDataset(train_path)
valid_text_dataset = tf.data.TextLineDataset(valid_path)
test_text_dataset = tf.data.TextLineDataset(test_path)
train_text_dataset = tf.data.TextLineDataset(train_path) # pyrefly: ignore[bad-instantiation]
valid_text_dataset = tf.data.TextLineDataset(valid_path) # pyrefly: ignore[bad-instantiation]
test_text_dataset = tf.data.TextLineDataset(test_path) # pyrefly: ignore[bad-instantiation]

# Tokenize data
tokenizer = get_trained_tokenizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def generate_ntk_matrix(
) -> jnp.ndarray:
"""Generates an NTK (Neural Tangent Kernel) matrix with a known spectrum."""
if spectrum is None:
weight_matrix = random.normal(key, shape=(num_data, num_params))
return weight_matrix @ weight_matrix.T / num_data
weight_matrix = random.normal(key, shape=(num_data, num_params)) # pyrefly: ignore[bad-argument-type]
return weight_matrix @ weight_matrix.T / num_data # pyrefly: ignore[unsupported-operation]
else:
orthogonal_matrix = random.orthogonal(key, len(spectrum))
return orthogonal_matrix @ jnp.diag(spectrum) @ orthogonal_matrix.T
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _load_dummy_data(
)
y_test = x_test

return x_train, y_train, x_test, y_test
return x_train, y_train, x_test, y_test # pyrefly: ignore[bad-return]

def _load_data(
self,
Expand All @@ -280,7 +280,7 @@ def _load_data(
wikitext_103.get_wikitext103_dataset()
)

return (
return ( # pyrefly: ignore[bad-return]
train_dataset['inputs'].astype(jnp.float32),
train_dataset['targets'].astype(jnp.float32),
validation_dataset['inputs'].astype(jnp.float32),
Expand Down Expand Up @@ -325,7 +325,7 @@ def create_train_state(
train=False,
)

tx = optimizers.get_optimizer_from_config(self.config)
tx = optimizers.get_optimizer_from_config(self.config) # pyrefly: ignore[bad-argument-type]

state = train_state.TrainState.create(
apply_fn=transformer.apply, params=variables['params'], tx=tx
Expand Down
4 changes: 2 additions & 2 deletions init2winit/trainer_lib/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class Trainer(base_trainer.BaseTrainer):
"""Default trainer."""

def update(self, batch, rng, metrics_state, training_cost):
def update(self, batch, rng, metrics_state, training_cost): # pyrefly: ignore[bad-override]
"""Single step of the training loop.

Uses the training algorithm's update_params function to get the updated
Expand Down Expand Up @@ -71,7 +71,7 @@ def update(self, batch, rng, metrics_state, training_cost):

new_metrics_state = None
if metrics_state is not None:
new_metrics_state = self._metrics_update_fn(
new_metrics_state = self._metrics_update_fn( # pyrefly: ignore[not-callable]
metrics_state,
step,
cost_value,
Expand Down
12 changes: 6 additions & 6 deletions init2winit/trainer_lib/training_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,8 @@ def get_default_training_hparams(cls, optimizer_name=None, model_name=None):
model_defaults = dict(_MODEL_TRAINING_DEFAULTS[model_name])
model_optimizer = model_defaults.pop('optimizer', 'adam')
# Start with the optimizer's own defaults, then overlay model-specific.
opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(model_optimizer, {}))
opt_defaults.update(model_defaults.pop('opt_hparams', {}))
opt_defaults = dict(_OPTAX_OPTIMIZER_DEFAULTS.get(model_optimizer, {})) # pyrefly: ignore[no-matching-overload]
opt_defaults.update(model_defaults.pop('opt_hparams', {})) # pyrefly: ignore[no-matching-overload]
training_hparams.update({
'optimizer': model_optimizer,
'opt_hparams': opt_defaults,
Expand Down Expand Up @@ -899,7 +899,7 @@ def update_params(
new_params: Pytree of model parameters.
new_model_state: Pytree of model state.
"""
del (
del ( # pyrefly: ignore[unsupported-delete]
workload,
hyperparameters,
param_types,
Expand All @@ -910,7 +910,7 @@ def update_params(
grad_clip = self.hps.opt_hparams.get('grad_clip', None)
# We pass the lr directly because the lr functions from sehedules.py
# have numpy dependencies and can't be jitted.
lr = self._lr_fn(global_step)
lr = self._lr_fn(global_step) # pyrefly: ignore[not-callable]
jitted_update_fn = jax.jit(
optax_update_params_helper,
static_argnames=(
Expand Down Expand Up @@ -1009,11 +1009,11 @@ def get_ema_eval_params(self, optimizer_state, params):
"""
del params # Unused
if isinstance(optimizer_state, optax.InjectStatefulHyperparamsState):
eval_params = optimizer_state.inner_state[0][0].ema
eval_params = optimizer_state.inner_state[0][0].ema # pyrefly: ignore[bad-index]
elif isinstance(
optimizer_state, gradient_accumulator.GradientAccumulatorState
):
eval_params = optimizer_state.base_state.inner_state[0][0].ema
eval_params = optimizer_state.base_state.inner_state[0][0].ema # pyrefly: ignore[missing-attribute]
else:
raise ValueError(
'EMA computation should be the very first transformation in defined'
Expand Down
2 changes: 1 addition & 1 deletion init2winit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run_in_parallel(function, list_of_kwargs_to_function, num_workers):
for f in concurrent.futures.as_completed(futures):
if f.exception():
# Propagate exception to main thread.
raise f.exception()
raise f.exception() # pyrefly: ignore[bad-raise]

return [f.result() for f in futures]

Expand Down