Skip to content

Auto-differentiable numerical inverse & efficient materialization of BNAF masks#234

Merged
danielward27 merged 11 commits into
danielward27:mainfrom
noahewolfe:diffable-inverse-bnaf
Feb 18, 2026
Merged

Auto-differentiable numerical inverse & efficient materialization of BNAF masks#234
danielward27 merged 11 commits into
danielward27:mainfrom
noahewolfe:diffable-inverse-bnaf

Conversation

@noahewolfe

@noahewolfe noahewolfe commented Feb 9, 2026

Copy link
Copy Markdown
Contributor

Here, we make two updates which are particularly relevant for block-neural autoregressive flows:

  1. Auto-differentiable numerical inverse via the implicit function theorem. (Fixes Autodiff problem with block_neural_autoregressive_flow #176)
  • Using the implicit function theorem (I followed https://arxiv.org/abs/2111.00254 in particular), we define a custom jacobian-vector product for NumericalInverse transforms.
  • Adds lineax as a dependency for memory-efficient computation of the JVP.
  • I've tested that this works using the block_neural_autoregressive_flow and the default greedy bisection search (in a publication to be on the arxiv in the next week or two).
  • I added very light unit tests, which pass.
  1. Efficient materialization of block masks with Kronecker products
  • I found that it was slow to initialize (e.g., just put into memory---no training or anything) block-neural autoregressive flows with large nn_block_dim.
  • I traced this to block_diag_mask and block_tril_mask, and reimplemented these with kronecker products.
  • These changes pass the unit tests in test_masks.py

Let me know what you think, happy to answer any questions and take any feedback!

@danielward27

Copy link
Copy Markdown
Owner

Nice! Cheers. I'm a bit busy at the moment but will try to check it out before the end of the week.

@danielward27 danielward27 left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much again for the contribution! It looks good to me, I've just left some comments with some minor suggestions, let me know what you think.

Comment thread flowjax/bijections/block_autoregressive_network.py Outdated
Comment thread flowjax/bijections/utils.py Outdated
bijection: AbstractBijection,
inverter: Callable[[AbstractBijection, Array, Array | None], Array],
diffable_inverter: bool = False,
raise_old_error: bool = False,

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the raise_old_error. If we don't bother using the legacy behavior for a deprecation cycle then we should probably just not bother including it to simplify the code a bit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This make sense; I've removed this flag. Should we also remove the staticmethod that replicates the old wrapper? (_wrap_inverter_with_error_on_grad)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks! I'll merge this and remove _wrap_inverter_with_error_on_grad in a separate commit. I don't think it should be needed.

Comment thread flowjax/bijections/utils.py Outdated
Noah Wolfe added 2 commits February 17, 2026 11:26
…fferentation flag; remove unused make_layer in BlockAutoregressiveNetwork __init__
@danielward27 danielward27 merged commit 85cde8f into danielward27:main Feb 18, 2026
1 check passed
@noahewolfe

Copy link
Copy Markdown
Contributor Author

Just wanted to follow-up here and record for future reference; the publication for which I developed this feature is now on the arXiv!
https://arxiv.org/abs/2602.20277v1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Autodiff problem with block_neural_autoregressive_flow

2 participants