From 01c94986837b30d2125ace4f735d45a7b3d07e85 Mon Sep 17 00:00:00 2001 From: apaleyes-bot <287535055+apaleyes-bot@users.noreply.github.com> Date: Mon, 25 May 2026 01:33:08 +0100 Subject: [PATCH] docs: Add comprehensive docstrings to epmgp.py functions - Added detailed Google-style docstrings to min_factor() - Added comprehensive docstrings to lt_factor() - Added complete docstrings to log_relative_gauss() All functions now have full parameter documentation, return/yield value descriptions, and type hints. Co-authored-by: Andrei Paleyes <2852301+apaleyes@users.noreply.github.com> --- emukit/bayesian_optimization/epmgp.py | 48 ++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/emukit/bayesian_optimization/epmgp.py b/emukit/bayesian_optimization/epmgp.py index a7cc975b..2d4b1dd7 100644 --- a/emukit/bayesian_optimization/epmgp.py +++ b/emukit/bayesian_optimization/epmgp.py @@ -4,6 +4,12 @@ # Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +"""Expectation Propagation for Minimum in Gaussian Processes (EPMGP). + +Implements the EPMGP algorithm from Cunningham et al. (2011) for computing +the probability that each point is the minimum in a Gaussian process model. +Uses iterative expectation propagation with likelihood tempering factors. +""" import numpy as np from scipy import special @@ -74,7 +80,20 @@ def joint_min(mu: np.ndarray, var: np.ndarray, with_derivatives: bool = False) - return logP, dlogPdMu, dlogPdSigma, dlogPdMudMu -def min_factor(Mu, Sigma, k, gamma=1): +def min_factor(Mu: np.ndarray, Sigma: np.ndarray, k: int, gamma: float = 1) -> np.ndarray: + """Compute the factor for the probability of minimum using expectation propagation. + + Implements part of the EPMGP algorithm to compute factors needed for the joint + minimum probability calculation. Uses iterative expectation propagation to refine + the approximation, with up to 50 iterations until convergence. + + :param Mu: Mean values of the Gaussian distribution, shape (D,). + :param Sigma: Covariance matrix, shape (D, D). + :param k: Index of the point for which to compute the minimum factor. + :param gamma: Damping factor for the expectation propagation update (default 1.0). + :yields: Log normalization constant and derivatives (logZ, dlogZdMu, dlogZdMudMu, + dlogZdSigma). Returns -inf if convergence fails. + """ D = Mu.shape[0] logS = np.zeros((D - 1,)) # mean time first moment @@ -160,7 +179,21 @@ def min_factor(Mu, Sigma, k, gamma=1): yield dlogZdSigma -def lt_factor(s, l, M, V, mp, p, gamma): +def lt_factor(s: int, l: int, M: np.ndarray, V: np.ndarray, mp: float, p: float, gamma: float) -> tuple: + """Compute a single likelihood term factor for expectation propagation update. + + Updates the Gaussian approximation by incorporating the constraint s < l. + + :param s: Index of the first element in the constraint (s < l). + :param l: Index of the second element in the constraint (s < l). + :param M: Current mean vector, shape (D,). + :param V: Current covariance matrix, shape (D, D). + :param mp: Current message precision term. + :param p: Current precision message. + :param gamma: Damping factor controlling the update step size. + :returns: Tuple of (Mnew, Vnew, pnew, mpnew, logS, d) with updated parameters + and convergence difference d. Returns NaN for d if numerical issues occur. + """ cVc = (V[l, l] - 2 * V[s, l] + V[s, s]) / 2.0 Vc = (V[:, l] - V[:, s]) / sq2 cM = (M[l] - M[s]) / sq2 @@ -224,9 +257,14 @@ def lt_factor(s, l, M, V, mp, p, gamma): return Mnew, Vnew, pnew, mpnew, logS, d -def log_relative_gauss(z): - """ - log_relative_gauss +def log_relative_gauss(z: float) -> tuple: + """Compute log(phi(z) / Phi(z)) where phi is the standard normal PDF and Phi is the CDF. + + Handles extreme values to avoid numerical overflow/underflow. + + :param z: Input value to the standard normal distribution. + :returns: Tuple of (e, logPhi, exit_flag) with the ratio, log CDF, and exit status. + exit_flag is -1 for z < -6 (lower tail), 1 for z > 6 (upper tail), 0 otherwise. """ if z < -6: return 1, -1.0e12, -1