Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import equinox as eqx
import jax.numpy as jnp
import lineax as lx
import numpy as np
from jaxtyping import Array, Int

from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.chain import Chain
from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes

import lineax as lx

class Invert(AbstractBijection):
"""Invert a bijection.
Expand Down Expand Up @@ -301,6 +301,7 @@ def __init__(
self,
bijection: AbstractBijection,
inverter: Callable[[AbstractBijection, Array, Array | None], Array],
*,
use_implicit_differentation: bool = True,
):
self.bijection = bijection
Expand All @@ -312,25 +313,6 @@ def __init__(
else:
self.inverter = inverter

@staticmethod
def _wrap_inverter_with_error_on_grad(inverter):
@eqx.filter_custom_jvp
def wrapped_inverter(bijection, y, condition=None):
return inverter(bijection, y, condition)

@wrapped_inverter.def_jvp
def wrapped_inverter_jvp(*_args, **_kwargs):
raise RuntimeError(
"Computing gradients through the numerical inverse would lead to "
"misleading results. If you are using a flow with the analytical "
"transform only defined in one direction, consider inverting the "
"bijection by flipping the ``invert`` argument to the flow. If this is "
"not possible, consider using implicit differentation (not yet "
"supported)."
)

return wrapped_inverter

@staticmethod
def _wrap_inverter_with_implicit_jvp(inverter):
@eqx.filter_custom_jvp
Expand Down