Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
- name: Check types with pyrefly
run: |
python3 -m pip install -q pyrefly
python3 -m pyrefly check "optax"
python3 -m pyrefly check
ruff-lint:
name: "Lint check with ruff"
runs-on: "ubuntu-latest"
Expand Down
8 changes: 5 additions & 3 deletions docs/ext/coverage_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def optax_public_symbols():
"""Collect all optax public symbols."""
names = set()
for module_name, module in find_internal_python_modules(optax):
for name in module.__all__:
names.add(module_name + "." + name)
module_all = getattr(module, "__all__", None)
if module_all is not None:
for name in module_all:
names.add(module_name + "." + name)
return names


class OptaxCoverageCheck(builders.Builder):
class OptaxCoverageCheck(builders.Builder): # pytype: disable=final-error
"""Builder that checks all public symbols are included."""

name = "coverage_check"
Expand Down
2 changes: 2 additions & 0 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def matrix_inverse_pth_root(

# We use float32 for the matrix inverse pth root.
# Switch to f64 if you have hardware that supports it.
# pyrefly: ignore [missing-attribute]
matrix_size = matrix.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501
alpha = jnp.asarray(-1.0 / p, jnp.float32)
identity = jnp.eye(matrix_size, dtype=jnp.float32)
Expand Down Expand Up @@ -277,6 +278,7 @@ def _iter_body(state):
error = jnp.max(jnp.abs(mat_m - identity))
is_converged = jnp.asarray(convergence, old_mat_h.dtype) # pytype: disable=attribute-error # lax-types # noqa: E501
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
# pyrefly: ignore [missing-attribute]
resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
return resultant_mat_h, error

Expand Down
3 changes: 2 additions & 1 deletion optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _comb(g, m):
elif mode == 'refined':
# Keep small values linear, saturate to sign for large values.
return jnp.where(jnp.abs(x) < 1.0, x, jnp.sign(x))
else: # pytype: disable=unreachable
else:
raise ValueError(
f'Unknown lion mode: {mode}. '
'It needs to be one of ["hard", "smooth", "refined"].'
Expand Down Expand Up @@ -1596,6 +1596,7 @@ def _precondition_by_lbfgs(
, 1999
"""
rhos = weights_memory
# pyrefly: ignore [missing-attribute]
memory_size = weights_memory.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501
indices = (memory_idx + jnp.arange(memory_size)) % memory_size

Expand Down
3 changes: 3 additions & 0 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array:
NxDxD tensor, with the same contents as `a` but with the diagonal
changed to `new_diags`.
"""
# pyrefly: ignore [missing-attribute]
a_dim, new_diags_dim = len(a.shape), len(new_diags.shape) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
if a_dim != 3:
raise ValueError(f'Expected `a` to be a 3D tensor, got {a_dim}D instead')
Expand All @@ -96,6 +97,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array:
f'Expected `new_diags` to be a 2D array, got {new_diags_dim}D instead'
)
n, d, d1 = a.shape
# pyrefly: ignore [missing-attribute]
n_diags, d_diags = new_diags.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501
if d != d1:
raise ValueError(
Expand All @@ -113,6 +115,7 @@ def set_diags(a: jax.Array, new_diags: jax.typing.ArrayLike) -> jax.Array:
indices3 = indices2

# Use numpy array setting
# pyrefly: ignore [missing-attribute]
a = a.at[indices1, indices2, indices3].set(new_diags.flatten()) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
return a

Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_dadapt_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def update_fn(
count_inc = numerics.safe_increment(count)
bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc)
dlr = state.estim_lr * sched * bc
# pyrefly: ignore [missing-attribute]
dlr = dlr.astype(numerator_weighted.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
s_weighted = jax.tree.map(
lambda sk, eas: sk / (jnp.sqrt(eas) + eps), grad_sum, state.exp_avg_sq
Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_mechanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def add_weight_decay(gi, pi):
clipped_h = jax.lax.clamp(-state.m, jnp.ones_like(state.m) * h, state.m)
betas = jnp.array(
[1.0 - 0.1**betai for betai in range(1, num_betas + 1)],
# pyrefly: ignore [missing-attribute]
dtype=state.s.dtype, # pytype: disable=attribute-error # jax-arraylike
)

Expand Down
8 changes: 6 additions & 2 deletions optax/contrib/_momo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def update_fn(
# initialize at first gradient, and loss
bt = jnp.where(count == 0, 0.0, beta)
barf = bt * state.barf + (1 - bt) * jnp.asarray(
value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501
value,
# pyrefly: ignore [missing-attribute]
dtype=state.barf.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501
)
exp_avg = jax.tree.map(
lambda ea, g: bt * ea + (1 - bt) * g, state.exp_avg, updates
Expand Down Expand Up @@ -297,7 +299,9 @@ def update_fn(
count = state.count
count_inc = numerics.safe_increment(count)
barf = b1 * state.barf + (1 - b1) * jnp.asarray(
value, dtype=state.barf.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501
value,
# pyrefly: ignore [missing-attribute]
dtype=state.barf.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501
)
exp_avg = jax.tree.map(
lambda ea, g: b1 * ea + (1 - b1) * g, state.exp_avg, updates
Expand Down
1 change: 1 addition & 0 deletions optax/contrib/_prodigy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def update_fn(
estim_lr = state.estim_lr
numerator_weighted = state.numerator_weighted
bc = ((1 - beta2**count_inc) ** 0.5) / (1 - beta1**count_inc)
# pyrefly: ignore [missing-attribute]
dlr = jnp.asarray(estim_lr * sched * bc, dtype=estim_lr.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
dg = jax.tree.map(lambda g: estim_lr * g, updates)
param_diff = jax.tree.map(lambda p0, p: p0 - p, params0, params)
Expand Down
8 changes: 6 additions & 2 deletions optax/contrib/_reduce_on_plateau.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,14 @@ def update_fn(
count = state.count
new_count = numerics.safe_increment(count)
new_avg_value = (
count * state.avg_value + jnp.astype(value, state.avg_value.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
count * state.avg_value
# pyrefly: ignore [missing-attribute]
+ jnp.astype(value, state.avg_value.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
) / new_count
new_state = state._replace(
avg_value=new_avg_value.astype(state.avg_value.dtype), count=new_count # pytype: disable=attribute-error # jax-arraylike # noqa: E501
# pyrefly: ignore [missing-attribute]
avg_value=new_avg_value.astype(state.avg_value.dtype),
count=new_count, # pytype: disable=attribute-error # jax-arraylike # noqa: E501
)

new_state = jax.lax.cond(
Expand Down
4 changes: 4 additions & 0 deletions optax/contrib/_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def transparent_update_fn(
updates = pick_one(last_step, opt_updates, adv_updates)

if reset_state:
# pyrefly: ignore [bad-argument-type]
initial_state = adv_optimizer.init(params)
adv_state = pick_one(last_step, initial_state, adv_state)
else:
Expand Down Expand Up @@ -256,12 +257,14 @@ def opaque_update_fn(
)
adv_updates = jax.tree.map(lambda x: -x, adv_updates)

# pyrefly: ignore [bad-argument-type]
adv_params = update.apply_updates(adv_params, adv_updates)
adv_updates = grad_fn(adv_params, i)
if batch_axis_name is not None:
adv_updates = jax.lax.pmean(adv_updates, axis_name=batch_axis_name)

if reset_state:
# pyrefly: ignore [bad-argument-type]
adv_state = adv_optimizer.init(outer_params)

updates, opt_state = optimizer.update(
Expand All @@ -280,4 +283,5 @@ def opaque_update_fn(
else:
update_fn = transparent_update_fn

# pyrefly: ignore [bad-argument-type]
return base.GradientTransformationExtraArgs(init_fn, update_fn)
4 changes: 3 additions & 1 deletion optax/contrib/_schedule_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def update_fn(
lr = learning_rate
if callable(learning_rate):
lr = jnp.asarray(
learning_rate(state.step_count), dtype=state.max_lr.dtype # pytype: disable=attribute-error # jax-arraylike # noqa: E501
learning_rate(state.step_count),
# pyrefly: ignore [missing-attribute]
dtype=state.max_lr.dtype, # pytype: disable=attribute-error # jax-arraylike # noqa: E501
)
max_lr = jnp.maximum(state.max_lr, lr) # pyrefly: ignore[bad-argument-type]

Expand Down
14 changes: 14 additions & 0 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class is an independent binary prediction and different classes are not
<http://www.deeplearningbook.org/contents/prob.html>`_, 2016
"""
utils.check_subdtype(logits, jnp.floating)
# pyrefly: ignore [missing-attribute]
labels = jnp.astype(labels, logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
log_p = jax.nn.log_sigmoid(logits)
# log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable
Expand Down Expand Up @@ -299,6 +300,7 @@ def softmax_cross_entropy(
Added ``axis`` and ``where`` arguments.
"""
utils.check_subdtype(logits, jnp.floating)
# pyrefly: ignore [missing-attribute]
if where is not None and where.ndim != logits.ndim: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
where = jnp.expand_dims(where, axis) # pyrefly: ignore[bad-argument-type]
log_probs = jax.nn.log_softmax(logits, axis, where)
Expand Down Expand Up @@ -386,19 +388,25 @@ def softmax_cross_entropy_with_integer_labels(
"""
utils.check_subdtype(logits, jnp.floating)
utils.check_subdtype(labels, jnp.integer)
# pyrefly: ignore [missing-attribute]
if where is not None and where.ndim != logits.ndim: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
where = jnp.expand_dims(where, axis)
if isinstance(axis, int):
# pyrefly: ignore [missing-attribute]
axis = canonicalize_axis(axis, logits.ndim) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
elif isinstance(axis, tuple):
# Move all "feature" dimensions to the end preserving axis ordering and
# subsequent flattening "feature" dimensions to a single one.
# pyrefly: ignore [missing-attribute]
logit_axis = canonicalize_axes(axis, logits.ndim) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
# pyrefly: ignore [missing-attribute]
batch_axis = tuple(x for x in range(logits.ndim) if x not in logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
axis = len(batch_axis)
# pyrefly: ignore [missing-attribute]
logits = logits.transpose(batch_axis + logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
logits = logits.reshape(logits.shape[:len(batch_axis)] + (-1,))
if where is not None:
# pyrefly: ignore [missing-attribute]
where = where.transpose(batch_axis + logit_axis) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
where = where.reshape(where.shape[:len(batch_axis)] + (-1,))
else:
Expand Down Expand Up @@ -449,6 +457,7 @@ def multiclass_hinge_loss(

.. versionadded:: 0.2.3
"""
# pyrefly: ignore [missing-attribute]
one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1]) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
return jnp.max(scores + 1.0 - one_hot_labels, axis=-1) - _dot_last_dim(
scores, one_hot_labels
Expand All @@ -474,6 +483,7 @@ def multiclass_perceptron_loss(

.. versionadded:: 0.2.2
"""
# pyrefly: ignore [missing-attribute]
one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1]) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
return jnp.max(scores, axis=-1) - _dot_last_dim(scores, one_hot_labels)

Expand Down Expand Up @@ -733,7 +743,9 @@ def ctc_loss_with_forward_probs(
utils.check_shapes_equal(labels, label_paddings)
# pyrefly: ignore[bad-index]
utils.check_shapes_equal(logits[..., 0], logit_paddings)
# pyrefly: ignore [missing-attribute]
batchsize, unused_maxinputlen, num_classes = logits.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501
# pyrefly: ignore [missing-attribute]
batchsize_of_labels, maxlabellen = labels.shape # pytype: disable=attribute-error # jax-arraylike # noqa: E501
if batchsize_of_labels != batchsize:
raise ValueError(
Expand Down Expand Up @@ -793,6 +805,7 @@ def loop_body(prev, x):

return (next_phi, next_emit), (next_phi, next_emit)

# pyrefly: ignore [missing-attribute]
xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) # pytype: disable=attribute-error # jax-arraylike # noqa: E501
_, (logalpha_phi, logalpha_emit) = jax.lax.scan(
loop_body, (logalpha_phi_init, logalpha_emit_init), xs
Expand Down Expand Up @@ -917,6 +930,7 @@ def sigmoid_focal_loss(
Added support for continuous labels in `[0, 1]`.
"""
utils.check_subdtype(logits, jnp.floating)
# pyrefly: ignore [missing-attribute]
labels = jnp.astype(labels, logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501

# Cross-entropy loss
Expand Down
2 changes: 2 additions & 0 deletions optax/losses/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _safe_reduce(
# `jnp.sum(loss_fn(reduce_fn=None)) == loss_fn(reduce_fn=jnp.sum)`
output = jnp.where(where, output, 0.0)

# pyrefly: ignore [bad-return]
return output # pytype: disable=bad-return-type # jax-arraylike


Expand Down Expand Up @@ -139,6 +140,7 @@ def ranking_softmax_loss(
The ranking softmax loss.
"""
utils.check_subdtype(logits, jnp.floating)
# pyrefly: ignore [missing-attribute]
labels = labels.astype(logits.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501

# Applies mask so that masked elements do not count towards the loss.
Expand Down
2 changes: 2 additions & 0 deletions optax/losses/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def squared_error(
utils.check_shapes_equal(predictions, targets)
# pyrefly: ignore[unsupported-operation]
errors = predictions - targets if targets is not None else predictions
# pyrefly: ignore [bad-return]
return errors**2 # pytype: disable=bad-return-type # jax-arraylike


Expand Down Expand Up @@ -133,6 +134,7 @@ def log_cosh(
# pyrefly: ignore[unsupported-operation]
errors = (predictions - targets) if (targets is not None) else predictions
# log(cosh(x)) = log((exp(x) + exp(-x))/2) = log(exp(x) + exp(-x)) - log(2)
# pyrefly: ignore [missing-attribute, unsupported-operation]
return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) # pytype: disable=attribute-error # jax-arraylike # noqa: E501


Expand Down
16 changes: 11 additions & 5 deletions optax/losses/_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ def dice_loss(
Volumetric Medical Image Segmentation" (2016).
"""

# pyrefly: ignore [missing-attribute]
if predictions.ndim == targets.ndim - 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
predictions = predictions[..., None] # pyrefly: ignore[bad-index]
# pyrefly: ignore [missing-attribute]
if targets.ndim == predictions.ndim - 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
targets = targets[..., None] # pyrefly: ignore[bad-index]
utils.check_shapes_equal(predictions, targets)
Expand All @@ -164,13 +166,14 @@ def dice_loss(
# Convert logits to probabilities
probs = predictions
if apply_softmax:
probs = (
jax.nn.sigmoid(predictions)
if predictions.shape[-1] == 1 # pytype: disable=attribute-error # jax-arraylike # noqa: E501
else jax.nn.softmax(predictions, axis=-1)
)
# pyrefly: ignore [missing-attribute]
if predictions.shape[-1] == 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
probs = jax.nn.sigmoid(predictions)
else:
probs = jax.nn.softmax(predictions, axis=-1)

# Default behavior: sum over all spatial dimensions (except first/last)
# pyrefly: ignore [bad-assignment, missing-attribute]
axis = tuple(range(1, probs.ndim - 1)) if axis is None else axis # pytype: disable=attribute-error # jax-arraylike # noqa: E501

# Compute intersection and sums over specified axes
Expand All @@ -196,6 +199,7 @@ def dice_loss(
dice_l = dice_l * class_weights

# Handle background class ignoring
# pyrefly: ignore [missing-attribute]
if ignore_background and probs.shape[-1] > 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
# Exclude the first class (background) from loss computation
dice_l = dice_l[..., 1:]
Expand Down Expand Up @@ -238,6 +242,7 @@ def multiclass_generalized_dice_loss(
utils.check_shapes_equal(predictions, targets)

# Compute class frequencies for weighting
# pyrefly: ignore [missing-attribute]
class_frequencies = jnp.sum(targets, axis=tuple(range(targets.ndim - 1))) # pytype: disable=attribute-error # jax-arraylike # noqa: E501

# Compute weights as inverse of squared frequencies
Expand Down Expand Up @@ -280,6 +285,7 @@ def binary_dice_loss(
Loss values of shape [...] (batch dimensions only).
"""
# Ensure both have channel dimension
# pyrefly: ignore [missing-attribute]
if predictions.ndim == targets.ndim and predictions.shape[-1] != 1: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
predictions = predictions[..., None] # pyrefly: ignore[bad-index]
targets = targets[..., None] # pyrefly: ignore[bad-index]
Expand Down
3 changes: 3 additions & 0 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def ntxent(
.. versionadded:: 0.2.3
"""
utils.check_subdtype(embeddings, jnp.floating)
# pyrefly: ignore [missing-attribute]
if labels.shape[0] != embeddings.shape[0]: # pytype: disable=attribute-error # jax-arraylike # noqa: E501
raise ValueError(
'Labels and embeddings must have the same leading dimension, found'
# pyrefly: ignore [missing-attribute]
f' {labels.shape[0]} for labels and {embeddings.shape[0]} for' # pytype: disable=attribute-error # jax-arraylike # noqa: E501
' embeddings.'
)
Expand All @@ -91,6 +93,7 @@ def ntxent(
_regression.cosine_similarity(
embeddings[None, :, :], # pyrefly: ignore[bad-index]
embeddings[:, None, :], # pyrefly: ignore[bad-index]
# pyrefly: ignore [missing-attribute]
epsilon=jnp.finfo(embeddings.dtype).eps, # pytype: disable=attribute-error # jax-arraylike # noqa: E501
)
/ temperature
Expand Down
1 change: 1 addition & 0 deletions optax/losses/_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ def smooth_labels(
num_categories = jnp.size(labels, axis)
else:
num_categories = jnp.sum(where, axis, keepdims=True)
# pyrefly: ignore [bad-return]
return (1.0 - alpha) * labels + alpha / num_categories # pytype: disable=bad-return-type # jax-arraylike # noqa: E501
1 change: 1 addition & 0 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def projection_hypercube(tree: Any, scale: Any = 1) -> Any:
def _projection_unit_simplex(values: jax.typing.ArrayLike) -> jax.Array:
"""Projection onto the unit simplex."""
s = 1
# pyrefly: ignore [missing-attribute]
n_features = values.shape[0] # pytype: disable=attribute-error # jax-arraylike # noqa: E501
u = jnp.sort(values)[::-1]
cumsum_u = jnp.cumsum(u)
Expand Down
1 change: 1 addition & 0 deletions optax/transforms/_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def skip_not_finite(
not_finite = jax.tree.map(lambda x: ~jnp.isfinite(x), updates)
num_not_finite = optax.tree.sum(not_finite)
should_skip = num_not_finite > 0 # pyrefly: ignore[unsupported-operation]
# pyrefly: ignore [bad-return]
return should_skip, { # pytype: disable=bad-return-type
'should_skip': should_skip,
'num_not_finite': num_not_finite,
Expand Down
Loading
Loading