diff --git a/gwinferno/distributions.py b/gwinferno/distributions.py index 40db691..4b889dd 100644 --- a/gwinferno/distributions.py +++ b/gwinferno/distributions.py @@ -18,7 +18,8 @@ def smooth(dx, x, xmin): func = jnp.exp(dx / (x - xmin) + dx / (x - xmin - dx)) s1 = jnp.where(jnp.less(x, xmin), 0, 1) s2 = jnp.where(jnp.less(x, xmin + dx) | jnp.greater_equal(x, xmin), (func + 1) ** (-1), s1) - return s2 + s3 = jnp.where(jnp.greater_equal(x, xmin + dx), 1, s2) + return s3 def logistic_function(x, L, k, x0): @@ -124,19 +125,22 @@ def truncnorm_pdf(xx, mu, sig, low, high, log=False): $$ p(x) \propto \mathcal{N}(x | \mu, \sigma)\Theta(x-x_\mathrm{min})\Theta(x_\mathrm{max}-x) $$ `log=True` makes this a log-normal distribution! + + If 'low == -jnp.inf', then return a right-truncated norm + If 'high == jnp.inf', then return a left-truncated norm """ if log: prob = jnp.exp(-jnp.power(jnp.log(xx) - mu, 2) / (2 * sig**2)) continuous_norm = 1 / (xx * sig * (2 * jnp.pi) ** 0.5) - left_tail_cdf = 0.5 * (1 + erf((jnp.log(low) - mu) / (sig * (2**0.5)))) - right_tail_cdf = 0.5 * (1 + erf((jnp.log(high) - mu) / (sig * (2**0.5)))) + left_tail_cdf = jnp.where(low > -jnp.inf, 0.5 * (1 + erf((jnp.log(low) - mu) / (sig * (2**0.5)))), 0) + right_tail_cdf = jnp.where(high < jnp.inf, 0.5 * (1 + erf((jnp.log(high) - mu) / (sig * (2**0.5)))), 1) denom = right_tail_cdf - left_tail_cdf else: prob = jnp.exp(-jnp.power(xx - mu, 2) / (2 * sig**2)) continuous_norm = 1 / (sig * (2 * jnp.pi) ** 0.5) - left_tail_cdf = 0.5 * (1 + erf((low - mu) / (sig * (2**0.5)))) - right_tail_cdf = 0.5 * (1 + erf((high - mu) / (sig * (2**0.5)))) + left_tail_cdf = jnp.where(low > -jnp.inf, 0.5 * (1 + erf((low - mu) / (sig * (2**0.5)))), 0) + right_tail_cdf = jnp.where(high < jnp.inf, 0.5 * (1 + erf((high - mu) / (sig * (2**0.5)))), 1) denom = right_tail_cdf - left_tail_cdf norm = continuous_norm / denom diff --git a/gwinferno/models/parametric/parametric.py b/gwinferno/models/parametric/parametric.py index 1bc312e..c537ce5 100644 --- a/gwinferno/models/parametric/parametric.py +++ b/gwinferno/models/parametric/parametric.py @@ -50,7 +50,14 @@ def plpeak_primary_pdf(m1, alpha, mmin, mmax, mpp, sigpp, lam, delta=None): if delta is None: return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) else: - return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) * smooth(delta, m1, mmin) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) + return ( (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) ) * smooth(delta, m1, mmin) + + +def brokenpl_twopeaks_primary_pdf(m1, alpha1, alpha2, mmin, mmax, mbreak, mpp1, sigpp1, mpp2, sigpp2, lam0, lam1, delta=None): + if delta is None: + return lam0 * broken_powerlaw_pdf(m1, alpha1, alpha2, mmin, mbreak, mmax) + lam1 * truncnorm_pdf(m1, mpp1, sigpp1, mmin, high=jnp.inf) + (1-lam0-lam1) * truncnorm_pdf(m1, mpp2, sigpp2, mmin, high=jnp.inf) + else: + return ( lam0 * broken_powerlaw_pdf(m1, alpha1, alpha2, mmin, mbreak, mmax) + lam1 * truncnorm_pdf(m1, mpp1, sigpp1, mmin, high=jnp.inf) + (1-lam0-lam1) * truncnorm_pdf(m1, mpp2, sigpp2, mmin, high=jnp.inf) ) * smooth(delta, m1, mmin) """