A small Python package for computing empirical coverage curves of posterior samples produced by stochastic reconstruction algorithms for inverse problems (e.g. PnP-ULA, DPS, diffusion posterior samplers, MCMC).
Given a stochastic posterior sampler that, for an observation
The construction proceeds in three steps.
1. Posterior mean. From the posterior samples we form the Monte-Carlo estimate of the posterior mean,
2. Approximate credible balls. For a confidence level
and define the ball
By construction
3. Compare against the reference. A well-calibrated posterior should
actually contain the true
Plotting
-
Curve below the diagonal → posterior is over-confident (the credible
balls are too small; the true
$x$ falls outside them more often than the sampler's credibility budget would allow). -
Curve above the diagonal → posterior is under-confident (the
credible balls are too large; the true
$x$ is contained more often than expected).
Coverage can be computed in pixel space or, optionally, in the latent space
of an embedding
The package also implements the TARP diagnostic of
Lemos, Coogan et al. 2023, a global
coverage test that does not assume the posterior is unimodal or centred on
its mean. For each reference
is computed. The expected coverage probability is the empirical CDF of
A well-calibrated sampler again follows the diagonal
Coverage, TARP supports an optional embedding callable and a
pluggable distance (default L2).
coverage_plots.Coverage— the public class. Accumulates the statistic over many reference batches viaupdate(x_samples, x)and returns the curve viacompute(). Pixel-space coverage is always computed; an extra embedding-space curve is computed in parallel when anembeddingcallable is supplied.coverage_plots.run_tarp— stateless function implementing the TARP diagnostic. Takes the same(x_samples, x)shape contract asCoverageand returns(alpha, ecp)arrays ready to plot.coverage_plots.load_mnist_vae_embedding— factory that loads a small pre-trained MNIST VAE encoder and returns it as an embedding callable suitable forCoverage(embedding=...).- example/example_coverage_uniform_blur.py
— runnable end-to-end example: MNIST uniform-deblurring solved with
PnP-ULA (
deepinv), then both a coverage curve and a TARP curve plotted toexample/outputs/. Pass--use-vaeto overlay VAE-embedding curves.
pip install -e .[example]
python example/example_coverage_uniform_blur.py # pixel-space curve
python example/example_coverage_uniform_blur.py --use-vae # + VAE embeddingimport torch
from coverage_plots import Coverage
cov = Coverage(num_alphas=200) # pixel-space only
# x: (N, C, H, W) ground-truth references
# x_samples: (N, S, C, H, W) S posterior samples per reference
cov.update(x_samples, x)
alphas, coverage, coverage_emb = cov.compute(csv_path="coverage.csv")
# coverage_emb is None unless an embedding callable was passed at constructionTARP variant:
from coverage_plots import run_tarp
alpha, ecp = run_tarp(x_samples, x) # pixel-space TARP
# alpha, ecp = run_tarp(x_samples, x, embedding=phi) # in embedding spaceTensor-shape contract: references are (N, C, H, W) and posterior samples
are (N, S, C, H, W). The metric is meant to be accumulated across many
references via repeated update() calls before compute().