From bb33a5e03c51f9ef235632601d367d094068e041 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Fri, 24 Oct 2025 15:07:05 +0200 Subject: [PATCH 1/3] add fugw batch loss to ot.batch._quadratic --- ot/batch/_quadratic.py | 152 ++++++++++++++++++++++ test/batch/test_solve_unbalanced_batch.py | 0 2 files changed, 152 insertions(+) create mode 100644 test/batch/test_solve_unbalanced_batch.py diff --git a/ot/batch/_quadratic.py b/ot/batch/_quadratic.py index 0da4b8962..e549c7e72 100644 --- a/ot/batch/_quadratic.py +++ b/ot/batch/_quadratic.py @@ -152,6 +152,90 @@ def h2(C2): return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric) +def div_to_product_batch( + T, a, b, T1=None, T2=None, divergence="kl", mass=True, nx=None +): + r"""Fast computation of the Bregman divergence between a batch of arbitrary measures and a product measures. + Only support for Kullback-Leibler and half-squared L2 divergences. + + - For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \pi - a \otimes b ||^2 + = \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big] + + - For Kullback-Leibler divergence: + + .. math:: + KL(\pi | a \otimes b) + = \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle + - \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b) + + where : + + - :math:`\pi` is the (`dim_a`, `dim_b`) transport plan + - :math:`\pi_1` and :math:`\pi_2` are the marginal distributions + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + pi : array-like (B, n, m) + Transport plan for each problem in the batch + a : array-like (B,n) + Unnormalized histogram of dimension `n` for each problem in the batch + b : array-like (B,m) + Unnormalized histogram of dimension `m` for each problem in the batch + T1 : array-like (B, n), optional (default = None) + Marginal distribution with respect to the first dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + T2 : array-like (B, m), optional (default = None) + Marginal distribution with respect to the second dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + mass : bool, optional. Default is False. + Only used in case of Kullback-Leibler divergence. + If False, calculate the relative entropy. + If True, calculate the Kullback-Leibler divergence. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + Bregman divergence between an arbitrary measure and a product measure for each problem in the batch. + """ + + arr = [T, a, b, T1, T2] + + if nx is None: + nx = get_backend(*arr, T1, T2) + + if divergence == "kl": + if T1 is None: + T1 = nx.sum(T, 2) + if T2 is None: + T2 = nx.sum(T, 1) + + if divergence == "kl": + res = ( + nx.sum((T * nx.log(T + 1.0 * (T == 0))), (1, 2)) + - nx.sum(T1 * nx.log(a), 1) + - nx.sum(T2 * nx.log(b), 1) + ) + if mass: + res = res - nx.sum(T1, 1) + nx.sum(a, 1) * nx.sum(b, 1) + + elif divergence == "l2": + res = ( + nx.sum(T**2, (1, 2)) + + nx.sum(a**2, 1) * nx.sum(b**2, 1) + - 2 * nx.sum((a * (T @ b[:, :, None]).squeeze(-1)), 1) + ) / 2 + + return res + + def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None): r""" Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version. @@ -266,6 +350,74 @@ def loss_quadratic_samples_batch( ) +def loss_fugw_batch( + L, M, T, alpha=0.5, reg_marginals=1, symmetric=True, divergence="kl", nx=None +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + L : dict + Cost tensor as returned by `tensor_batch`. + M : array-like, shape (B, n, m) + Cost matrix between features across domains. + T : array-like, shape (B, n, m) + Transport plan. + alpha : float or array-like( B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float or array-like( B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + nx : module, optional + Backend to use. Default is None. + + Examples + -------- + >>> import numpy as np + >>> from ot.batch import tensor_batch, loss_quadratic_batch + >>> # Create batch of cost matrices + >>> C1 = np.random.rand(3, 5, 5) # 3 problems, 5x5 source matrices + >>> C2 = np.random.rand(3, 4, 4) # 3 problems, 4x4 target matrices + >>> a = np.ones((3, 5)) / 5 # Uniform source distributions + >>> b = np.ones((3, 4)) / 4 # Uniform target distributions + >>> L = tensor_batch(a, b, C1, C2, loss='sqeuclidean') + >>> # Use the uniform transport plan for testing + >>> T = np.ones((3, 5, 4)) / (5 * 4) + >>> loss = loss_quadratic_batch(L, T, recompute_const=True) + >>> loss.shape + (3,) + + See Also + -------- + ot.batch.tensor_batch : From computing the cost tensor L. + ot.batch.solve_gromov_batch : For finding the optimal transport plan T. + """ + if nx is None: + nx = get_backend(T) + + Q = loss_quadratic_batch(L, T, recompute_const=True, symmetric=symmetric, nx=nx) + + L = loss_linear_batch(M, T, nx=nx) + + unbalanced = div_to_product_batch( + T, + a=nx.sum(T, axis=2), + b=nx.sum(T, axis=1), + divergence=divergence, + mass=True, + nx=nx, + ) + + return (1 - alpha) * L + alpha * Q + reg_marginals * unbalanced + + def solve_gromov_batch( C1, C2, diff --git a/test/batch/test_solve_unbalanced_batch.py b/test/batch/test_solve_unbalanced_batch.py new file mode 100644 index 000000000..e69de29bb From e58591e5e5a8404aea09eaa3e234afb64a6eee03 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Fri, 22 May 2026 16:37:19 +0200 Subject: [PATCH 2/3] add tests and fix functions in batch module --- ot/batch/_linear.py | 2 +- ot/batch/_quadratic.py | 183 ++++++++++++++++++---- test/batch/test_solve_batch.py | 2 +- test/batch/test_solve_gromov_batch.py | 77 ++++++++- test/batch/test_solve_unbalanced_batch.py | 0 5 files changed, 228 insertions(+), 36 deletions(-) delete mode 100644 test/batch/test_solve_unbalanced_batch.py diff --git a/ot/batch/_linear.py b/ot/batch/_linear.py index a63fcb404..1a9ec1955 100644 --- a/ot/batch/_linear.py +++ b/ot/batch/_linear.py @@ -147,7 +147,7 @@ def loss_linear_batch(M, T, nx=None): return nx.sum(M * T, axis=(1, 2)) -def loss_linear_samples_batch(X, Y, T, metric="l2"): +def loss_linear_samples_batch(X, Y, T, metric="sqeuclidean"): r"""Computes the linear optimal transport loss given samples and transport plan. This is the equivalent of calling `dist_batch` and then `loss_linear_batch`. diff --git a/ot/batch/_quadratic.py b/ot/batch/_quadratic.py index e549c7e72..982ef7cfd 100644 --- a/ot/batch/_quadratic.py +++ b/ot/batch/_quadratic.py @@ -10,8 +10,9 @@ from ..utils import OTResult from ot.backend import get_backend -from ot.batch._linear import loss_linear_batch +from ot.batch._linear import loss_linear_batch, loss_linear_samples_batch from ot.batch._utils import bmv, bop, bregman_log_projection_batch +from ot.utils import list_to_array def tensor_batch( @@ -289,7 +290,7 @@ def loss_quadratic_samples_batch( C2, T, loss="sqeuclidean", - symmetric=None, + symmetric=True, nx=None, logits=None, recompute_const=False, @@ -351,7 +352,17 @@ def loss_quadratic_samples_batch( def loss_fugw_batch( - L, M, T, alpha=0.5, reg_marginals=1, symmetric=True, divergence="kl", nx=None + a, + b, + L, + M, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + nx=None, ): r""" Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. @@ -364,58 +375,172 @@ def loss_fugw_batch( Cost matrix between features across domains. T : array-like, shape (B, n, m) Transport plan. - alpha : float or array-like( B,) optional + alpha : float, array-like or list (B,) optional Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha a scalar it is used for all problems in the batch. - reg_marginals : float or array-like( B,) optional + reg_marginals : float array-like or list(B,) optional Marginal relaxation terms. If rho is a scalar it is used for all problems in the batch. symmetric : bool, optional Whether to use symmetric version. Default is True. divergence : string, default = "kl" Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. nx : module, optional Backend to use. Default is None. + """ + if nx is None: + nx = get_backend(T) - Examples - -------- - >>> import numpy as np - >>> from ot.batch import tensor_batch, loss_quadratic_batch - >>> # Create batch of cost matrices - >>> C1 = np.random.rand(3, 5, 5) # 3 problems, 5x5 source matrices - >>> C2 = np.random.rand(3, 4, 4) # 3 problems, 4x4 target matrices - >>> a = np.ones((3, 5)) / 5 # Uniform source distributions - >>> b = np.ones((3, 4)) / 4 # Uniform target distributions - >>> L = tensor_batch(a, b, C1, C2, loss='sqeuclidean') - >>> # Use the uniform transport plan for testing - >>> T = np.ones((3, 5, 4)) / (5 * 4) - >>> loss = loss_quadratic_batch(L, T, recompute_const=True) - >>> loss.shape - (3,) + B = T.shape[0] - See Also - -------- - ot.batch.tensor_batch : From computing the cost tensor L. - ot.batch.solve_gromov_batch : For finding the optimal transport plan T. + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_batch( + L, T, recompute_const=recompute_const, symmetric=symmetric, nx=nx + ) + + linear = loss_linear_batch(M, T, nx=nx) + + unbalanced = div_to_product_batch( + T, + a, + b, + divergence=divergence, + mass=True, + nx=nx, + ) + + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced + + +def loss_fugw_samples_batch( + a, + b, + C1, + C2, + X, + Y, + T, + alpha=0.5, + reg_marginals=1, + symmetric=True, + divergence="kl", + recompute_const=True, + metric_linear="sqeuclidean", + metric_quadratic="sqeuclidean", + logits=None, + nx=None, +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (quadratic term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + a : array-like, shape (B, n) + Source distributions. + b : array-like, shape (B, m) + Target distributions. + C1 : array-like, shape (B, n, n) or (B, n, n, d) + Source cost matrices for the quadratic term. + C2 : array-like, shape (B, m, m) or (B, n, n, d) + Target cost matrices for the quadratic term. + X : array-like, shape (B, n, d) + Samples from source distribution for the linear term + Y : array-like, shape (B, m, d) + Samples from target distribution for the linear term + T : array-like, shape (B, n, m) + Transport plan. + alpha : float or array-like or list(B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float or array-like or list(B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + recompute_const : bool, optional + Whether to recompute the constant term. Default is True. This should be set to True if T does not satisfy the marginal constraints. + metric_linear : str, optional + Metric for the linear term, 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' + metric_quadratic : str, optional + Metric to use for the quadratic term. Supported values: 'sqeuclidean', 'kl'. + Default is 'sqeuclidean'. + logits : bool, optional + For KL divergence, whether inputs are logits (unnormalized log probabilities). + If True, inputs are treated as logits. Default is None. + nx : module, optional + Backend to use. Default is None. """ if nx is None: nx = get_backend(T) - Q = loss_quadratic_batch(L, T, recompute_const=True, symmetric=symmetric, nx=nx) + B = T.shape[0] + + if isinstance(alpha, list): + alpha = list_to_array(alpha, nx=nx) + + if isinstance(reg_marginals, list): + reg_marginals = list_to_array(reg_marginals, nx=nx) + + if hasattr(alpha, "ndim") and alpha.ndim > 0: + if alpha.ndim != 1 or alpha.shape[0] != B: + raise ValueError( + f"If alpha is not a scalar, it must have shape ({B},), got {alpha.shape}" + ) + + if hasattr(reg_marginals, "ndim") and reg_marginals.ndim > 0: + if reg_marginals.ndim != 1 or reg_marginals.shape[0] != B: + raise ValueError( + f"If reg_marginals is not a scalar, it must have shape ({B},), got {reg_marginals.shape}" + ) + + quadratic = loss_quadratic_samples_batch( + a, + b, + C1, + C2, + T, + loss=metric_quadratic, + symmetric=symmetric, + nx=nx, + logits=logits, + recompute_const=recompute_const, + ) - L = loss_linear_batch(M, T, nx=nx) + linear = loss_linear_samples_batch(X, Y, T, metric=metric_linear) unbalanced = div_to_product_batch( T, - a=nx.sum(T, axis=2), - b=nx.sum(T, axis=1), + a, + b, divergence=divergence, mass=True, nx=nx, ) - return (1 - alpha) * L + alpha * Q + reg_marginals * unbalanced + return (1 - alpha) * linear + alpha * quadratic + reg_marginals * unbalanced def solve_gromov_batch( diff --git a/test/batch/test_solve_batch.py b/test/batch/test_solve_batch.py index 45a7e69fe..fc9a74492 100644 --- a/test/batch/test_solve_batch.py +++ b/test/batch/test_solve_batch.py @@ -1,4 +1,4 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary # Kilian Fatras diff --git a/test/batch/test_solve_gromov_batch.py b/test/batch/test_solve_gromov_batch.py index e0029689b..1518b41d8 100644 --- a/test/batch/test_solve_gromov_batch.py +++ b/test/batch/test_solve_gromov_batch.py @@ -1,19 +1,24 @@ -"""Tests for module bregman on OT with bregman projections""" +"""Tests for module batch""" # Author: Remi Flamary -# Kilian Fatras -# Quang Huy Tran -# Eduardo Fernandes Montesuma +# Sonia Mazelet + # # License: MIT License import numpy as np -from ot.batch import solve_gromov_batch, loss_quadratic_samples_batch +from ot.batch import ( + solve_gromov_batch, + loss_quadratic_batch, + loss_linear_batch, + loss_quadratic_samples_batch, +) from ot import solve_gromov from ot.batch._linear import dist_batch import pytest from itertools import product from ot.backend import torch +from ot.batch._quadratic import tensor_batch, loss_fugw_batch, loss_fugw_samples_batch def test_solve_gromov_batch(): @@ -133,3 +138,65 @@ def test_backend(nx): C = np.random.randn(batchsize, n, n, d) C = nx.from_numpy(C) solve_gromov_batch(C1=C, C2=C, a=None, b=None, loss="sqeuclidean", logits=False) + + +def test_fugw_loss(): + """Check that loss_fugw_batch and loss_fugw_samples_batch run without error.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + X = rng.rand(batchsize, n, d) + Y = rng.rand(batchsize, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + alpha = rng.rand() + reg_marginals = rng.rand() + + loss_fugw = loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + loss_fugw_sample = loss_fugw_samples_batch( + a, a, C1, C2, X, Y, T, alpha=alpha, reg_marginals=reg_marginals + ) + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + + alpha = rng.rand(batchsize) + reg_marginals = rng.rand(batchsize) + loss_fugw = loss_fugw_batch(a, a, L, M, T, alpha=alpha, reg_marginals=reg_marginals) + loss_fugw_sample = loss_fugw_samples_batch( + a, a, C1, C2, X, Y, T, alpha=alpha, reg_marginals=reg_marginals + ) + assert np.isfinite(loss_fugw).all() + assert np.isfinite(loss_fugw_sample).all() + + +def test_valid_fugw_loss_endpoints(): + """Check that loss_fugw_batch gives the same results as solve_gromov_batch and solve_linear_batch for alpha=0 and alpha=1.""" + batchsize = 2 + n = 4 + d = 2 + rng = np.random.RandomState(0) + C1 = rng.rand(batchsize, n, n, d) + C2 = rng.rand(batchsize, n, n, d) + M = rng.rand(batchsize, n, n) + a = np.ones((batchsize, n)) + reg_marginals = 0 + T = rng.rand(batchsize, n, n) + L = tensor_batch(a=a, b=a, C1=C1, C2=C2, loss="sqeuclidean") + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=0.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_linear = loss_linear_batch(M, T) + np.testing.assert_allclose(loss_fugw, loss_linear, atol=1e-5) + + loss_fugw = loss_fugw_batch( + a, a, L, M, T, alpha=1.0, divergence="l2", reg_marginals=reg_marginals + ) + loss_gromov = loss_quadratic_batch(L, T, recompute_const=True) + np.testing.assert_allclose(loss_fugw, loss_gromov, atol=1e-5) diff --git a/test/batch/test_solve_unbalanced_batch.py b/test/batch/test_solve_unbalanced_batch.py deleted file mode 100644 index e69de29bb..000000000 From 29e1378a8ca80a5f6b371ca055f5314460ad2211 Mon Sep 17 00:00:00 2001 From: SoniaMazelet <121769948+SoniaMaz8@users.noreply.github.com> Date: Fri, 22 May 2026 17:03:40 +0200 Subject: [PATCH 3/3] update RELEASES.md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index 0f8918cac..6ff2b2acf 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) - Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) +- Add batch FUGW loss to `ot.batch` (PR #775) #### Closed issues