diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 6721ced..2554686 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -6,6 +6,7 @@ import equinox as eqx import jax.numpy as jnp +import lineax as lx import numpy as np from jaxtyping import Array, Int @@ -13,7 +14,6 @@ from flowjax.bijections.chain import Chain from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes -import lineax as lx class Invert(AbstractBijection): """Invert a bijection. @@ -301,6 +301,7 @@ def __init__( self, bijection: AbstractBijection, inverter: Callable[[AbstractBijection, Array, Array | None], Array], + *, use_implicit_differentation: bool = True, ): self.bijection = bijection @@ -312,25 +313,6 @@ def __init__( else: self.inverter = inverter - @staticmethod - def _wrap_inverter_with_error_on_grad(inverter): - @eqx.filter_custom_jvp - def wrapped_inverter(bijection, y, condition=None): - return inverter(bijection, y, condition) - - @wrapped_inverter.def_jvp - def wrapped_inverter_jvp(*_args, **_kwargs): - raise RuntimeError( - "Computing gradients through the numerical inverse would lead to " - "misleading results. If you are using a flow with the analytical " - "transform only defined in one direction, consider inverting the " - "bijection by flipping the ``invert`` argument to the flow. If this is " - "not possible, consider using implicit differentation (not yet " - "supported)." - ) - - return wrapped_inverter - @staticmethod def _wrap_inverter_with_implicit_jvp(inverter): @eqx.filter_custom_jvp