Skip to content
Merged
Binary file removed abcmb/__init__.pyc
Binary file not shown.
496 changes: 269 additions & 227 deletions abcmb/background.py

Large diffs are not rendered by default.

68 changes: 39 additions & 29 deletions abcmb/hyrex/helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(self, lna_axis_4Heequil, integration_spacing = 5.0e-4, Nsteps=800,
Initial redshift (default: 8000.)
z1 : float, optional
Final redshift (default: 20.)
adjoint : diffrax.adjoint
Adjoint mode for diffrax solves (static field). Defaults
to ForwardMode.
"""
self.integration_spacing = integration_spacing
self.concrete_axis_size_postSahaHe = jnp.zeros(Nsteps_postSahaHe)
Expand All @@ -71,7 +74,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024):
Parameters:
-----------
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Expand All @@ -98,7 +102,8 @@ def get_helium_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_st
Parameters:
-----------
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).
rtol : float, optional
Relative tolerance for ODE solver (default: 1e-6)
atol : float, optional
Expand Down Expand Up @@ -160,7 +165,8 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9):
lna_axis : array
Log scale factor grid
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).
threshold : float, optional
Threshold for HeIII fraction to stop calculation (default: 1e-9)

Expand All @@ -169,7 +175,7 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9):
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
BG, params = args
recomb_inputs, params = args
# Pre-allocate xe_output
xe_output = jnp.ones_like(lna_axis)*jnp.inf
lna_output = jnp.ones_like(lna_axis)*jnp.inf
Expand All @@ -183,8 +189,8 @@ def compute_xe(carry):
lna = lna_axis[iz]

# Cosmological parameters
TCMB = BG.TCMB(lna, params)
nH = BG.nH(lna, params)
TCMB = recomb_inputs.TCMB(lna)
nH = recomb_inputs.nH(lna)

# compute xHeIII
fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number
Expand Down Expand Up @@ -212,8 +218,7 @@ def stop_condition(state):
# Initial state: (xe_output, xe, iz, stop flag)
initial_state = (xe_output, lna_output, xe, iz, stop)

# Run the while loop until the stop condition is met.
# eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp
# eqx.internal.while_loop with kind='checkpointed' uses a custom_vjp
# so reverse-mode AD can traverse this dynamic-stop loop via treeverse
# checkpointing. max_steps must be a static upper bound.
final_state = eqx.internal.while_loop(
Expand All @@ -240,18 +245,19 @@ def xHeII_post_Saha(self, lna, args):
lna : float
Log scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).

Returns:
--------
float
HeII fraction (units: dimensionless)
"""
BG, params = args
recomb_inputs, params = args
fHe = params['YHe']/(1.-params['YHe'])/3.97153

TCMB = BG.TCMB(lna, params)
nH = BG.nH(lna, params)
TCMB = recomb_inputs.TCMB(lna)
nH = recomb_inputs.nH(lna)

# Saha ratio xe * xHeII / xHeI
s = 4 * 2.414194e15 * TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB) * jnp.exp(-285325. / (TCMB/cnst.kB)) / nH
Expand All @@ -271,16 +277,17 @@ def xH1_Saha(self, lna, args):
lna : float
Log scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).

Returns:
--------
float
Neutral hydrogen fraction (units: dimensionless)
"""
BG, params = args
TCMB = BG.TCMB(lna, params)
nH = BG.nH(lna, params)
recomb_inputs, params = args
TCMB = recomb_inputs.TCMB(lna)
nH = recomb_inputs.nH(lna)
xHeII = self.xHeII_post_Saha(lna, args)
s = 2.4127161187130e15* TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB)*jnp.exp(-157801.37882/(TCMB/cnst.kB))/nH
xH1 = jnp.where(s>1e5,(1.+xHeII)/s - (xHeII**2 + 3.*xHeII + 2.)/s**2,\
Expand All @@ -301,7 +308,8 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5):
starting_lna : float
Initial log scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).
threshold : float, optional
Threshold for deviation from Saha (default: 1e-5)

Expand All @@ -310,7 +318,7 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5):
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
BG, params = args
recomb_inputs, params = args
# Pre-allocate xe_output
xe_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf
lna_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf
Expand Down Expand Up @@ -353,7 +361,6 @@ def stop_condition(state):
# Initial state: (xe_output, xe, iz, stop flag)
initial_state = (xe_output, lna_output, xe, iz, stop)

# Run the while loop until the stop condition is met.
# eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp
# so reverse-mode AD can traverse this dynamic-stop loop via treeverse
# checkpointing. max_steps must be a static upper bound.
Expand Down Expand Up @@ -383,22 +390,23 @@ def helium_dxHeIIdlna(self, xe, lna, args):
lna : float
Log scale factor
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).

Returns:
--------
float
HeII recombination rate dxHeII/dlna (units: dimensionless)
"""
BG, params = args
recomb_inputs, params = args

fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number

# cosmology
#lna = -jnp.log(1+z)
TCMB = BG.TCMB(lna, params) # eV
nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3
H = BG.H(lna, params) # Hubble parameter, 1/s
TCMB = recomb_inputs.TCMB(lna) # eV
nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3
H = recomb_inputs.H(lna) # Hubble parameter, 1/s
GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s

# compute xH1 in Saha equilibrium, xHeII in post-saha
Expand Down Expand Up @@ -462,15 +470,16 @@ def xe_derivative_HeII(self, lna, state, args):
state : float
Current HeII ionization state
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).

Returns:
--------
float
Time derivative of HeII fraction (units: dimensionless)
"""

BG, params = args
recomb_inputs, params = args
#z = 1. / jnp.exp(lna) - 1.
# use xe = xHeII + (1.-xH1)
xe = state + self.xH1_Saha(lna, args)
Expand All @@ -491,7 +500,8 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K
xe0 : float
Initial ionization fraction
args : tuple
Background cosmology and cosmological parameters (BG, params)
Recombination input arrays and cosmological parameters
(recomb_inputs, params).
rtol : float, optional
Relative tolerance (default: 1e-6)
atol : float, optional
Expand All @@ -506,10 +516,10 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K
tuple
(xe_output, lna_output) - ionization fraction and log scale factor arrays
"""
BG, params = args
recomb_inputs, params = args

# Initial conditions
TCMB_init = BG.TCMB(starting_lna, params) # Initial matter temperature
TCMB_init = recomb_inputs.TCMB(starting_lna) # Initial matter temperature
initial_state = jnp.array([xe0])
term = ODETerm(self.xe_derivative_HeII)

Expand Down
Loading
Loading