From 6047d208af7838c8d8329604f2e1f365dcb39a78 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Tue, 26 May 2026 09:38:03 -0700 Subject: [PATCH] Fix pytype and runtime errors in Optax causing GitHub CI failures. - Remove invalid pytype directive 'unreachable' in transform.py. - Safely access '__all__' in coverage_check.py to avoid runtime AttributeError on modules lacking it (e.g., optax.contrib). - Disable 'final-error' in pytype for OptaxCoverageCheck class in coverage_check.py. PiperOrigin-RevId: 921516936 --- .github/workflows/tests.yml | 2 +- docs/ext/coverage_check.py | 8 +++++--- optax/_src/linear_algebra.py | 2 ++ optax/_src/transform.py | 3 ++- optax/_src/utils.py | 3 +++ optax/contrib/_dadapt_adamw.py | 1 + optax/contrib/_mechanic.py | 1 + optax/contrib/_momo.py | 8 ++++++-- optax/contrib/_prodigy.py | 1 + optax/contrib/_reduce_on_plateau.py | 8 ++++++-- optax/contrib/_sam.py | 4 ++++ optax/contrib/_schedule_free.py | 4 +++- optax/losses/_classification.py | 14 ++++++++++++++ optax/losses/_ranking.py | 2 ++ optax/losses/_regression.py | 2 ++ optax/losses/_segmentation.py | 16 +++++++++++----- optax/losses/_self_supervised.py | 3 +++ optax/losses/_smoothing.py | 1 + optax/projections/_projections.py | 1 + optax/transforms/_accumulation.py | 1 + optax/transforms/_clipping.py | 3 +++ optax/transforms/_monitoring.py | 3 +++ optax/tree_utils/_random.py | 9 +++++++-- optax/tree_utils/_state_utils.py | 2 ++ pyproject.toml | 1 + 25 files changed, 86 insertions(+), 17 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 93c024257..40714dd17 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -82,7 +82,7 @@ jobs: - name: Check types with pyrefly run: | python3 -m pip install -q pyrefly - python3 -m pyrefly check "optax" + python3 -m pyrefly check ruff-lint: name: "Lint check with ruff" runs-on: "ubuntu-latest" diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py index df8396506..a79e9c1e8 100644 --- a/docs/ext/coverage_check.py +++ b/docs/ext/coverage_check.py @@ -52,12 +52,14 @@ def optax_public_symbols(): """Collect all optax public symbols.""" names = set() for module_name, module in find_internal_python_modules(optax): - for name in module.__all__: - names.add(module_name + "." + name) + module_all = getattr(module, "__all__", None) + if module_all is not None: + for name in module_all: + names.add(module_name + "." + name) return names -class OptaxCoverageCheck(builders.Builder): +class OptaxCoverageCheck(builders.Builder): # pytype: disable=final-error """Builder that checks all public symbols are included.""" name = "coverage_check" diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 4bc1d0275..f37ce938c 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -187,6 +187,7 @@ def matrix_inverse_pth_root( # We use float32 for the matrix inverse pth root. # Switch to f64 if you have hardware that supports it. + # pyrefly: ignore [missing-attribute] matrix_size = matrix.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501 alpha = jnp.asarray(-1.0 / p, jnp.float32) identity = jnp.eye(matrix_size, dtype=jnp.float32) @@ -277,6 +278,7 @@ def _iter_body(state): error = jnp.max(jnp.abs(mat_m - identity)) is_converged = jnp.asarray(convergence, old_mat_h.dtype) # pytype: disable=attribute-error # lax-types # noqa: E501 resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + # pyrefly: ignore [missing-attribute] resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 return resultant_mat_h, error diff --git a/optax/_src/transform.py b/optax/_src/transform.py index d5ca14fd6..ec2e8a2b5 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -477,7 +477,7 @@ def _comb(g, m): elif mode == 'refined': # Keep small values linear, saturate to sign for large values. return jnp.where(jnp.abs(x) < 1.0, x, jnp.sign(x)) - else: # pytype: disable=unreachable + else: raise ValueError( f'Unknown lion mode: {mode}. ' 'It needs to be one of ["hard", "smooth", "refined"].' @@ -1596,6 +1596,7 @@ def _precondition_by_lbfgs( , 1999 """ rhos = weights_memory + # pyrefly: ignore [missing-attribute] memory_size = weights_memory.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501 indices = (memory_idx + jnp.arange(memory_size)) % memory_size diff --git a/optax/_src/utils.py b/optax/_src/utils.py index c3f63d2ef..311358596 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -88,6 +88,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array: NxDxD tensor, with the same contents as `a` but with the diagonal changed to `new_diags`. """ + # pyrefly: ignore [missing-attribute] a_dim, new_diags_dim = len(a.shape), len(new_diags.shape) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 if a_dim != 3: raise ValueError(f'Expected `a` to be a 3D tensor, got {a_dim}D instead') @@ -96,6 +97,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array: f'Expected `new_diags` to be a 2D array, got {new_diags_dim}D instead' ) n, d, d1 = a.shape + # pyrefly: ignore [missing-attribute] n_diags, d_diags = new_diags.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501 if d != d1: raise ValueError( @@ -113,6 +115,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array: indices3 = indices2 # Use numpy array setting + # pyrefly: ignore [missing-attribute] a = a.at[indices1, indices2, indices3].set(new_diags.flatten()) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 return a diff --git a/optax/contrib/_dadapt_adamw.py b/optax/contrib/_dadapt_adamw.py index fb1ef5519..2fc90c40e 100644 --- a/optax/contrib/_dadapt_adamw.py +++ b/optax/contrib/_dadapt_adamw.py @@ -106,6 +106,7 @@ def update_fn( count_inc = numerics.safe_increment(count) bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc) dlr = state.estim_lr * sched * bc + # pyrefly: ignore [missing-attribute] dlr = dlr.astype(numerator_weighted.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 s_weighted = jax.tree.map( lambda sk, eas: sk / (jnp.sqrt(eas) + eps), grad_sum, state.exp_avg_sq diff --git a/optax/contrib/_mechanic.py b/optax/contrib/_mechanic.py index 2149e2e30..1ea4dda80 100644 --- a/optax/contrib/_mechanic.py +++ b/optax/contrib/_mechanic.py @@ -202,6 +202,7 @@ def add_weight_decay(gi, pi): clipped_h = jax.lax.clamp(-state.m, jnp.ones_like(state.m) * h, state.m) betas = jnp.array( [1.0 - 0.1**betai for betai in range(1, num_betas + 1)], + # pyrefly: ignore [missing-attribute] dtype=state.s.dtype, # pytype: disable=attribute-error # jax-arraylike ) diff --git a/optax/contrib/_momo.py b/optax/contrib/_momo.py index c9e4124a4..cb3e0ef7a 100644 --- a/optax/contrib/_momo.py +++ b/optax/contrib/_momo.py @@ -134,7 +134,9 @@ def update_fn( # initialize at first gradient, and loss bt = jnp.where(count == 0, 0.0, beta) barf = bt * state.barf + (1 - bt) * jnp.asarray( - value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + value, + # pyrefly: ignore [missing-attribute] + dtype=state.barf.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) exp_avg = jax.tree.map( lambda ea, g: bt * ea + (1 - bt) * g, state.exp_avg, updates @@ -297,7 +299,9 @@ def update_fn( count = state.count count_inc = numerics.safe_increment(count) barf = b1 * state.barf + (1 - b1) * jnp.asarray( - value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + value, + # pyrefly: ignore [missing-attribute] + dtype=state.barf.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) exp_avg = jax.tree.map( lambda ea, g: b1 * ea + (1 - b1) * g, state.exp_avg, updates diff --git a/optax/contrib/_prodigy.py b/optax/contrib/_prodigy.py index 1c6f1270f..d7d7a88ad 100644 --- a/optax/contrib/_prodigy.py +++ b/optax/contrib/_prodigy.py @@ -127,6 +127,7 @@ def update_fn( estim_lr = state.estim_lr numerator_weighted = state.numerator_weighted bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc) + # pyrefly: ignore [missing-attribute] dlr = jnp.asarray(estim_lr * sched * bc, dtype=estim_lr.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 dg = jax.tree.map(lambda g: estim_lr * g, updates) param_diff = jax.tree.map(lambda p0, p: p0 - p, params0, params) diff --git a/optax/contrib/_reduce_on_plateau.py b/optax/contrib/_reduce_on_plateau.py index d1f13278b..0da3ef9de 100644 --- a/optax/contrib/_reduce_on_plateau.py +++ b/optax/contrib/_reduce_on_plateau.py @@ -175,10 +175,14 @@ def update_fn( count = state.count new_count = numerics.safe_increment(count) new_avg_value = ( - count * state.avg_value + jnp.astype(value, state.avg_value.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + count * state.avg_value + # pyrefly: ignore [missing-attribute] + + jnp.astype(value, state.avg_value.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) / new_count new_state = state._replace( - avg_value=new_avg_value.astype(state.avg_value.dtype), count=new_count # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + # pyrefly: ignore [missing-attribute] + avg_value=new_avg_value.astype(state.avg_value.dtype), + count=new_count, # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) new_state = jax.lax.cond( diff --git a/optax/contrib/_sam.py b/optax/contrib/_sam.py index 3c22545e8..012f8a59f 100644 --- a/optax/contrib/_sam.py +++ b/optax/contrib/_sam.py @@ -221,6 +221,7 @@ def transparent_update_fn( updates = pick_one(last_step, opt_updates, adv_updates) if reset_state: + # pyrefly: ignore [bad-argument-type] initial_state = adv_optimizer.init(params) adv_state = pick_one(last_step, initial_state, adv_state) else: @@ -256,12 +257,14 @@ def opaque_update_fn( ) adv_updates = jax.tree.map(lambda x: -x, adv_updates) + # pyrefly: ignore [bad-argument-type] adv_params = update.apply_updates(adv_params, adv_updates) adv_updates = grad_fn(adv_params, i) if batch_axis_name is not None: adv_updates = jax.lax.pmean(adv_updates, axis_name=batch_axis_name) if reset_state: + # pyrefly: ignore [bad-argument-type] adv_state = adv_optimizer.init(outer_params) updates, opt_state = optimizer.update( @@ -280,4 +283,5 @@ def opaque_update_fn( else: update_fn = transparent_update_fn + # pyrefly: ignore [bad-argument-type] return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/optax/contrib/_schedule_free.py b/optax/contrib/_schedule_free.py index 21bae28f9..8629e6c5d 100644 --- a/optax/contrib/_schedule_free.py +++ b/optax/contrib/_schedule_free.py @@ -162,7 +162,9 @@ def update_fn( lr = learning_rate if callable(learning_rate): lr = jnp.asarray( - learning_rate(state.step_count), dtype=state.max_lr.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + learning_rate(state.step_count), + # pyrefly: ignore [missing-attribute] + dtype=state.max_lr.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) max_lr = jnp.maximum(state.max_lr, lr) # pyrefly: ignore[bad-argument-type] diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 383f244f6..2f88dcafe 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -76,6 +76,7 @@ class is an independent binary prediction and different classes are not `_, 2016 """ utils.check_subdtype(logits, jnp.floating) + # pyrefly: ignore [missing-attribute] labels = jnp.astype(labels, logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 log_p = jax.nn.log_sigmoid(logits) # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable @@ -299,6 +300,7 @@ def softmax_cross_entropy( Added ``axis`` and ``where`` arguments. """ utils.check_subdtype(logits, jnp.floating) + # pyrefly: ignore [missing-attribute] if where is not None and where.ndim != logits.ndim: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 where = jnp.expand_dims(where, axis) # pyrefly: ignore[bad-argument-type] log_probs = jax.nn.log_softmax(logits, axis, where) @@ -386,19 +388,25 @@ def softmax_cross_entropy_with_integer_labels( """ utils.check_subdtype(logits, jnp.floating) utils.check_subdtype(labels, jnp.integer) + # pyrefly: ignore [missing-attribute] if where is not None and where.ndim != logits.ndim: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 where = jnp.expand_dims(where, axis) if isinstance(axis, int): + # pyrefly: ignore [missing-attribute] axis = canonicalize_axis(axis, logits.ndim) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 elif isinstance(axis, tuple): # Move all "feature" dimensions to the end preserving axis ordering and # subsequent flattening "feature" dimensions to a single one. + # pyrefly: ignore [missing-attribute] logit_axis = canonicalize_axes(axis, logits.ndim) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + # pyrefly: ignore [missing-attribute] batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 axis = len(batch_axis) + # pyrefly: ignore [missing-attribute] logits = logits.transpose(batch_axis + logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1,)) if where is not None: + # pyrefly: ignore [missing-attribute] where = where.transpose(batch_axis + logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 where = where.reshape(where.shape[:len(batch_axis)] + (-1,)) else: @@ -449,6 +457,7 @@ def multiclass_hinge_loss( .. versionadded:: 0.2.3 """ + # pyrefly: ignore [missing-attribute] one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1]) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 return jnp.max(scores + 1.0 - one_hot_labels, axis=-1) - _dot_last_dim( scores, one_hot_labels @@ -474,6 +483,7 @@ def multiclass_perceptron_loss( .. versionadded:: 0.2.2 """ + # pyrefly: ignore [missing-attribute] one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1]) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 return jnp.max(scores, axis=-1) - _dot_last_dim(scores, one_hot_labels) @@ -733,7 +743,9 @@ def ctc_loss_with_forward_probs( utils.check_shapes_equal(labels, label_paddings) # pyrefly: ignore[bad-index] utils.check_shapes_equal(logits[..., 0], logit_paddings) + # pyrefly: ignore [missing-attribute] batchsize, unused_maxinputlen, num_classes = logits.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + # pyrefly: ignore [missing-attribute] batchsize_of_labels, maxlabellen = labels.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501 if batchsize_of_labels != batchsize: raise ValueError( @@ -793,6 +805,7 @@ def loop_body(prev, x): return (next_phi, next_emit), (next_phi, next_emit) + # pyrefly: ignore [missing-attribute] xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 _, (logalpha_phi, logalpha_emit) = jax.lax.scan( loop_body, (logalpha_phi_init, logalpha_emit_init), xs @@ -917,6 +930,7 @@ def sigmoid_focal_loss( Added support for continuous labels in `[0, 1]`. """ utils.check_subdtype(logits, jnp.floating) + # pyrefly: ignore [missing-attribute] labels = jnp.astype(labels, logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 # Cross-entropy loss diff --git a/optax/losses/_ranking.py b/optax/losses/_ranking.py index b7ef0aa50..c30aeba19 100644 --- a/optax/losses/_ranking.py +++ b/optax/losses/_ranking.py @@ -103,6 +103,7 @@ def _safe_reduce( # `jnp.sum(loss_fn(reduce_fn=None)) == loss_fn(reduce_fn=jnp.sum)` output = jnp.where(where, output, 0.0) + # pyrefly: ignore [bad-return] return output # pytype: disable=bad-return-type # jax-arraylike @@ -139,6 +140,7 @@ def ranking_softmax_loss( The ranking softmax loss. """ utils.check_subdtype(logits, jnp.floating) + # pyrefly: ignore [missing-attribute] labels = labels.astype(logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 # Applies mask so that masked elements do not count towards the loss. diff --git a/optax/losses/_regression.py b/optax/losses/_regression.py index 635376d97..46bc0520c 100644 --- a/optax/losses/_regression.py +++ b/optax/losses/_regression.py @@ -48,6 +48,7 @@ def squared_error( utils.check_shapes_equal(predictions, targets) # pyrefly: ignore[unsupported-operation] errors = predictions - targets if targets is not None else predictions + # pyrefly: ignore [bad-return] return errors**2 # pytype: disable=bad-return-type # jax-arraylike @@ -133,6 +134,7 @@ def log_cosh( # pyrefly: ignore[unsupported-operation] errors = (predictions - targets) if (targets is not None) else predictions # log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2) + # pyrefly: ignore [missing-attribute, unsupported-operation] return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 diff --git a/optax/losses/_segmentation.py b/optax/losses/_segmentation.py index 52ee99f1b..bed4d4c30 100644 --- a/optax/losses/_segmentation.py +++ b/optax/losses/_segmentation.py @@ -143,8 +143,10 @@ def dice_loss( Volumetric Medical Image Segmentation" (2016). """ + # pyrefly: ignore [missing-attribute] if predictions.ndim == targets.ndim - 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 predictions = predictions[..., None] # pyrefly: ignore[bad-index] + # pyrefly: ignore [missing-attribute] if targets.ndim == predictions.ndim - 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 targets = targets[..., None] # pyrefly: ignore[bad-index] utils.check_shapes_equal(predictions, targets) @@ -164,13 +166,14 @@ def dice_loss( # Convert logits to probabilities probs = predictions if apply_softmax: - probs = ( - jax.nn.sigmoid(predictions) - if predictions.shape[-1] == 1 # pytype: disable=attribute-error # jax-arraylike # noqa: E501 - else jax.nn.softmax(predictions, axis=-1) - ) + # pyrefly: ignore [missing-attribute] + if predictions.shape[-1] == 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 + probs = jax.nn.sigmoid(predictions) + else: + probs = jax.nn.softmax(predictions, axis=-1) # Default behavior: sum over all spatial dimensions (except first/last) + # pyrefly: ignore [bad-assignment, missing-attribute] axis = tuple(range(1, probs.ndim - 1)) if axis is None else axis # pytype: disable=attribute-error # jax-arraylike # noqa: E501 # Compute intersection and sums over specified axes @@ -196,6 +199,7 @@ def dice_loss( dice_l = dice_l * class_weights # Handle background class ignoring + # pyrefly: ignore [missing-attribute] if ignore_background and probs.shape[-1] > 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 # Exclude the first class (background) from loss computation dice_l = dice_l[..., 1:] @@ -238,6 +242,7 @@ def multiclass_generalized_dice_loss( utils.check_shapes_equal(predictions, targets) # Compute class frequencies for weighting + # pyrefly: ignore [missing-attribute] class_frequencies = jnp.sum(targets, axis=tuple(range(targets.ndim - 1))) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 # Compute weights as inverse of squared frequencies @@ -280,6 +285,7 @@ def binary_dice_loss( Loss values of shape [...] (batch dimensions only). """ # Ensure both have channel dimension + # pyrefly: ignore [missing-attribute] if predictions.ndim == targets.ndim and predictions.shape[-1] != 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 predictions = predictions[..., None] # pyrefly: ignore[bad-index] targets = targets[..., None] # pyrefly: ignore[bad-index] diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index c09c159eb..269a829c2 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -79,9 +79,11 @@ def ntxent( .. versionadded:: 0.2.3 """ utils.check_subdtype(embeddings, jnp.floating) + # pyrefly: ignore [missing-attribute] if labels.shape[0] != embeddings.shape[0]: # pytype: disable=attribute-error # jax-arraylike # noqa: E501 raise ValueError( 'Labels and embeddings must have the same leading dimension, found' + # pyrefly: ignore [missing-attribute] f' {labels.shape[0]} for labels and {embeddings.shape[0]} for' # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ' embeddings.' ) @@ -91,6 +93,7 @@ def ntxent( _regression.cosine_similarity( embeddings[None, :, :], # pyrefly: ignore[bad-index] embeddings[:, None, :], # pyrefly: ignore[bad-index] + # pyrefly: ignore [missing-attribute] epsilon=jnp.finfo(embeddings.dtype).eps, # pytype: disable=attribute-error # jax-arraylike # noqa: E501 ) / temperature diff --git a/optax/losses/_smoothing.py b/optax/losses/_smoothing.py index 21122eaca..277088922 100644 --- a/optax/losses/_smoothing.py +++ b/optax/losses/_smoothing.py @@ -52,4 +52,5 @@ def smooth_labels( num_categories = jnp.size(labels, axis) else: num_categories = jnp.sum(where, axis, keepdims=True) + # pyrefly: ignore [bad-return] return (1.0 - alpha) * labels + alpha / num_categories # pytype: disable=bad-return-type # jax-arraylike # noqa: E501 diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index a344715de..97209c9d8 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -94,6 +94,7 @@ def projection_hypercube(tree: Any, scale: Any = 1) -> Any: def _projection_unit_simplex(values: jax.typing.ArrayLike) -> jax.Array: """Projection onto the unit simplex.""" s = 1 + # pyrefly: ignore [missing-attribute] n_features = values.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501 u = jnp.sort(values)[::-1] cumsum_u = jnp.cumsum(u) diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 474cbe74c..08931a2ca 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -180,6 +180,7 @@ def skip_not_finite( not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates) num_not_finite = optax.tree.sum(not_finite) should_skip = num_not_finite > 0 # pyrefly: ignore[unsupported-operation] + # pyrefly: ignore [bad-return] return should_skip, { # pytype: disable=bad-return-type 'should_skip': should_skip, 'num_not_finite': num_not_finite, diff --git a/optax/transforms/_clipping.py b/optax/transforms/_clipping.py index ebf31b98e..824998e53 100644 --- a/optax/transforms/_clipping.py +++ b/optax/transforms/_clipping.py @@ -281,6 +281,7 @@ def unitwise_norm( # Note that this assumes parameters with a shape of length 3 are multihead # linear parameters--if you wish to apply AGC to 1D convs, you may need # to modify this line. + # pyrefly: ignore [missing-attribute] elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear # pytype: disable=attribute-error # jax-arraylike # noqa: E501 squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) elif x.ndim == 4: # Conv kernels of shape HWIO # pytype: disable=attribute-error # jax-arraylike # noqa: E501 @@ -289,9 +290,11 @@ def unitwise_norm( squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2, 3), keepdims=True) else: raise ValueError( + # pyrefly: ignore [missing-attribute] f"Expected parameter with shape in {1, 2, 3, 4, 5}, got {x.shape}. " # pytype: disable=attribute-error # jax-arraylike # noqa: E501 "Use axis parameter to specify reduction axes for other shapes." ) + # pyrefly: ignore [missing-attribute] return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) # pytype: disable=attribute-error # jax-arraylike # noqa: E501 diff --git a/optax/transforms/_monitoring.py b/optax/transforms/_monitoring.py index 321de9573..5e227d29a 100644 --- a/optax/transforms/_monitoring.py +++ b/optax/transforms/_monitoring.py @@ -145,6 +145,7 @@ def monitor( measures_[measure_name] = base.with_extra_args_support(measure_) else: measures_[measure_name] = base.with_extra_args_support(measure) + # pyrefly: ignore [bad-assignment] measures = measures_ measure_names = tuple(measures.keys()) @@ -152,6 +153,7 @@ def init(params: base.Params) -> MonitorState: measurements = {} measure_states = [] for measure_name in measure_names: + # pyrefly: ignore [missing-attribute] measure_states.append(measures[measure_name].init(params)) return MonitorState(measurements, tuple(measure_states)) @@ -164,6 +166,7 @@ def update( measurements = {} new_measure_states = [] for i, measure_name in enumerate(measure_names): + # pyrefly: ignore [missing-attribute] measurement, measure_state = measures[measure_name].update( updates, state.measure_states[i], diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index 560ebdddd..3eb615fb3 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -80,8 +80,13 @@ def tree_random_like( ) # pytype: disable=wrong-arg-count return jax.tree.map( # pytype: disable=wrong-keyword-args - lambda leaf, key: sampler_(key, leaf.shape, dtype or leaf.dtype, - out_sharding=jax.typeof(leaf).sharding), + lambda leaf, key: sampler_( + key, + leaf.shape, + dtype or leaf.dtype, + # pyrefly: ignore [bad-argument-count, unexpected-keyword] + out_sharding=jax.typeof(leaf).sharding, + ), # pytype: enable=wrong-keyword-args target_tree, keys_tree, diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 2028805c8..9550aefca 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -495,10 +495,12 @@ def _replace(path: _KeyPath, node: Any) -> Any: if isinstance(child, (dict, list, tuple)): # If the child is itself a pytree, further search in the child to # replace the given value + # pyrefly: ignore [no-matching-overload] new_children_with_keys.update({key: _replace(child_path, child)}) else: # If the child is just a leaf that does not contain the key or # satisfies the filtering operation, just return the child. + # pyrefly: ignore [no-matching-overload] new_children_with_keys.update({key: child}) return _set_children(node, new_children_with_keys) diff --git a/pyproject.toml b/pyproject.toml index b389670a6..5da9fadd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,4 +139,5 @@ include = '\.pyi?$' [tool.pyrefly] project-includes = ["optax"] +project-excludes = ["**/*_test.py"] python-version = "3.12"