From bc13905ac66be20a96afa9c3fb572e2fcec2a27f Mon Sep 17 00:00:00 2001 From: Birdy Phathanapirom Date: Tue, 12 May 2026 16:03:23 -0400 Subject: [PATCH 1/6] Adding third-party mmdfuse code, reimplemented in numpy to remove jax dependency --- reno/third_party/__init__.py | 8 + reno/third_party/mmdfuse/LICENSE.md | 21 ++ reno/third_party/mmdfuse/__init__.py | 0 reno/third_party/mmdfuse/mmdfuse.py | 422 +++++++++++++++++++++++++++ 4 files changed, 451 insertions(+) create mode 100644 reno/third_party/__init__.py create mode 100644 reno/third_party/mmdfuse/LICENSE.md create mode 100644 reno/third_party/mmdfuse/__init__.py create mode 100644 reno/third_party/mmdfuse/mmdfuse.py diff --git a/reno/third_party/__init__.py b/reno/third_party/__init__.py new file mode 100644 index 0000000..1b24214 --- /dev/null +++ b/reno/third_party/__init__.py @@ -0,0 +1,8 @@ +""" +Public API is the mmdfuse function, re-exported for convenient import: + from reno.third_party.mmdfuse import mmdfuse +""" + +from .mmdfuse import mmdfuse + +__all__ = ["mmdfuse"] diff --git a/reno/third_party/mmdfuse/LICENSE.md b/reno/third_party/mmdfuse/LICENSE.md new file mode 100644 index 0000000..4a19d1a --- /dev/null +++ b/reno/third_party/mmdfuse/LICENSE.md @@ -0,0 +1,21 @@ +# MIT License + +Copyright (c) 2023 Antonin Schrab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/reno/third_party/mmdfuse/__init__.py b/reno/third_party/mmdfuse/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py new file mode 100644 index 0000000..c530336 --- /dev/null +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -0,0 +1,422 @@ +# ============================================================================= +# mmdfuse.py – copy of the original MMD‑FUSE implementation +# ============================================================================= +# Copyright (c) 2023 Biggs, Schrab & Gretton +# Licensed under the MIT License (see LICENSE file in this directory). +# +# ------------------------------------------------------------------------- +# Modifications made for the Reno repository +# ------------------------------------------------------------------------- +# • Removed JAX dependency – all JAX‑specific imports and `jax.numpy` +# calls have been replaced with NumPy equivalents. The public API +# (`mmdfuse(...)`) and the statistical behaviour remain unchanged. +# • Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. +# • Minor refactoring to avoid JAX random‑key handling. +# +# ------------------------------------------------------------------------- +# Citation +# ------------------------------------------------------------------------- +# If you use this code in a publication, please cite the original work: +# @article{biggs2023mmdfuse, +# author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, +# title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, +# year = {2023}, +# journal = {Advances in Neural Information Processing Systems}, +# volume = {36} +# } +# ============================================================================= + +import numpy as np + + +def _logsumexp(a, axis=None, b=1.0): + """ + NumPy-only stable logsumexp. + + Equivalent to scipy.special.logsumexp(a, axis=axis, b=b) + for positive scalar b. + """ + a = np.asarray(a) + a_max = np.max(a, axis=axis, keepdims=True) + + # Avoid invalid operations if all values are -inf + shifted = np.exp(a - a_max) + s = np.sum(b * shifted, axis=axis, keepdims=True) + out = a_max + np.log(s) + + if axis is not None: + out = np.squeeze(out, axis=axis) + + return out + + +def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5): + """ + Compute kernel matrix for a given kernel and bandwidth. + + Parameters + ---------- + pairwise_matrix : ndarray + Matrix of pairwise distances. + l : {"l1", "l2"} + Distance type. + kernel : str + Kernel name. + bandwidth : float + Kernel bandwidth. + rq_kernel_exponent : float + Exponent for rational quadratic kernel. + + Returns + ------- + ndarray + Kernel matrix. + """ + d = pairwise_matrix / bandwidth + + if kernel == "gaussian" and l == "l2": + return np.exp(-(d**2) / 2) + + elif kernel == "laplace" and l == "l1": + return np.exp(-d * np.sqrt(2)) + + elif kernel == "rq" and l == "l2": + return (1 + d**2 / (2 * rq_kernel_exponent)) ** (-rq_kernel_exponent) + + elif kernel == "imq" and l == "l2": + return (1 + d**2) ** (-0.5) + + elif (kernel == "matern_0.5_l1" and l == "l1") or ( + kernel == "matern_0.5_l2" and l == "l2" + ): + return np.exp(-d) + + elif (kernel == "matern_1.5_l1" and l == "l1") or ( + kernel == "matern_1.5_l2" and l == "l2" + ): + return (1 + np.sqrt(3) * d) * np.exp(-np.sqrt(3) * d) + + elif (kernel == "matern_2.5_l1" and l == "l1") or ( + kernel == "matern_2.5_l2" and l == "l2" + ): + return (1 + np.sqrt(5) * d + 5 / 3 * d**2) * np.exp(-np.sqrt(5) * d) + + elif (kernel == "matern_3.5_l1" and l == "l1") or ( + kernel == "matern_3.5_l2" and l == "l2" + ): + return ( + 1 + + np.sqrt(7) * d + + 2 * 7 / 5 * d**2 + + 7 * np.sqrt(7) / 3 / 5 * d**3 + ) * np.exp(-np.sqrt(7) * d) + + elif (kernel == "matern_4.5_l1" and l == "l1") or ( + kernel == "matern_4.5_l2" and l == "l2" + ): + return ( + 1 + + 3 * d + + 3 * (6**2) / 28 * d**2 + + (6**3) / 84 * d**3 + + (6**4) / 1680 * d**4 + ) * np.exp(-3 * d) + + else: + raise ValueError('The values of "l" and "kernel" are not valid.') + + +def np_distances(X, Y, l, max_samples=None, matrix=False): + """ + NumPy replacement for jax_distances. + + Computes pairwise l1 or l2 distances using broadcasting. + + Parameters + ---------- + X : ndarray, shape (m, d) + Y : ndarray, shape (n, d) + l : {"l1", "l2"} + max_samples : int or None + matrix : bool + If True, return full distance matrix. + If False, return upper-triangular entries. + + Returns + ------- + ndarray + """ + X = np.asarray(X) + Y = np.asarray(Y) + + Xs = X[:max_samples] + Ys = Y[:max_samples] + + diff = Xs[:, None, :] - Ys[None, :, :] + + if l == "l1": + output = np.sum(np.abs(diff), axis=-1) + elif l == "l2": + output = np.sqrt(np.sum(diff**2, axis=-1)) + else: + raise ValueError("Value of 'l' must be either 'l1' or 'l2'.") + + if matrix: + return output + else: + return output[np.triu_indices(output.shape[0])] + + +def compute_bandwidths(X, Y, l, number_bandwidths, only_median=False): + """ + NumPy replacement for the JAX/JIT compute_bandwidths function. + """ + Z = np.concatenate((X, Y), axis=0) + distances = np_distances(Z, Z, l, matrix=False) + + median = np.median(distances) + + if only_median: + return median + + distances = distances + (distances == 0) * median + dd = np.sort(distances) + + lambda_min = dd[int(np.floor(len(dd) * 0.05))] / 2 + lambda_max = dd[int(np.floor(len(dd) * 0.95))] * 2 + + bandwidths = np.linspace(lambda_min, lambda_max, number_bandwidths) + return bandwidths + + +def _make_rng(key=None): + """ + Convert a key-like input into a NumPy random Generator. + + Parameters + ---------- + key : None, int, or np.random.Generator + + Returns + ------- + np.random.Generator + """ + if isinstance(key, np.random.Generator): + return key + return np.random.default_rng(key) + + +def mmdfuse( + X, + Y, + key=None, + alpha=0.05, + kernels=("laplace", "gaussian"), + lambda_multiplier=1, + number_bandwidths=10, + number_permutations=2000, + return_p_val=False, +): + """ + Two-Sample MMD-FUSE test, NumPy-only version. + + Parameters + ---------- + X : array_like, shape (m, d) + Y : array_like, shape (n, d) + key : None, int, or np.random.Generator + Random seed or NumPy Generator. + Example: key=0 + alpha : float + Test level. + kernels : str or tuple/list of str + Kernel names. + lambda_multiplier : float + number_bandwidths : int + number_permutations : int + return_p_val : bool + + Returns + ------- + int + 0 if the test fails to reject the null. + 1 if the test rejects the null. + + Or, if return_p_val=True: + + output : int + p_val : float + all_statistics : ndarray + """ + X = np.asarray(X, dtype=float) + Y = np.asarray(Y, dtype=float) + + rng = _make_rng(key) + + # Match original behavior: ensure n <= m + if Y.shape[0] > X.shape[0]: + X, Y = Y, X + + m = X.shape[0] + n = Y.shape[0] + + assert n <= m + assert n >= 2 and m >= 2 + assert 0 < alpha < 1 + assert lambda_multiplier > 0 + assert number_bandwidths > 1 and type(number_bandwidths) is int + assert number_permutations > 0 and type(number_permutations) is int + + if type(kernels) is str: + kernels = (kernels,) + + valid_kernels = ( + "imq", + "rq", + "gaussian", + "matern_0.5_l2", + "matern_1.5_l2", + "matern_2.5_l2", + "matern_3.5_l2", + "matern_4.5_l2", + "laplace", + "matern_0.5_l1", + "matern_1.5_l1", + "matern_2.5_l1", + "matern_3.5_l1", + "matern_4.5_l1", + ) + + for kernel in kernels: + assert kernel in valid_kernels + + all_kernels_l1 = ( + "laplace", + "matern_0.5_l1", + "matern_1.5_l1", + "matern_2.5_l1", + "matern_3.5_l1", + "matern_4.5_l1", + ) + + all_kernels_l2 = ( + "imq", + "rq", + "gaussian", + "matern_0.5_l2", + "matern_1.5_l2", + "matern_2.5_l2", + "matern_3.5_l2", + "matern_4.5_l2", + ) + + number_kernels = len(kernels) + kernels_l1 = [k for k in kernels if k in all_kernels_l1] + kernels_l2 = [k for k in kernels if k in all_kernels_l2] + + # Setup for permutations + B = number_permutations + total = m + n + + # Shape: (B + 1, m + n) + idx = np.empty((B + 1, total), dtype=int) + for b in range(B + 1): + idx[b] = rng.permutation(total) + + # 11 + v11 = np.concatenate((np.ones(m), -np.ones(n))) + V11i = np.tile(v11, (B + 1, 1)) + V11 = np.take_along_axis(V11i, idx, axis=1) + V11[B] = v11 + V11 = V11.T + + # 10 + v10 = np.concatenate((np.ones(m), np.zeros(n))) + V10i = np.tile(v10, (B + 1, 1)) + V10 = np.take_along_axis(V10i, idx, axis=1) + V10[B] = v10 + V10 = V10.T + + # 01 + v01 = np.concatenate((np.zeros(m), -np.ones(n))) + V01i = np.tile(v01, (B + 1, 1)) + V01 = np.take_along_axis(V01i, idx, axis=1) + V01[B] = v01 + V01 = V01.T + + # Compute all permuted MMD estimates + N = number_bandwidths * number_kernels + M = np.zeros((N, B + 1)) + + kernel_count = -1 + + Z = np.concatenate((X, Y), axis=0) + + for r in range(2): + kernels_l = (kernels_l1, kernels_l2)[r] + l = ("l1", "l2")[r] + + if len(kernels_l) > 0: + # Pairwise distance matrix + pairwise_matrix = np_distances(Z, Z, l, matrix=True) + + # Collection of bandwidths + distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0])] + + median = np.median(distances) + distances = distances + (distances == 0) * median + + dd = np.sort(distances) + lambda_min = dd[int(np.floor(len(dd) * 0.05))] / 2 + lambda_max = dd[int(np.floor(len(dd) * 0.95))] * 2 + + bandwidths = np.linspace(lambda_min, lambda_max, number_bandwidths) + + # Compute all permuted MMD estimates for either l1 or l2 + for kernel in kernels_l: + kernel_count += 1 + + for i in range(number_bandwidths): + bandwidth = bandwidths[i] + + # Compute kernel matrix and set diagonal to zero + K = kernel_matrix(pairwise_matrix, l, kernel, bandwidth) + np.fill_diagonal(K, 0) + + # Compute standard deviation + unscaled_std = np.sqrt(np.sum(K**2)) + + # Matrix products + KV10 = K @ V10 + KV01 = K @ V01 + KV11 = K @ V11 + + values = ( + np.sum(V10 * KV10, axis=0) + * (n - m + 1) + * (n - 1) + / (m * (m - 1)) + + np.sum(V01 * KV01, axis=0) + * (m - n + 1) + / m + + np.sum(V11 * KV11, axis=0) + * (n - 1) + / m + ) + + values = values / unscaled_std * np.sqrt(n * (n - 1)) + + M[kernel_count * number_bandwidths + i] = values + + # Compute permuted and original statistics + all_statistics = _logsumexp(lambda_multiplier * M, axis=0, b=1 / N) + original_statistic = all_statistics[-1] + + # Compute p-value and test output + p_val = np.mean(all_statistics >= original_statistic) + output = int(p_val <= alpha) + + if return_p_val: + return output, p_val, all_statistics + else: + return output From 169a7a5972f64122603651aa626650baf974d007 Mon Sep 17 00:00:00 2001 From: Birdy Phathanapirom Date: Mon, 18 May 2026 08:22:02 -0400 Subject: [PATCH 2/6] added third-party mmdfuse code to repo for use in voi tools --- reno/third_party/mmdfuse/mmdfuse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py index c530336..64371d8 100644 --- a/reno/third_party/mmdfuse/mmdfuse.py +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -24,6 +24,7 @@ # journal = {Advances in Neural Information Processing Systems}, # volume = {36} # } +# @TODO include link to repo # ============================================================================= import numpy as np From 5792866a9607f22f7e23166b7403e68b9038b181 Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Mon, 18 May 2026 10:21:56 -0400 Subject: [PATCH 3/6] Update docstring style in mmdfuse.py --- reno/third_party/mmdfuse/mmdfuse.py | 137 ++++++++++------------------ 1 file changed, 48 insertions(+), 89 deletions(-) diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py index 64371d8..28da0b0 100644 --- a/reno/third_party/mmdfuse/mmdfuse.py +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -31,11 +31,9 @@ def _logsumexp(a, axis=None, b=1.0): - """ - NumPy-only stable logsumexp. + """NumPy-only stable logsumexp. - Equivalent to scipy.special.logsumexp(a, axis=axis, b=b) - for positive scalar b. + Equivalent to ``scipy.special.logsumexp(a, axis=axis, b=b)`` for positive scalar b. """ a = np.asarray(a) a_max = np.max(a, axis=axis, keepdims=True) @@ -52,25 +50,16 @@ def _logsumexp(a, axis=None, b=1.0): def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5): - """ - Compute kernel matrix for a given kernel and bandwidth. - - Parameters - ---------- - pairwise_matrix : ndarray - Matrix of pairwise distances. - l : {"l1", "l2"} - Distance type. - kernel : str - Kernel name. - bandwidth : float - Kernel bandwidth. - rq_kernel_exponent : float - Exponent for rational quadratic kernel. - - Returns - ------- - ndarray + """Compute kernel matrix for a given kernel and bandwidth. + + Args: + pairwise_matrix (ndarray): Matrix of pairwise distances. + l (str): {"l1", "l2"} Distance type. + kernel (str): Kernel name. + bandwidth (float): Kernel bandwidth. + rq_kernel_exponent (float): Exponent for rational quadratic kernel. + + Returns: Kernel matrix. """ d = pairwise_matrix / bandwidth @@ -106,10 +95,7 @@ def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5) kernel == "matern_3.5_l2" and l == "l2" ): return ( - 1 - + np.sqrt(7) * d - + 2 * 7 / 5 * d**2 - + 7 * np.sqrt(7) / 3 / 5 * d**3 + 1 + np.sqrt(7) * d + 2 * 7 / 5 * d**2 + 7 * np.sqrt(7) / 3 / 5 * d**3 ) * np.exp(-np.sqrt(7) * d) elif (kernel == "matern_4.5_l1" and l == "l1") or ( @@ -128,24 +114,17 @@ def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5) def np_distances(X, Y, l, max_samples=None, matrix=False): - """ - NumPy replacement for jax_distances. + """NumPy replacement for jax_distances. Computes pairwise l1 or l2 distances using broadcasting. - Parameters - ---------- - X : ndarray, shape (m, d) - Y : ndarray, shape (n, d) - l : {"l1", "l2"} - max_samples : int or None - matrix : bool - If True, return full distance matrix. - If False, return upper-triangular entries. - - Returns - ------- - ndarray + Args: + X (ndarray): shape (m, d) + Y (ndarray): shape (n, d) + l (str): {"l1", "l2"} Distance type. + max_samples (int): Maximum number of pairs to draw for computing distances. + matrix (bool): Returns the full distance matrix if ``True``, otherwise just the + upper-triangular entries. """ X = np.asarray(X) Y = np.asarray(Y) @@ -169,9 +148,7 @@ def np_distances(X, Y, l, max_samples=None, matrix=False): def compute_bandwidths(X, Y, l, number_bandwidths, only_median=False): - """ - NumPy replacement for the JAX/JIT compute_bandwidths function. - """ + """NumPy replacement for the JAX/JIT compute_bandwidths function.""" Z = np.concatenate((X, Y), axis=0) distances = np_distances(Z, Z, l, matrix=False) @@ -191,22 +168,20 @@ def compute_bandwidths(X, Y, l, number_bandwidths, only_median=False): def _make_rng(key=None): - """ - Convert a key-like input into a NumPy random Generator. + """Convert a key-like input into a NumPy random Generator. - Parameters - ---------- - key : None, int, or np.random.Generator + Args: + key (int | np.random.Generator): The key or existing generator. - Returns - ------- - np.random.Generator + Returns: + np.random.Generator """ if isinstance(key, np.random.Generator): return key return np.random.default_rng(key) +# NOTE: typehint for array-like is np.typing.ArrayLike def mmdfuse( X, Y, @@ -218,36 +193,24 @@ def mmdfuse( number_permutations=2000, return_p_val=False, ): - """ - Two-Sample MMD-FUSE test, NumPy-only version. - - Parameters - ---------- - X : array_like, shape (m, d) - Y : array_like, shape (n, d) - key : None, int, or np.random.Generator - Random seed or NumPy Generator. - Example: key=0 - alpha : float - Test level. - kernels : str or tuple/list of str - Kernel names. - lambda_multiplier : float - number_bandwidths : int - number_permutations : int - return_p_val : bool - - Returns - ------- - int - 0 if the test fails to reject the null. - 1 if the test rejects the null. - - Or, if return_p_val=True: - - output : int - p_val : float - all_statistics : ndarray + """Two-Sample MMD-FUSE test, NumPy-only version. + + Args: + X (array_like): shape (m, d) + Y (array_like): shape (n, d) + key (int | np.random.Generator) Random seed or NumPy Generator. + Example: ``key=0`` + alpha (float): Test level. + kernels (str | tuple[str] | list[str]): Kernel names. + lambda_multiplier (float): ??? + number_bandwidths (int): ??? + number_permutations (int): ??? + return_p_val (bool): ??? + + Returns: + 0 if the test fails to reject the null. 1 if the test rejects the null. + Or, if return_p_val=True, returns a tuple of the int output, float p_val, and + the numpy array containing all statistics. """ X = np.asarray(X, dtype=float) Y = np.asarray(Y, dtype=float) @@ -397,12 +360,8 @@ def mmdfuse( * (n - m + 1) * (n - 1) / (m * (m - 1)) - + np.sum(V01 * KV01, axis=0) - * (m - n + 1) - / m - + np.sum(V11 * KV11, axis=0) - * (n - 1) - / m + + np.sum(V01 * KV01, axis=0) * (m - n + 1) / m + + np.sum(V11 * KV11, axis=0) * (n - 1) / m ) values = values / unscaled_std * np.sqrt(n * (n - 1)) From b1459bfbbded4a8fd83ebd20ac790708d7a7e69a Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Mon, 18 May 2026 10:25:33 -0400 Subject: [PATCH 4/6] Update mmdfuse module docstring --- reno/third_party/mmdfuse/mmdfuse.py | 61 ++++++++++++++++------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py index 28da0b0..d5b8c38 100644 --- a/reno/third_party/mmdfuse/mmdfuse.py +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -1,31 +1,36 @@ -# ============================================================================= -# mmdfuse.py – copy of the original MMD‑FUSE implementation -# ============================================================================= -# Copyright (c) 2023 Biggs, Schrab & Gretton -# Licensed under the MIT License (see LICENSE file in this directory). -# -# ------------------------------------------------------------------------- -# Modifications made for the Reno repository -# ------------------------------------------------------------------------- -# • Removed JAX dependency – all JAX‑specific imports and `jax.numpy` -# calls have been replaced with NumPy equivalents. The public API -# (`mmdfuse(...)`) and the statistical behaviour remain unchanged. -# • Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. -# • Minor refactoring to avoid JAX random‑key handling. -# -# ------------------------------------------------------------------------- -# Citation -# ------------------------------------------------------------------------- -# If you use this code in a publication, please cite the original work: -# @article{biggs2023mmdfuse, -# author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, -# title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, -# year = {2023}, -# journal = {Advances in Neural Information Processing Systems}, -# volume = {36} -# } -# @TODO include link to repo -# ============================================================================= +"""Third party tool for a value of information algorithm. + +============================================================================= +mmdfuse.py - copy of the original MMD-FUSE implementation +============================================================================= +Copyright (c) 2023 Biggs, Schrab & Gretton +Licensed under the MIT License (see LICENSE file in this directory). + +------------------------------------------------------------------------- + Modifications made for the Reno repository +------------------------------------------------------------------------- +* Removed JAX dependency – all JAX‑specific imports and `jax.numpy` +calls have been replaced with NumPy equivalents. The public API +(`mmdfuse(...)`) and the statistical behaviour remain unchanged. +* Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. +* Minor refactoring to avoid JAX random‑key handling. +* Update docstring style. + +------------------------------------------------------------------------- + Citation +------------------------------------------------------------------------- +If you use this code in a publication, please cite the original work: + @article{biggs2023mmdfuse, + author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, + title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, + year = {2023}, + journal = {Advances in Neural Information Processing Systems}, + volume = {36} + } + +Repo link: https://github.com/antoninschrab/mmdfuse +============================================================================= +""" import numpy as np From 8b39e9afbc7416b7f8a8558c37db5e9c9a2fe0fe Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Mon, 18 May 2026 10:36:29 -0400 Subject: [PATCH 5/6] Add typehints to mmdfuse.py --- reno/third_party/mmdfuse/mmdfuse.py | 177 +++++++++++++++------------- 1 file changed, 96 insertions(+), 81 deletions(-) diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py index d5b8c38..a843d37 100644 --- a/reno/third_party/mmdfuse/mmdfuse.py +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -1,41 +1,45 @@ -"""Third party tool for a value of information algorithm. - -============================================================================= -mmdfuse.py - copy of the original MMD-FUSE implementation -============================================================================= -Copyright (c) 2023 Biggs, Schrab & Gretton -Licensed under the MIT License (see LICENSE file in this directory). - -------------------------------------------------------------------------- - Modifications made for the Reno repository -------------------------------------------------------------------------- -* Removed JAX dependency – all JAX‑specific imports and `jax.numpy` -calls have been replaced with NumPy equivalents. The public API -(`mmdfuse(...)`) and the statistical behaviour remain unchanged. -* Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. -* Minor refactoring to avoid JAX random‑key handling. -* Update docstring style. - -------------------------------------------------------------------------- - Citation -------------------------------------------------------------------------- -If you use this code in a publication, please cite the original work: - @article{biggs2023mmdfuse, - author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, - title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, - year = {2023}, - journal = {Advances in Neural Information Processing Systems}, - volume = {36} - } - -Repo link: https://github.com/antoninschrab/mmdfuse -============================================================================= +"""Third party library for MMD-FUSE, a value of information algorithm. + +See https://github.com/antoninschrab/mmdfuse for the original. """ +# ============================================================================= +# mmdfuse.py - copy of the original MMD-FUSE implementation +# ============================================================================= +# Copyright (c) 2023 Biggs, Schrab & Gretton +# Licensed under the MIT License (see LICENSE file in this directory). +# +# ------------------------------------------------------------------------- +# Modifications made for the Reno repository +# ------------------------------------------------------------------------- +# * Removed JAX dependency - all JAX-specific imports and `jax.numpy` +# calls have been replaced with NumPy equivalents. The public API +# (`mmdfuse(...)`) and the statistical behaviour remain unchanged. +# * Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. +# * Minor refactoring to avoid JAX random-key handling. +# * Update docstring style. +# * Add typehints and some minor variable renames. +# +# ------------------------------------------------------------------------- +# Citation +# ------------------------------------------------------------------------- +# If you use this code in a publication, please cite the original work: +# @article{biggs2023mmdfuse, +# author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, +# title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, +# year = {2023}, +# journal = {Advances in Neural Information Processing Systems}, +# volume = {36} +# } +# +# Repo link: https://github.com/antoninschrab/mmdfuse +# ============================================================================= + import numpy as np +from numpy.typing import ArrayLike -def _logsumexp(a, axis=None, b=1.0): +def _logsumexp(a: ArrayLike, axis: int = None, b: float = 1.0) -> np.ndarray: """NumPy-only stable logsumexp. Equivalent to ``scipy.special.logsumexp(a, axis=axis, b=b)`` for positive scalar b. @@ -54,57 +58,60 @@ def _logsumexp(a, axis=None, b=1.0): return out -def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5): +def kernel_matrix( + pairwise_matrix: np.ndarray, + dist_metric: str, + kernel: str, + bandwidth: float, + rq_kernel_exponent: float = 0.5, +) -> np.ndarray: """Compute kernel matrix for a given kernel and bandwidth. Args: pairwise_matrix (ndarray): Matrix of pairwise distances. - l (str): {"l1", "l2"} Distance type. + dist_metric (str): {"l1", "l2"} Distance type. kernel (str): Kernel name. bandwidth (float): Kernel bandwidth. rq_kernel_exponent (float): Exponent for rational quadratic kernel. - - Returns: - Kernel matrix. """ d = pairwise_matrix / bandwidth - if kernel == "gaussian" and l == "l2": + if kernel == "gaussian" and dist_metric == "l2": return np.exp(-(d**2) / 2) - elif kernel == "laplace" and l == "l1": + elif kernel == "laplace" and dist_metric == "l1": return np.exp(-d * np.sqrt(2)) - elif kernel == "rq" and l == "l2": + elif kernel == "rq" and dist_metric == "l2": return (1 + d**2 / (2 * rq_kernel_exponent)) ** (-rq_kernel_exponent) - elif kernel == "imq" and l == "l2": + elif kernel == "imq" and dist_metric == "l2": return (1 + d**2) ** (-0.5) - elif (kernel == "matern_0.5_l1" and l == "l1") or ( - kernel == "matern_0.5_l2" and l == "l2" + elif (kernel == "matern_0.5_l1" and dist_metric == "l1") or ( + kernel == "matern_0.5_l2" and dist_metric == "l2" ): return np.exp(-d) - elif (kernel == "matern_1.5_l1" and l == "l1") or ( - kernel == "matern_1.5_l2" and l == "l2" + elif (kernel == "matern_1.5_l1" and dist_metric == "l1") or ( + kernel == "matern_1.5_l2" and dist_metric == "l2" ): return (1 + np.sqrt(3) * d) * np.exp(-np.sqrt(3) * d) - elif (kernel == "matern_2.5_l1" and l == "l1") or ( - kernel == "matern_2.5_l2" and l == "l2" + elif (kernel == "matern_2.5_l1" and dist_metric == "l1") or ( + kernel == "matern_2.5_l2" and dist_metric == "l2" ): return (1 + np.sqrt(5) * d + 5 / 3 * d**2) * np.exp(-np.sqrt(5) * d) - elif (kernel == "matern_3.5_l1" and l == "l1") or ( - kernel == "matern_3.5_l2" and l == "l2" + elif (kernel == "matern_3.5_l1" and dist_metric == "l1") or ( + kernel == "matern_3.5_l2" and dist_metric == "l2" ): return ( 1 + np.sqrt(7) * d + 2 * 7 / 5 * d**2 + 7 * np.sqrt(7) / 3 / 5 * d**3 ) * np.exp(-np.sqrt(7) * d) - elif (kernel == "matern_4.5_l1" and l == "l1") or ( - kernel == "matern_4.5_l2" and l == "l2" + elif (kernel == "matern_4.5_l1" and dist_metric == "l1") or ( + kernel == "matern_4.5_l2" and dist_metric == "l2" ): return ( 1 @@ -118,7 +125,13 @@ def kernel_matrix(pairwise_matrix, l, kernel, bandwidth, rq_kernel_exponent=0.5) raise ValueError('The values of "l" and "kernel" are not valid.') -def np_distances(X, Y, l, max_samples=None, matrix=False): +def np_distances( + X: ArrayLike, + Y: ArrayLike, + dist_metric: str, + max_samples: int = None, + matrix: bool = False, +) -> np.ndarray: """NumPy replacement for jax_distances. Computes pairwise l1 or l2 distances using broadcasting. @@ -126,7 +139,7 @@ def np_distances(X, Y, l, max_samples=None, matrix=False): Args: X (ndarray): shape (m, d) Y (ndarray): shape (n, d) - l (str): {"l1", "l2"} Distance type. + dist_metric (str): {"l1", "l2"} Distance type. max_samples (int): Maximum number of pairs to draw for computing distances. matrix (bool): Returns the full distance matrix if ``True``, otherwise just the upper-triangular entries. @@ -139,9 +152,9 @@ def np_distances(X, Y, l, max_samples=None, matrix=False): diff = Xs[:, None, :] - Ys[None, :, :] - if l == "l1": + if dist_metric == "l1": output = np.sum(np.abs(diff), axis=-1) - elif l == "l2": + elif dist_metric == "l2": output = np.sqrt(np.sum(diff**2, axis=-1)) else: raise ValueError("Value of 'l' must be either 'l1' or 'l2'.") @@ -152,10 +165,16 @@ def np_distances(X, Y, l, max_samples=None, matrix=False): return output[np.triu_indices(output.shape[0])] -def compute_bandwidths(X, Y, l, number_bandwidths, only_median=False): +def compute_bandwidths( + X: np.ndarray, + Y: np.ndarray, + dist_metric: str, + number_bandwidths: int, + only_median: bool = False, +) -> np.ndarray: """NumPy replacement for the JAX/JIT compute_bandwidths function.""" Z = np.concatenate((X, Y), axis=0) - distances = np_distances(Z, Z, l, matrix=False) + distances = np_distances(Z, Z, dist_metric, matrix=False) median = np.median(distances) @@ -172,38 +191,34 @@ def compute_bandwidths(X, Y, l, number_bandwidths, only_median=False): return bandwidths -def _make_rng(key=None): - """Convert a key-like input into a NumPy random Generator. +def _make_rng(key: int | np.random.Generator = None) -> np.random.Generator: + """Convert a key-like input into a NumPy random Generator, if not one already. Args: - key (int | np.random.Generator): The key or existing generator. - - Returns: - np.random.Generator + key (int | np.random.Generator): The key to use or an existing generator. """ if isinstance(key, np.random.Generator): return key return np.random.default_rng(key) -# NOTE: typehint for array-like is np.typing.ArrayLike def mmdfuse( - X, - Y, - key=None, - alpha=0.05, - kernels=("laplace", "gaussian"), - lambda_multiplier=1, - number_bandwidths=10, - number_permutations=2000, - return_p_val=False, -): + X: ArrayLike, + Y: ArrayLike, + key: int | np.random.Generator = None, + alpha: float = 0.05, + kernels: str | tuple[str] | list[str] = ("laplace", "gaussian"), + lambda_multiplier: float = 1.0, + number_bandwidths: int = 10, + number_permutations: int = 2000, + return_p_val: bool = False, +) -> int | tuple[int, float, np.ndarray]: """Two-Sample MMD-FUSE test, NumPy-only version. Args: - X (array_like): shape (m, d) - Y (array_like): shape (n, d) - key (int | np.random.Generator) Random seed or NumPy Generator. + X (ArrayLike): shape (m, d) + Y (ArrayLike): shape (n, d) + key (int | np.random.Generator): Random seed or NumPy Generator. Example: ``key=0`` alpha (float): Test level. kernels (str | tuple[str] | list[str]): Kernel names. @@ -323,11 +338,11 @@ def mmdfuse( for r in range(2): kernels_l = (kernels_l1, kernels_l2)[r] - l = ("l1", "l2")[r] + dist_metric = ("l1", "l2")[r] if len(kernels_l) > 0: # Pairwise distance matrix - pairwise_matrix = np_distances(Z, Z, l, matrix=True) + pairwise_matrix = np_distances(Z, Z, dist_metric, matrix=True) # Collection of bandwidths distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0])] @@ -349,7 +364,7 @@ def mmdfuse( bandwidth = bandwidths[i] # Compute kernel matrix and set diagonal to zero - K = kernel_matrix(pairwise_matrix, l, kernel, bandwidth) + K = kernel_matrix(pairwise_matrix, dist_metric, kernel, bandwidth) np.fill_diagonal(K, 0) # Compute standard deviation From 124bd5399517d5fada289fd9fa953f19c95d2cb7 Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Mon, 18 May 2026 10:42:06 -0400 Subject: [PATCH 6/6] Switch dist_metric to literal typehints --- reno/third_party/__init__.py | 5 +++-- reno/third_party/mmdfuse/mmdfuse.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/reno/third_party/__init__.py b/reno/third_party/__init__.py index 1b24214..007a777 100644 --- a/reno/third_party/__init__.py +++ b/reno/third_party/__init__.py @@ -1,5 +1,6 @@ -""" -Public API is the mmdfuse function, re-exported for convenient import: +"""Public API for the mmdfuse function, re-exported for convenient import. + +Example: from reno.third_party.mmdfuse import mmdfuse """ diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py index a843d37..eb65801 100644 --- a/reno/third_party/mmdfuse/mmdfuse.py +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -35,6 +35,8 @@ # Repo link: https://github.com/antoninschrab/mmdfuse # ============================================================================= +from typing import Literal + import numpy as np from numpy.typing import ArrayLike @@ -60,7 +62,7 @@ def _logsumexp(a: ArrayLike, axis: int = None, b: float = 1.0) -> np.ndarray: def kernel_matrix( pairwise_matrix: np.ndarray, - dist_metric: str, + dist_metric: Literal["l1", "l2"], kernel: str, bandwidth: float, rq_kernel_exponent: float = 0.5, @@ -69,7 +71,7 @@ def kernel_matrix( Args: pairwise_matrix (ndarray): Matrix of pairwise distances. - dist_metric (str): {"l1", "l2"} Distance type. + dist_metric (Literal["l1", "l2"]): Distance type. kernel (str): Kernel name. bandwidth (float): Kernel bandwidth. rq_kernel_exponent (float): Exponent for rational quadratic kernel. @@ -128,7 +130,7 @@ def kernel_matrix( def np_distances( X: ArrayLike, Y: ArrayLike, - dist_metric: str, + dist_metric: Literal["l1", "l2"], max_samples: int = None, matrix: bool = False, ) -> np.ndarray: @@ -139,7 +141,7 @@ def np_distances( Args: X (ndarray): shape (m, d) Y (ndarray): shape (n, d) - dist_metric (str): {"l1", "l2"} Distance type. + dist_metric (Literal["l1", "l2"]): Distance type. max_samples (int): Maximum number of pairs to draw for computing distances. matrix (bool): Returns the full distance matrix if ``True``, otherwise just the upper-triangular entries. @@ -168,7 +170,7 @@ def np_distances( def compute_bandwidths( X: np.ndarray, Y: np.ndarray, - dist_metric: str, + dist_metric: Literal["l1", "l2"], number_bandwidths: int, only_median: bool = False, ) -> np.ndarray: