diff --git a/abcmb/__init__.pyc b/abcmb/__init__.pyc deleted file mode 100644 index a3ad9a0..0000000 Binary files a/abcmb/__init__.pyc and /dev/null differ diff --git a/abcmb/background.py b/abcmb/background.py index 0b5c18c..ace9d43 100644 --- a/abcmb/background.py +++ b/abcmb/background.py @@ -8,6 +8,7 @@ from .hyrex.array_with_padding import array_with_padding from .hyrex import recomb_functions +from .hyrex.hyrex import RecombInputs from . import ABCMBTools as tools from . import constants as cnst @@ -15,53 +16,32 @@ file_dir = os.path.dirname(__file__) config.update("jax_enable_x64", True) -class Background(eqx.Module): + +class BackgroundPreRecomb(eqx.Module): """ - Background cosmology module for cosmological calculations. + Pre-recombination background-cosmology object. - Computes background quantities including Hubble parameter, conformal time, - recombination history, and optical depth evolution. + Holds everything HyRex needs to run on CPU: (conformal-time tabulation, + species list, and HyRex input arrays via ``RecombInputs`` object). Attributes: ----------- species_list : tuple A list of all fluids in the cosmology lna_tau_tab : jnp.array - Log scale factor axis used to tabulate conformal time - tau_tab : jnp.array - Tabulated conformal time. - tau0 : float + Log scale factor axis used to tabulate conformal time + tau_tab : jnp.array + Tabulated conformal time. + tau0 : float Conformal time today in Mpc. - xe_tab : array_with_padding - Tabulated free electron fraction xe during recombination - lna_xe_tab : array_with_padding - Log scale factor axis corresponding to tabulated xe values. - Tm_tab : array_with_padding - Tabulated matter temperature Tm during recombination - lna_Tm_tab : array_with_padding - Log scale factor axis corresponding to tabulated Tm values. - kappa_func : diffrax.solution - Visibility function - z_reion : float - Redshift of hydrogen reionization in the CAMB parameterization. - tau_reion : float - Optical depth to reionization - lna_rec : float - Log scale factor of recombination - rA_rec : float - Comoving angular diameter distance at recombination in Mpc - rs_d : float - Sound horizon at baryon decoupling in Mpc - z_d : float - Redshift of baryon decoupling - lna_transfer_start : float - Log scale factor at which to begin integrating transfer functions. - lna_visibility_stop : float - Log scale factor at which to stop integrating T1, T2, and E sources due to small visibility functions. - Only used for l<400. - - Recombination Unrelated Methods: - -------------------------------- + recomb_inputs : RecombInputs + Bundle of background quantities (TCMB, nH, H) sampled on + ``RecModel.lna_axis_full``; consumed by HyRex. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). + + Methods: + -------- rho_tot : Compute total energy density (units: eV cm^{-3}) P_tot : Compute total pressure (units: eV cm^{-3}) H : Compute Hubble parameter (units: s^{-1}) @@ -69,52 +49,27 @@ class Background(eqx.Module): aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) tau : Compute conformal time (units: Mpc) - z_d : Compute baryon decoupling redshift (units: dimensionless) - rs_d : Compute sound horizon at decoupling (units: Mpc) - - Recombination Related Methods: - ------------------------------ - xe : Compute free electron fraction (units: dimensionless) - Tm : Compute matter temperature (units: eV) - mu_bar : Compute mean molecular mass (units: eV) - cs2 : Compute baryon sound speed squared (units: dimensionless) nH : Compute hydrogen number density (units: cm^{-3}) TCMB : Compute CMB temperature (units: eV) - tau_c : Compute Thomson scattering time (units: Mpc) - kappa : Compute optical depth (units: dimensionless) - visibility : Compute visibility function (units: Mpc^{-1}) + R_ratio_lna : Compute baryon drag ratio (units: dimensionless) """ - # params : dict species_list : tuple - + lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) # Axis for tabulating conformal time. - tau_tab : jnp.array # Tabulated conformal time. + tau_tab : jnp.array # Tabulated conformal time. tau0 : float # Conformal time today - # Recombination related - xe_tab : "array_with_padding" - lna_xe_tab : "array_with_padding" - Tm_tab : "array_with_padding" - lna_Tm_tab : "array_with_padding" - kappa_func : "diffrax.solution" - z_reion : float - tau_reion : float - lna_rec : float - rA_rec : float # Comoving angular diameter distance at recombination. - - # Transfer related - lna_transfer_start : float # Time where transfer functions start integrating. - lna_visibility_stop : float # Time to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400 + recomb_inputs : "RecombInputs" adjoint : "diffrax.adjoint" = eqx.field(static=True) - def __init__(self, params, species_list, RecModel, ReionModel, adjoint=ForwardMode): + def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): """ - Initialize Background cosmology module. + Initialize pre-recombination background. - Computes and tabulates conformal time, recombination history, - optical depth, and key cosmological epochs. + Tabulates conformal time and builds the RecombInputs object for + HyRex. Parameters: ----------- @@ -122,54 +77,26 @@ def __init__(self, params, species_list, RecModel, ReionModel, adjoint=ForwardMo Cosmological parameters species_list : tuple List of fluid species for energy density calculations - RecModel : callable + RecModel : hyrex.recomb_model Recombination module for computing xe and Tm histories - ReionModel : callable - Reionization module for computing the xe correction. + adjoint : diffrax.adjoint, optional + Adjoint class for diffrax solves (default: ForwardMode) """ self.adjoint = adjoint - # self.params = params self.species_list = species_list self.tau_tab = self._tabulate_conformal_time(params) - self.tau0 =self.tau(0.) - - self.run_recombination(params, RecModel, ReionModel) - - # Find approximate maximum of visibility function. - lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling should have happened at some time in this interval. - vis_vals = vmap(self.visibility,in_axes=[0,None])(lna_vals, params) - self.lna_rec = lna_vals[jnp.argmax(vis_vals)] - self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)] - self.rA_rec = self.tau0 - self.tau(self.lna_rec) - - # Find approximate early time when aH x tau_c = 0.008 - lna_vals = jnp.linspace(-15.0, -6.0, 5000) - aH_tau_c_vals = vmap(self.aH,in_axes=[0,None])(lna_vals,params)*self.tau_c(lna_vals,params) - self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)] - - def run_recombination(self, params, RecModel, ReionModel): - """ - Call HyRex to get the primary recombination history, then patch on tanh reionization. - For now assumes the user will specify tau_reio. - """ - ### RECOMBINATION ### - - # Run hyrex to tabulate recombination output - xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = RecModel((self,params)) - - ### REIONIZATION ### - reion_model = ReionModel(self, params) - self.z_reion = reion_model.z_reion - self.tau_reion = reion_model.tau_reion - - xe_reion_correction = reion_model.xe_reion(self.lna_xe_tab.arr, self.z_reion, params) - xe_full_arr = xe_reion_correction + xe.arr - xe_full = array_with_padding(xe_full_arr) - - self.xe_tab = xe_full - - self.kappa_func = self._tabulate_optical_depth(params) + self.tau0 = self.tau(0.) + + # Bundle the background quantities HyRex needs onto its sampling + # grid (acccording to the input RecModel) + lna_axis = RecModel.lna_axis_full + self.recomb_inputs = RecombInputs( + lna_grid = lna_axis, + TCMB_arr = vmap(self.TCMB, in_axes=[0, None])(lna_axis, params), + nH_arr = vmap(self.nH, in_axes=[0, None])(lna_axis, params), + H_arr = vmap(self.H, in_axes=[0, None])(lna_axis, params), + ) def rho_tot(self, lna, params): """ @@ -188,16 +115,12 @@ def rho_tot(self, lna, params): -------- float Total energy density (units: eV cm^{-3}) - - Notes: - ------ - User should not modify this function without careful consideration. """ rho_tot = 0. for i in range(len(self.species_list)): rho_tot += self.species_list[i].rho(lna, params) return rho_tot - + def P_tot(self, lna, params): """ Compute total pressure. @@ -215,10 +138,6 @@ def P_tot(self, lna, params): -------- float Total pressure (units: eV cm^{-3}) - - Notes: - ------ - User should not modify this function without careful consideration. """ P_tot = 0. for i in range(len(self.species_list)): @@ -266,7 +185,7 @@ def aH(self, lna, params): Conformal Hubble parameter (units: Mpc^{-1}) """ return jnp.exp(lna)*self.H(lna, params) / cnst.c_Mpc_over_s - + def aH_prime(self, lna, params): """ Compute derivative of conformal Hubble parameter. @@ -307,9 +226,8 @@ def d2adtau2_over_a(self, lna, params): float Second derivative of scale factor (units: Mpc^{-2}) """ - return self.aH(lna, params)**2 + self.aH(lna, params)*self.aH_prime(lna, params) - + def _dtau_dlna(self, lna, y, args): """ Compute derivative of conformal time with respect to ln(a). @@ -426,9 +344,210 @@ def tau(self, lna): radiation approximation, and starting diffrax integration at the early time with appropriate initial conditions. """ - return tools.fast_interp(lna, self.lna_tau_tab[0], self.lna_tau_tab[-1], self.tau_tab) + def nH(self, lna, params): + """ + Compute hydrogen number density. + + Calculates total hydrogen number density at given redshift. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + Hydrogen number density (units: cm^{-3}) + """ + return (1-params['YHe']) * 3. * params['omega_b'] * cnst.H0_over_h**2 / 8 / jnp.pi / cnst.G / cnst.mH / jnp.exp(lna)**3 + + def TCMB(self, lna, params): + """ + Compute CMB temperature. + + Calculates CMB temperature at given redshift using T ∝ 1/a scaling. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + CMB temperature (units: eV) + """ + return params['TCMB0'] / jnp.exp(lna) + + def R_ratio_lna(self, lna, params): + """ + Compute baryon drag ratio. + + Calculates R = 3ρ_b/(4ρ_γ), the ratio of baryon to photon + energy densities that appears in baryon drag calculations. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + Baryon drag ratio (units: dimensionless) + """ + rho_b = 0. + rho_g = 0. + + for s in self.species_list: + if s.name == "Photon": + rho_g += s.rho(lna, params) + elif s.name == "Baryon": + rho_b += s.rho(lna, params) + + return 3. * rho_b / (4 * rho_g) + + +class Background(BackgroundPreRecomb): + """ + Full Background cosmology module for cosmological calculations. + + Inherits all cosmology fields and methods from ``BackgroundPreRecomb``. + Construction takes a ``BackgroundPreRecomb`` and the recombination output + from HyRex, then applies reionization and integrates the optical depth. + + This factorization allows HyRex to always run on CPU (its faster backend). + + Attributes: + ----------- + species_list : tuple + A list of all fluids in the cosmology + lna_tau_tab : jnp.array + Log scale factor axis used to tabulate conformal time + tau_tab : jnp.array + Tabulated conformal time. + tau0 : float + Conformal time today in Mpc. + recomb_inputs : RecombInputs + Bundle of background quantities (TCMB, nH, H) sampled on + ``RecModel.lna_axis_full``; consumed by HyRex. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). + xe_tab : array_with_padding + Tabulated free electron fraction xe with reionization correction. + lna_xe_tab : array_with_padding + Log scale factor axis corresponding to tabulated xe values. + Tm_tab : array_with_padding + Tabulated matter temperature Tm during recombination. + lna_Tm_tab : array_with_padding + Log scale factor axis corresponding to tabulated Tm values. + kappa_func : diffrax.solution + Optical depth function (dense interpolation). + z_reion : float + Redshift of hydrogen reionization in the CAMB parameterization. + tau_reion : float + Optical depth to reionization. + lna_rec : float + Log scale factor of recombination. + rA_rec : float + Comoving angular diameter distance at recombination in Mpc. + lna_transfer_start : float + Log scale factor at which to begin integrating transfer functions. + lna_visibility_stop : float + Log scale factor at which to stop integrating T1, T2, and E sources + due to small visibility functions. Only used for l<400. + + Methods: + -------- + rho_tot : Compute total energy density (units: eV cm^{-3}) + P_tot : Compute total pressure (units: eV cm^{-3}) + H : Compute Hubble parameter (units: s^{-1}) + aH : Compute conformal Hubble parameter (units: Mpc^{-1}) + aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) + d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) + tau : Compute conformal time (units: Mpc) + xe : Compute free electron fraction (units: dimensionless) + Tm : Compute matter temperature (units: eV) + tau_c : Compute Thomson scattering time (units: Mpc) + expmkappa : Compute exp(-kappa) (units: dimensionless) + visibility : Compute visibility function (units: Mpc^{-1}) + z_d : Compute baryon decoupling redshift (units: dimensionless) + rs_d : Compute sound horizon at decoupling (units: Mpc) + """ + + xe_tab : "array_with_padding" + lna_xe_tab : "array_with_padding" + Tm_tab : "array_with_padding" + lna_Tm_tab : "array_with_padding" + kappa_func : "diffrax.solution" + z_reion : float + tau_reion : float + lna_rec : float + rA_rec : float # Comoving angular diameter distance at recombination. + + # Transfer related + lna_transfer_start : float # Time where transfer functions start integrating. + lna_visibility_stop : float # Time to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400 + + def __init__(self, pre_BG, recomb_output, params, ReionModel): + """ + Initialize Background cosmology module. + + Consolidates pre-recombination and recombination elements of background cosmology. + + Parameters: + ----------- + pre_BG : BackgroundPreRecomb + Output of the pre-recomb stage; provides species_list, + tau_tab, tau0, recomb_inputs, adjoint. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)`` quadruple + params : dict + Cosmological parameters. + ReionModel : callable + Reionization module for computing the xe correction. + """ + # Copy pre-recomb fields onto self. + self.adjoint = pre_BG.adjoint + self.species_list = pre_BG.species_list + self.tau_tab = pre_BG.tau_tab + self.tau0 = pre_BG.tau0 + self.recomb_inputs = pre_BG.recomb_inputs + + # Unpack HyRex output and apply reionization. + xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = recomb_output + + reion_model = ReionModel(self, params) + self.z_reion = reion_model.z_reion + self.tau_reion = reion_model.tau_reion + + xe_reion_correction = reion_model.xe_reion(self.lna_xe_tab.arr, self.z_reion, params) + xe_full_arr = xe_reion_correction + xe.arr + self.xe_tab = array_with_padding(xe_full_arr) + + self.kappa_func = self._tabulate_optical_depth(params) + + # Find approximate maximum of visibility function. + lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling falls in here. + vis_vals = vmap(self.visibility, in_axes=[0, None])(lna_vals, params) + self.lna_rec = lna_vals[jnp.argmax(vis_vals)] + self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)] + self.rA_rec = self.tau0 - self.tau(self.lna_rec) + + # Find approximate early time when aH x tau_c = 0.008 + lna_vals = jnp.linspace(-15.0, -6.0, 5000) + aH_tau_c_vals = vmap(self.aH, in_axes=[0, None])(lna_vals, params) * self.tau_c(lna_vals, params) + self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)] + ### RECOMBINATION RELATED ### def xe(self, lna): @@ -447,15 +566,13 @@ def xe(self, lna): -------- float Free electron fraction (units: dimensionless) - + Notes: ------ The logic flow is equivalent to: if lna < self.lna_xe_tab.arr[0]: return self.xe_tab[0] - elif lna > self.lna_xe_tab.lastval: return self.xe_tab.lastval - else: return jnp.interp(lna, self.lna_xe_tab, self.xe_tab) """ return jnp.where( @@ -524,46 +641,6 @@ def Tm(self, lna, params): ) ) - def nH(self, lna, params): - """ - Compute hydrogen number density. - - Calculates total hydrogen number density at given redshift. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - Hydrogen number density (units: cm^{-3}) - """ - return (1-params['YHe']) * 3. * params['omega_b'] * cnst.H0_over_h**2 / 8 / jnp.pi / cnst.G / cnst.mH / jnp.exp(lna)**3 - - def TCMB(self,lna, params): - """ - Compute CMB temperature. - - Calculates CMB temperature at given redshift using T ∝ 1/a scaling. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - CMB temperature (units: eV) - """ - return params['TCMB0'] / jnp.exp(lna) - def tau_c(self, lna, params): """ Compute Thomson scattering time. @@ -603,7 +680,7 @@ def _tabulate_optical_depth(self, params): -------- array Tabulated optical depth values (units: dimensionless) - + Notes: ------ Also computes time derivative of optical depth, which is the @@ -615,13 +692,13 @@ def _tabulate_optical_depth(self, params): adjoint=self.adjoint() sol = diffeqsolve( term, - solver=Kvaerno5(), + solver=Kvaerno5(), stepsize_controller=stepsize_controller, - t0=0., - t1=-10., - dt0=-1.e-3, + t0=0., + t1=-10., + dt0=-1.e-3, max_steps=2048, - y0=0.0, + y0=0.0, saveat=SaveAt(dense=True), adjoint=adjoint ) @@ -629,7 +706,7 @@ def _tabulate_optical_depth(self, params): def expmkappa(self, lna): """ - Compute optical depth. + Compute exp(-optical depth). Interpolates from pre-tabulated optical depth history. @@ -641,9 +718,8 @@ def expmkappa(self, lna): Returns: -------- float - Optical depth (units: dimensionless) + exp(-(optical depth)) (units: dimensionless) """ - return jnp.where( lna < -10., 0., @@ -704,7 +780,6 @@ def find_z_at_kappad_equals_one(self, z, kappa_d): z_sorted = z[idx] kappa_d_sorted = jnp.abs(kappa_d)[idx] - # interpolate z_d = jnp.interp(1.0, kappa_d_sorted, z_sorted) return z_d @@ -731,38 +806,6 @@ def interp_rs_at_z(self, z_bg, r_s, z_d): rs_sorted = r_s[idx] return jnp.interp(z_d, z_sorted, rs_sorted) - def R_ratio_lna(self, lna, params): - """ - Compute baryon drag ratio. - - Calculates R = 3ρ_b/(4ρ_γ), the ratio of baryon to photon - energy densities that appears in baryon drag calculations. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - Baryon drag ratio (units: dimensionless) - """ - - rho_b = 0. - rho_g = 0. - - for s in self.species_list: - if s.name == "Photon": - rho_g += s.rho(lna, params) - elif s.name == "Baryon": - rho_b += s.rho(lna, params) - - return 3. * rho_b / (4 * rho_g) - - @jax.named_scope("tabulate kappa d") def _tabulate_kappa_d(self, params): """ Tabulate baryon optical depth. @@ -784,23 +827,22 @@ def _tabulate_kappa_d(self, params): term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() - + solution = diffeqsolve( term, - solver=Tsit5(), # Kvaerno5 is just slower but gives same result + solver=Tsit5(), # Kvaerno5 is just slower but gives same result stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[-1], # Initial x value (~0 in this case) - t1=self.lna_tau_tab[0], # Final x value (smallest x value) - dt0=-1e-3, # Initial step size + t0=self.lna_tau_tab[-1], # Initial x value (~0 in this case) + t1=self.lna_tau_tab[0], # Final x value (smallest x value) + dt0=-1e-3, max_steps=2048, - y0=0.0, # Initial value tau(x=0) = 0 + y0=0.0, # Initial value tau(x=0) = 0 saveat=SaveAt(ts=self.lna_tau_tab[::-1]), # Save at all points in x, reverse order since integrating backwards adjoint=adjoint ) result = solution.ys[::-1] return result - @jax.named_scope("tabulate rs") def _tabulate_rs(self, params): """ Tabulate sound horizon evolution. @@ -818,19 +860,19 @@ def _tabulate_rs(self, params): array Tabulated sound horizon values (units: Mpc) """ - # initial condition assuming cs**2 = 1/3 at early times + # initial condition assuming cs**2 = 1/3 at early times rs0 = 1./jnp.sqrt(3) / (self.aH( self.lna_tau_tab[0], params )) integrand = lambda lna, y, args: 1./jnp.sqrt(3*(1+self.R_ratio_lna(lna, params))) / (self.aH(lna, params)) term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() - + solution = diffeqsolve( term, solver=Tsit5(), stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[0], # reversed direction since I know rs at early times + t0=self.lna_tau_tab[0], # reversed direction since I know rs at early times t1=self.lna_tau_tab[-1], dt0=1e-3, max_steps=2048, @@ -840,7 +882,6 @@ def _tabulate_rs(self, params): ) result = solution.ys return result - def z_d(self, params): """ @@ -879,14 +920,15 @@ def rs_d(self, params): """ return self.interp_rs_at_z(1/jnp.exp(self.lna_tau_tab) - 1, self._tabulate_rs(params), self.z_d(params)) + class ReionizationModel(eqx.Module): """ Object for computing the reionization correction to the free electron fraction. - Provides the base methods + Provides the base methods xe_reion : calculates the tanh electron fraction correction at redshifts lna, given z_reion and params tau_reion_fn : calculates the optical depth to reionization. - + At the moment we only support the CAMB tanh parameterization, but we need different approaches based on whether the use inputs the optical depth tau_reion or the reionization redshift z_reion. @@ -898,7 +940,7 @@ class ReionizationModel(eqx.Module): def xe_reion(self, lna, z_reion, params): """ Passing in an lna array should get you the correct tanh patching based on the - reionization parameter. + reionization parameter. """ fHe = params['YHe'] / 4 / (1-params['YHe']) z = 1/jnp.exp(lna) - 1 @@ -919,7 +961,7 @@ def xe_reion(self, lna, z_reion, params): def tau_reion_fn(self, z_reion, BG, params): lna_axis = jnp.linspace(-5., 0., 2000) xe_reion_correction = self.xe_reion(lna_axis, z_reion, params) - # Free electron number density belonging only to reionized hydrogen. + # Free electron number density belonging only to reionized hydrogen. ne = BG.nH(lna_axis, params) * xe_reion_correction Gamma = jnp.exp(lna_axis)*ne*cnst.thomson_xsec*cnst.c/cnst.c_Mpc_over_s aH = BG.aH(lna_axis, params) @@ -957,4 +999,4 @@ def tau_target_fn(z_reion, args): solver = optx.Newton(rtol=1e-5, atol=1e-5) sol = optx.root_find(tau_target_fn, solver, 7.6, params.get("tau_reion", jnp.array(0.05430842))) self.z_reion = sol.value - self.tau_reion = params.get("tau_reion", jnp.array(0.05430842)) \ No newline at end of file + self.tau_reion = params.get("tau_reion", jnp.array(0.05430842)) diff --git a/abcmb/hyrex/helium.py b/abcmb/hyrex/helium.py index 51d96c1..d8f5ed5 100644 --- a/abcmb/hyrex/helium.py +++ b/abcmb/hyrex/helium.py @@ -55,6 +55,9 @@ def __init__(self, lna_axis_4Heequil, integration_spacing = 5.0e-4, Nsteps=800, Initial redshift (default: 8000.) z1 : float, optional Final redshift (default: 20.) + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.concrete_axis_size_postSahaHe = jnp.zeros(Nsteps_postSahaHe) @@ -71,7 +74,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -98,7 +102,8 @@ def get_helium_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_st Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -160,7 +165,8 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9): lna_axis : array Log scale factor grid args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for HeIII fraction to stop calculation (default: 1e-9) @@ -169,7 +175,7 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Pre-allocate xe_output xe_output = jnp.ones_like(lna_axis)*jnp.inf lna_output = jnp.ones_like(lna_axis)*jnp.inf @@ -183,8 +189,8 @@ def compute_xe(carry): lna = lna_axis[iz] # Cosmological parameters - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) # compute xHeIII fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number @@ -212,8 +218,7 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. - # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp + # eqx.internal.while_loop with kind='checkpointed' uses a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse # checkpointing. max_steps must be a static upper bound. final_state = eqx.internal.while_loop( @@ -240,18 +245,19 @@ def xHeII_post_Saha(self, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float HeII fraction (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args fHe = params['YHe']/(1.-params['YHe'])/3.97153 - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) # Saha ratio xe * xHeII / xHeI s = 4 * 2.414194e15 * TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB) * jnp.exp(-285325. / (TCMB/cnst.kB)) / nH @@ -271,16 +277,17 @@ def xH1_Saha(self, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float Neutral hydrogen fraction (units: dimensionless) """ - BG, params = args - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + recomb_inputs, params = args + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) xHeII = self.xHeII_post_Saha(lna, args) s = 2.4127161187130e15* TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB)*jnp.exp(-157801.37882/(TCMB/cnst.kB))/nH xH1 = jnp.where(s>1e5,(1.+xHeII)/s - (xHeII**2 + 3.*xHeII + 2.)/s**2,\ @@ -301,7 +308,8 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5): starting_lna : float Initial log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for deviation from Saha (default: 1e-5) @@ -310,7 +318,7 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Pre-allocate xe_output xe_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf lna_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf @@ -353,7 +361,6 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse # checkpointing. max_steps must be a static upper bound. @@ -383,22 +390,23 @@ def helium_dxHeIIdlna(self, xe, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float HeII recombination rate dxHeII/dlna (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number # cosmology #lna = -jnp.log(1+z) - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s # compute xH1 in Saha equilibrium, xHeII in post-saha @@ -462,7 +470,8 @@ def xe_derivative_HeII(self, lna, state, args): state : float Current HeII ionization state args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -470,7 +479,7 @@ def xe_derivative_HeII(self, lna, state, args): Time derivative of HeII fraction (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args #z = 1. / jnp.exp(lna) - 1. # use xe = xHeII + (1.-xH1) xe = state + self.xH1_Saha(lna, args) @@ -491,7 +500,8 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-6) atol : float, optional @@ -506,10 +516,10 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB_init = BG.TCMB(starting_lna, params) # Initial matter temperature + TCMB_init = recomb_inputs.TCMB(starting_lna) # Initial matter temperature initial_state = jnp.array([xe0]) term = ODETerm(self.xe_derivative_HeII) diff --git a/abcmb/hyrex/hydrogen.py b/abcmb/hyrex/hydrogen.py index 6a335e4..b9d914a 100644 --- a/abcmb/hyrex/hydrogen.py +++ b/abcmb/hyrex/hydrogen.py @@ -70,6 +70,9 @@ def __init__(self, xe_4He, lna_4He, lna_end, last_4He_lna, twog_redshift, integr Maximum number of integration steps (default: 800) swift : array, optional SWIFT correction function tabulation + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.swift = swift @@ -93,7 +96,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -122,7 +126,8 @@ def get_hydrogen_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_ Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -195,7 +200,8 @@ def post_Saha_expansion(self, starting_lna, args, threshold=1e-5): starting_lna : float Initial log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for deviation from Saha (default: 1e-5) @@ -204,10 +210,10 @@ def post_Saha_expansion(self, starting_lna, args, threshold=1e-5): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB = BG.TCMB(starting_lna, params) - nH = BG.nH(starting_lna, params) + TCMB = recomb_inputs.TCMB(starting_lna) + nH = recomb_inputs.nH(starting_lna) xe0, _ = recomb_functions.xe_Saha(TCMB, nH) # Saha equilibrium is our intial condition # Pre-allocate xe_output @@ -223,9 +229,9 @@ def compute_xe(carry): lna = starting_lna + iz*self.integration_spacing # Cosmological parameters - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) - H = BG.H(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) + H = recomb_inputs.H(lna) # Saha equilibrium for xe xe_Saha, s = recomb_functions.xe_Saha(TCMB, nH) @@ -256,11 +262,9 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse - # checkpointing. max_steps must be a static upper bound; the output - # axis size serves that role here. + # checkpointing. max_steps must be a static upper bound final_state = eqx.internal.while_loop( stop_condition, compute_xe, initial_state, max_steps=self.concrete_axis_size.size, @@ -287,19 +291,20 @@ def xe_derivative_twophoton(self, lna, xe, args): xe : float Current ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float Time derivative dxe/dlna (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args x1s = 1. - xe # fraction of neutral hydrogen - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s Tm = TCMB * (1.-H/GammaC) @@ -325,7 +330,8 @@ def solve_emla_twophoton(self, lna_axis_init, lna_axis_final, xe0, args, rtol=1e xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-6) atol : float, optional @@ -340,10 +346,10 @@ def solve_emla_twophoton(self, lna_axis_init, lna_axis_final, xe0, args, rtol=1e tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB_init = BG.TCMB(lna_axis_init, params) # Initial CMB temperature + TCMB_init = recomb_inputs.TCMB(lna_axis_init) # Initial CMB temperature initial_state = xe0 term = ODETerm(self.xe_derivative_twophoton) @@ -392,7 +398,8 @@ def xe_tm_derivative(self, lna, state, args): state : array Current state [xe, Tm] args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -400,11 +407,11 @@ def xe_tm_derivative(self, lna, state, args): Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state - BG, params = args + recomb_inputs, params = args - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s Delta = 0.0 @@ -427,7 +434,8 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional @@ -447,15 +455,15 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste t0 = lna0 t1 = jnp.inf - BG,params = args + recomb_inputs, params = args # need to go at least twice max_steps to make sure we catch the t1 we actually want t_arr = jnp.linspace(t0+self.integration_spacing, t0+2*max_steps*self.integration_spacing, 2*max_steps) save_at = SaveAt(ts=t_arr) - TCMB_init = BG.TCMB(t0, params) # Initial CMB temperature - Tm0 = TCMB_init * (1.-BG.H(t0, params)/recomb_functions.Gamma_compton(xe0, TCMB_init, params['YHe'])) + TCMB_init = recomb_inputs.TCMB(t0) # Initial CMB temperature + Tm0 = TCMB_init * (1.-recomb_inputs.H(t0)/recomb_functions.Gamma_compton(xe0, TCMB_init, params['YHe'])) initial_state = jnp.array([xe0, Tm0]) term = ODETerm(self.xe_tm_derivative) @@ -464,7 +472,7 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste def temperature_check(t, y, args, **kwargs): lna = t _, Tm = y - TCMB = BG.TCMB(lna, params) + TCMB = recomb_inputs.TCMB(lna) TR_MIN = recomb_functions.TR_MIN # Minimum Tcmb in eV T_RATIO_MIN = recomb_functions.T_RATIO_MIN # Minimum Tratio ratio = jnp.minimum(Tm / TCMB, TCMB / Tm) @@ -541,7 +549,8 @@ def get_current_correction_func(self, TCMB, args): TCMB : float CMB temperature (units: eV) args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -553,7 +562,7 @@ def get_current_correction_func(self, TCMB, args): omega_cb_fid = 0.14175 Neff_fid = 3.046 - BG, params = args + recomb_inputs, params = args # For the user inputed cosmology currently scanned over. @@ -666,7 +675,8 @@ def TLA_xe_deriv(self, lna, state, args): state : array Current state [xe, Tm] args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -674,12 +684,12 @@ def TLA_xe_deriv(self, lna, state, args): Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state - BG, params = args + recomb_inputs, params = args xHII = xe # since everything else is fully recombined - nH = BG.nH(lna, params) - TCMB = BG.TCMB(lna, params) # eV - H = BG.H(lna, params) + nH = recomb_inputs.nH(lna) + TCMB = recomb_inputs.TCMB(lna) # eV + H = recomb_inputs.H(lna) C = recomb_functions.peebles_C(jnp.exp(-lna) - 1.0, xHII, H, nH, args) alpha = recomb_functions.alpha_H(Tm) @@ -709,7 +719,8 @@ def solve_TLA(self, lna0, xe0, Tm0, args, rtol=1e-7, atol=1e-9, solver=Kvaerno3( Tm0: float Starting matter temperature args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional diff --git a/abcmb/hyrex/hyrex.py b/abcmb/hyrex/hyrex.py index 8c4b819..90b3857 100644 --- a/abcmb/hyrex/hyrex.py +++ b/abcmb/hyrex/hyrex.py @@ -9,8 +9,86 @@ from .hydrogen import hydrogen_model from .helium import helium_model from .array_with_padding import array_with_padding +from ..ABCMBTools import fast_interp config.update("jax_enable_x64", True) + +class RecombInputs(eqx.Module): + """ + Bundle of pre-recombination background quantities sampled on a + fixed lna grid for computing recombination. + + Attributes + ---------- + lna_grid : jnp.array + Uniform log scale-factor sampling axis. (units: dimensionless) + TCMB_arr : jnp.array + Photon-bath temperature TCMB(lna) (units: eV) + nH_arr : jnp.array + Hydrogen number density nH(lna) (units: cm^-3) + H_arr : jnp.array + Hubble parameter H(lna) (units: s^-1) + + Methods + ------- + TCMB : Linear interpolation of CMB temperature over lna (units: eV) + nH : Linear interpolation of hydrogen number density over lna (units: cm^-3) + H : Linear interpolation of Hubble over lna (units: s^-1) + """ + + lna_grid : jnp.array + TCMB_arr : jnp.array + nH_arr : jnp.array + H_arr : jnp.array + + def TCMB(self, lna): + """ + Linearly interpolate CMB temperature at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + CMB temperature TCMB(lna) (units: eV). + """ + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.TCMB_arr) + + def nH(self, lna): + """ + Linearly interpolate hydrogen number density at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + Hydrogen number density nH(lna) (units: cm^-3). + """ + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.nH_arr) + + def H(self, lna): + """ + Linearly interpolate Hubble parameter at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + Hubble parameter H(lna) (units: s^-1). + """ + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.H_arr) + class recomb_model(eqx.Module): """ Complete recombination model implementation. @@ -49,6 +127,9 @@ def __init__(self, integration_spacing = 5.0e-4, z0=8000., z1=0., adjoint = Forw Initial redshift (default: 8000.) z1 : float, optional Final redshift (default: 0.) + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.adjoint = adjoint @@ -69,7 +150,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -86,7 +168,7 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): with reionization, log scale factor, matter temperature, and temperature grid """ return self.get_history(args, rtol, atol, solver, max_steps) - + def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute complete recombination and reionization history. @@ -97,7 +179,8 @@ def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=102 Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -115,7 +198,7 @@ def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=102 matter temperature, and temperature grid """ - BG, params = args + recomb_inputs, params = args lna_axis_4Heequil = self.lna_axis_full[self.idx_4He_equil] xe_4He, lna_4He = helium_model(lna_axis_4Heequil, adjoint=self.adjoint)(args) diff --git a/abcmb/hyrex/recomb_functions.py b/abcmb/hyrex/recomb_functions.py index 13a81d5..383da61 100644 --- a/abcmb/hyrex/recomb_functions.py +++ b/abcmb/hyrex/recomb_functions.py @@ -13,14 +13,10 @@ #Tabulated values of 2s-2p transition rates to interpolate. alpha_tab = jnp.array(np.loadtxt(file_dir+"/tabs/Alpha_inf.dat")) -try: - gpus = devices('gpu') - R_tab = device_put( - R_tab, device=gpus[0]) - alpha_tab = device_put( - alpha_tab, device=gpus[0]) -except: - pass +# pin to CPU +cpus = devices('cpu') +R_tab = device_put(R_tab, device=cpus[0]) +alpha_tab = device_put(alpha_tab, device=cpus[0]) # File handling and interpolating related constants. # Do not change these unless something about the tabulated files have changed. @@ -94,7 +90,7 @@ def beta_H(TCMB): def peebles_C(z, xHII, H, nH, args): """ Peebles C factor, probability for an n=2 hydrogen atom to reach the ground state before becoming photoionized, a function of redshift and ionized proton fraction. - + Dimensions: None Parameters @@ -107,23 +103,24 @@ def peebles_C(z, xHII, H, nH, args): Value(s) of Hubble at given redshift(s). nH : float/jnp.array Value(s) of hydrogen number density at given redshift(s) - BG: cosmology.Background - Background cosmology object + args : tuple + (recomb_inputs, params) — recombination input arrays and + cosmological parameters. Returns ------- C : float/jnp.array Peebles C factor. """ - BG, params = args + recomb_inputs, params = args # (2p to 1s rate) x (1 - xHII), in s^{-1} rate_2p1s_times_x1s = 8*jnp.pi*H / (3 * (nH*(cnst.c/cnst.lya_freq)**3)) # s^{-1} rate_exc = 3. * rate_2p1s_times_x1s/4. + (1.-xHII) * cnst.R2s1s/4. - + # Ionization rate, in s^{-1} - rate_ion = (1-xHII) * beta_H( BG.TCMB( jnp.log(1/(1+z)) , params ) ) + rate_ion = (1-xHII) * beta_H( recomb_inputs.TCMB( jnp.log(1/(1+z)) ) ) return rate_exc / (rate_exc + rate_ion) diff --git a/abcmb/linx/abundances.py b/abcmb/linx/abundances.py index f0559f8..cfde07d 100644 --- a/abcmb/linx/abundances.py +++ b/abcmb/linx/abundances.py @@ -40,6 +40,9 @@ class AbundanceModel(eqx.Module): Binding energy of each species. species_mass : list Mass of each species. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Defaults + to ForwardMode. """ nuclear_net : nucl.NuclearRates diff --git a/abcmb/linx/background.py b/abcmb/linx/background.py index c102856..f75df7b 100644 --- a/abcmb/linx/background.py +++ b/abcmb/linx/background.py @@ -31,6 +31,8 @@ class BackgroundModel(eqx.Module): Whether to use leading order QED correction. Default is `True`. NLO : bool, optional Whether to use next-to-leading order QED correction. Default is True. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. """ decoupled : bool diff --git a/abcmb/main.py b/abcmb/main.py index ee717a3..4d8e673 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -15,6 +15,7 @@ from . import background, perturbations, spectrum, model_specs from . import constants as cnst from .ABCMBTools import bilinear_interp +from .background import BackgroundPreRecomb, Background, ReionizationModelFromZ, ReionizationModelFromTau from .linx.background import BackgroundModel from .linx.abundances import AbundanceModel @@ -57,6 +58,8 @@ class Model(eqx.Module): A LINX abundance model used for computing the helium-4 mass fraction given the user's input baryon density, Neff, neutron lifetime, and nuclear reaction rates. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. Methods: -------- @@ -72,9 +75,9 @@ class Model(eqx.Module): specs : dict species_list : tuple = () - species_dict : dict - - PArthENoPE_CLASS_table : Array + species_dict : dict + + PArthENoPE_CLASS_table : Array thermo_model_DNeff : BackgroundModel abundanceModel : AbundanceModel @@ -104,9 +107,14 @@ def __init__(self, # Pull adjoint out of kwargs before load_specs — it must NOT end up # inside self.specs (a non-JAX pytree leaf breaks lax.cond / filter_jit - # tracing). Default preserves prior ForwardMode behavior. + # tracing). adjoint = kwargs.pop("adjoint", diffrax.ForwardMode) + # If user requested RecursiveCheckpointAdjoint, auto-tighten kappa_PE + # to a reverse-AD-safe value (unless another value is specified) + if adjoint is diffrax.RecursiveCheckpointAdjoint and "kappa_PE" not in kwargs: + kwargs["kappa_PE"] = 1e-3 + # Fill in all user defined and missing specs parameters specs = model_specs.load_specs(kwargs) self.specs = specs @@ -142,7 +150,7 @@ def __init__(self, scale_pol=specs["scale_pol"] ) - # Initialize recombination model + # Initialize recombination model. self.RecModel = hyrex.recomb_model(adjoint=adjoint) # DO NOT CHANGE z1 FROM 0 # Initialize BBN model @@ -157,12 +165,10 @@ def __init__(self, self.adjoint = adjoint - # need this outside of the jit context - # since we want LINX to run on CPU + # need this outside of the main jit context + # since we want LINX/HyRex to run on CPU def __call__(self, params : dict = {}): """ - Compute CMB angular power spectra for given parameters. - Runs the full pipeline from background evolution through perturbation integration to CMB power spectrum computation. @@ -173,17 +179,16 @@ def __call__(self, params : dict = {}): Returns: -------- - tuple - (ℓ values, (C_ℓ^TT, C_ℓ^TE, C_ℓ^EE)) for computed multipoles + Output + Bundle of CMB power spectra (ClTT, ClTE, ClEE) and their + multipole grid l, matter power spectrum Pk and its k-grid, + the Background and PerturbationTable objects, and the + full parameter dict including derived keys. """ - - full_params = self.add_derived_parameters(params) return self.run_cosmology_abbr(full_params) - - ### JITTED OR JITTABLE FUNCTIONS ### - @eqx.filter_jit + def run_cosmology_abbr(self, params : dict): """ Compute CMB angular power spectra for given parameters. @@ -194,14 +199,60 @@ def run_cosmology_abbr(self, params : dict): Parameters: ----------- params : dict - Cosmological parameters + Cosmological parameters (must already have derived keys). Returns: -------- - tuple - (ℓ values, (C_ℓ^TT, C_ℓ^TE, C_ℓ^EE)) for computed multipoles + Output + CMB power spectra and friends. + """ + # Cast int/bool params to float64 before entering any + # ``eqx.filter_jit`` for custom_vjp/AD safety in + # checkpointed_while_loop + def _to_float(v): + arr = jnp.asarray(v) + if arr.dtype.kind in 'iub': + return arr.astype(jnp.float64) + return arr + params = jax.tree_util.tree_map(_to_float, params) + + pre_BG = self.get_BG_pre_recomb(params) + + cpu_dev = jax.devices('cpu')[0] + recomb_inputs_cpu = jax.device_put(pre_BG.recomb_inputs, cpu_dev) + params_cpu = jax.device_put(params, cpu_dev) + + recomb_output = eqx.filter_jit(self.RecModel, backend='cpu')((recomb_inputs_cpu, params_cpu)) + + try: + recomb_output = jax.device_put(recomb_output, jax.devices('gpu')[0]) + except Exception: + pass + + # recomb_output contains array_with_padding objects whose + # padding_size and lastnum int arrays. The + # checkpointed_while_loop's filter_custom_vjp inside + # _run_post_recomb's diffrax solves trips an internal + # _get_value_assert_unperturbed on int leaves under outer + # AD; convert to float to avoid. + recomb_output = jax.tree_util.tree_map(_to_float, recomb_output) + + return self._run_post_recomb(params, pre_BG, recomb_output) + + @eqx.filter_jit + def get_BG_pre_recomb(self, params : dict): """ + Pre-recomb stage: tabulate conformal time and bundle H, T, nH for recombination. + + Parameters: + ----------- + params : dict + Cosmological parameters + Returns: + -------- + BackgroundPreRecomb + """ # let the user know the code is compiling print("") print(' /\\ ') @@ -214,14 +265,35 @@ def run_cosmology_abbr(self, params : dict): print(' /_/ \\_|___/ \\___/|| |||_-) is compiling...') print('\\_____/ ') print("") + return BackgroundPreRecomb(params, self.species_list, self.RecModel, adjoint=self.adjoint) + + @eqx.filter_jit + def _run_post_recomb(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): + """ + Post-recombination stage: full Background construction (reionization, + optical depth, decoupling), perturbation evolution, CMB spectra. + + Parameters: + ----------- + params : dict + Cosmological parameters + pre_BG : BackgroundPreRecomb + Output of :meth:`get_BG_pre_recomb`. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. + + Returns: + -------- + Output + """ # Compute background and linear perturbations - PT, BG = self.get_PTBG(params) + PT, BG = self.get_PTBG(params, pre_BG, recomb_output) # Compute CMB power spectra Cls = self.SS.get_Cl(PT, BG, params) l = self.SS.ells - + # Compute linear matter power spectrum Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) k = self.SS.k_axis_Pk_output @@ -235,61 +307,67 @@ def run_cosmology_abbr(self, params : dict): return output @eqx.filter_jit - def get_PTBG(self, params : dict): + def get_PTBG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ - Get perturbation table and background. + Get perturbation table and full Background. - Computes background and evolves perturbations for the given parameters. + Constructs the post-recomb Background from ``pre_BG`` + ``recomb_output`` + and runs the perturbation evolver. Parameters: ----------- params : dict Cosmological parameters + pre_BG : BackgroundPreRecomb + Pre-recombination stage object. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. Returns: -------- tuple - (PerturbationTable, Background) objects + (PerturbationTable, Background) """ - BG = self.get_BG(params) + BG = self.get_BG(params, pre_BG, recomb_output) PT = self.PE.full_evolution((BG, params)) - return PT, BG - @eqx.filter_jit - def get_BG(self, params : dict): + def get_BG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ - Get background for given parameters. + Construct the full ``Background`` from pre-recomb + HyRex output. + + Selects the reionization model (z-input vs tau-input) via ``lax.cond``. + NOT directly ``@eqx.filter_jit``-decorated; called from inside + ``_run_post_recomb`` (which is jit-wrapped). Parameters: ----------- params : dict Cosmological parameters + pre_BG : BackgroundPreRecomb + Pre-recombination stage object. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. Returns: -------- background.Background - Background object """ - # Bind to a local so both closures capture a plain class rather than - # an attribute lookup on self. The class is never placed in the - # lax.cond operand tuple (keeping it a valid JAX pytree). - adjoint = self.adjoint def get_BG_z_reion(args): - params, species_list, RecModel = args - return background.Background(params, species_list, RecModel, background.ReionizationModelFromZ, adjoint=adjoint) + params, pre_BG, recomb_output = args + return Background(pre_BG, recomb_output, params, ReionizationModelFromZ) def get_BG_tau_reion(args): - params, species_list, RecModel = args - return background.Background(params, species_list, RecModel, background.ReionizationModelFromTau, adjoint=adjoint) + params, pre_BG, recomb_output = args + return Background(pre_BG, recomb_output, params, ReionizationModelFromTau) BG = lax.cond( self.specs["input_tau_reion"], get_BG_tau_reion, get_BG_z_reion, - (params, self.species_list, self.RecModel) + (params, pre_BG, recomb_output) ) - + return BG def add_derived_parameters(self, param_in : dict) -> dict: diff --git a/abcmb/model_specs.py b/abcmb/model_specs.py index 6cd57b1..6a3dfba 100644 --- a/abcmb/model_specs.py +++ b/abcmb/model_specs.py @@ -67,6 +67,11 @@ def load_specs(input_specs): specs["pcoeff_PE"] = input_specs.get("pcoeff_PE", 0.25) specs["icoeff_PE"] = input_specs.get("icoeff_PE", 0.8) specs["dcoeff_PE"] = input_specs.get("dcoeff_PE", 0.) + # Newton convergence threshold for Kvaerno5 VeryChord root finder. + # If you encounter NaNs in gradients, decrease this value (e.g. 1e-3). + # Small (<1%) performance hit for doing so, so default stays at 1e-2 + # unless reverseAD is requested (see main). + specs["kappa_PE"] = input_specs.get("kappa_PE", 0.01) ### Physical contributions to CMB temperature transfer function ### specs["scale_sw"] = input_specs.get("scale_sw", 1) diff --git a/abcmb/perturbations.py b/abcmb/perturbations.py index 3da858e..7757cf3 100644 --- a/abcmb/perturbations.py +++ b/abcmb/perturbations.py @@ -3,6 +3,8 @@ import numpy as np from jax import vmap, lax import diffrax +from diffrax import with_stepsize_controller_tols +from diffrax._root_finder import VeryChord import equinox as eqx from . import constants as cnst @@ -38,6 +40,8 @@ class PerturbationEvolver(eqx.Module): A list of wavenumbers k at which to compute perturbations specs : dict A dictionary containing run options + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. Methods: -------- @@ -281,7 +285,18 @@ def evolution_one_k(self, k, lna, args): # Settings for post-tight coupling term = diffrax.ODETerm(self.get_derivatives) - solver = diffrax.Kvaerno5() + # LCDM defaults for very high k sometimes failed with reverseAD. + # The reason for that was the default precision parameter in + # the Kvaerno5 rootfinder (via VeryChord). Parameter now read + # in from specs; if reverse-AD produces NaNs (especially on omega_b/ + # omega_cdm), decrease kappa_PE in specs (e.g. to 1e-3) to tighten + # convergence. + _rf = eqx.tree_at( + lambda s: s.kappa, + with_stepsize_controller_tols(VeryChord)(), + replace=self.specs["kappa_PE"], + ) + solver = diffrax.Kvaerno5(root_finder=_rf) rtol=jnp.where( k > self.specs["k_split_PE"], diff --git a/abcmb/species.py b/abcmb/species.py index 95a178c..3c79701 100644 --- a/abcmb/species.py +++ b/abcmb/species.py @@ -41,10 +41,10 @@ class Fluid(eqx.Module): rho_plus_P_sigma : Compute standard shear perturbation (units: eV cm^{-3}) """ - delta_idx : int = eqx.field(default=0) + delta_idx : int = eqx.field(default=0, static=True) num_moments : int = eqx.field(default=0, static=True) name : str = eqx.field(default="", static=True) - is_matter : bool = eqx.field(default=False) # Does the fluid contribute towards matter overdensity today. + is_matter : bool = eqx.field(default=False, static=True) # Does the fluid contribute towards matter overdensity today. def __init__(self, delta_idx, specs): self.delta_idx = delta_idx diff --git a/abcmb/spectrum.py b/abcmb/spectrum.py index dbe3ef6..60833f9 100644 --- a/abcmb/spectrum.py +++ b/abcmb/spectrum.py @@ -411,15 +411,27 @@ def lensing_Cl(self, ells, PT, BG, params): coeff = 8.*jnp.pi**2/(ells+0.5)**3 chi = lambda lna : BG.tau0 - BG.tau(lna) + # The previous jnp.nan_to_num(integrand, nan=0.) here masked the + # forward NaN but left a 0*NaN cotangent in the backward through + # the where-mask that nan_to_num secretly expands to, which + # propagated through BG.tau. Fix: substitute lna_safe everywhere, + # then mask the result to 0 at the boundary. + lna_axis = jnp.linspace(BG.lna_rec, 0., 500) + lna_floor = lna_axis[-2] + def integrand_func(lna): - k = (ells+0.5)/chi(lna) - window = (chi(BG.lna_rec) - chi(lna))/chi(BG.lna_rec)/chi(lna) - res = chi(lna)/BG.aH(lna, params) * window**2 * self.lensing_power_spectrum(k, lna, PT, BG, params) - return res + lna_safe = jnp.where(lna < 0., lna, lna_floor) + chi_safe = chi(lna_safe) + k = (ells+0.5)/chi_safe + window = (chi(BG.lna_rec) - chi_safe)/chi(BG.lna_rec)/chi_safe + res = ( + chi_safe / BG.aH(lna_safe, params) + * window**2 + * self.lensing_power_spectrum(k, lna_safe, PT, BG, params) + ) + return jnp.where(lna < 0., res, 0.) - lna_axis = jnp.linspace(BG.lna_rec, 0., 500) integrand = vmap(integrand_func)(lna_axis) - integrand = jnp.nan_to_num(integrand, nan=0.) return coeff*jnp.trapezoid(integrand, lna_axis, axis=0) def lensed_Cls(self, ells, ClTT_unlensed, ClTE_unlensed, ClEE_unlensed, PT, BG, params): diff --git a/abcmb/version.py b/abcmb/version.py index 7a17bdd..845be45 100644 --- a/abcmb/version.py +++ b/abcmb/version.py @@ -1 +1 @@ -__version__ = "0.2.4" \ No newline at end of file +__version__ = "0.2.5" \ No newline at end of file