diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index af02d43fc..d4c68d641 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -65,28 +65,6 @@ Non-negative least squares .. autofunction:: nnls -Second Order Optimization -------------------------- - -.. currentmodule:: optax.second_order - -.. autosummary:: - fisher_diag - hessian_diag - hvp - -Fisher diagonal -~~~~~~~~~~~~~~~ -.. autofunction:: fisher_diag - -Hessian diagonal -~~~~~~~~~~~~~~~~ -.. autofunction:: hessian_diag - -Hessian vector product -~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: hvp - Tree ---- diff --git a/optax/CHANGELOG.md b/optax/CHANGELOG.md index bc2477718..b8392d2c6 100644 --- a/optax/CHANGELOG.md +++ b/optax/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Removed + +- Removed deprecated `optax.second_order.hvp`, `optax.second_order.hessian_diag`, + and `optax.second_order.fisher_diag`. These were deprecated in 0.2.7 and + scheduled for removal in 0.2.9. + ## [0.2.8] - 2026-03-20 ### Changed diff --git a/optax/second_order/__init__.py b/optax/second_order/__init__.py index a2cebadd1..a17a6dce6 100644 --- a/optax/second_order/__init__.py +++ b/optax/second_order/__init__.py @@ -13,9 +13,3 @@ # limitations under the License. # ============================================================================== """The second order optimization sub-package.""" - -# pylint: disable=g-importing-member - -from optax.second_order._deprecated import fisher_diag -from optax.second_order._deprecated import hessian_diag -from optax.second_order._deprecated import hvp diff --git a/optax/second_order/_deprecated.py b/optax/second_order/_deprecated.py deleted file mode 100644 index 9875c153f..000000000 --- a/optax/second_order/_deprecated.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Deprecated second order utilities kept for backward compatibility. -""" - -import abc -import functools -from typing import Any, Protocol - -import jax -from jax import flatten_util -import jax.numpy as jnp -from optax._src.deprecations import warn_deprecated_function # pylint: disable=g-importing-member - - -def _ravel(p: Any) -> jax.Array: - return flatten_util.ravel_pytree(p)[0] - - -class LossFn(Protocol): - """A loss function to be optimized.""" - - @abc.abstractmethod - def __call__( - self, params: Any, inputs: jax.Array, targets: jax.Array - ) -> jax.Array: - ... - - -@functools.partial(warn_deprecated_function, version_removed='0.2.9') -def hvp( - loss: LossFn, - v: jax.Array, - params: Any, - inputs: jax.Array, - targets: jax.Array, -) -> jax.Array: - """Performs an efficient vector-Hessian (of `loss`) product. - - .. deprecated: 0.2.7. This function will be removed in 0.2.9 - - Args: - loss: the loss function. - v: a vector of size `ravel(params)`. - params: model parameters. - inputs: inputs at which `loss` is evaluated. - targets: targets at which `loss` is evaluated. - - Returns: - An Array corresponding to the product of `v` and the Hessian of `loss` - evaluated at `(params, inputs, targets)`. - """ - _, unravel_fn = flatten_util.ravel_pytree(params) - loss_fn = lambda p: loss(p, inputs, targets) - return jax.jvp(jax.grad(loss_fn), [params], [unravel_fn(v)])[1] - - -@functools.partial(warn_deprecated_function, version_removed='0.2.9') -def hessian_diag( - loss: LossFn, - params: Any, - inputs: jax.Array, - targets: jax.Array, -) -> jax.Array: - """Computes the diagonal hessian of `loss` at (`inputs`, `targets`). - - .. deprecated: 0.2.7. This function will be removed in 0.2.9 - - Args: - loss: the loss function. - params: model parameters. - inputs: inputs at which `loss` is evaluated. - targets: targets at which `loss` is evaluated. - - Returns: - A DeviceArray corresponding to the product to the Hessian of `loss` - evaluated at `(params, inputs, targets)`. - """ - vs = jnp.eye(_ravel(params).size) - comp = lambda v: jnp.vdot(v, _ravel(hvp(loss, v, params, inputs, targets))) - return jax.vmap(comp)(vs) - - -@functools.partial(warn_deprecated_function, version_removed='0.2.9') -def fisher_diag( - negative_log_likelihood: LossFn, - params: Any, - inputs: jax.Array, - targets: jax.Array, -) -> jax.Array: - """Computes the diagonal of the (observed) Fisher information matrix. - - .. deprecated: 0.2.7. This function will be removed in 0.2.9 - - Args: - negative_log_likelihood: the negative log likelihood function with expected - signature `loss = fn(params, inputs, targets)`. - params: model parameters. - inputs: inputs at which `negative_log_likelihood` is evaluated. - targets: targets at which `negative_log_likelihood` is evaluated. - - Returns: - An Array corresponding to the product to the Hessian of - `negative_log_likelihood` evaluated at `(params, inputs, targets)`. - """ - return jnp.square( - _ravel(jax.grad(negative_log_likelihood)(params, inputs, targets)) - )