From be8050e2d8e7b5fde509f7b88e0a14bd8b039039 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Mon, 4 May 2026 15:12:51 -0700 Subject: [PATCH 01/14] refactored BG and hyrex to pull hyrex into its own jit context --- abcmb/background.py | 518 ++++++++++++++++---------------- abcmb/hyrex/helium.py | 61 ++-- abcmb/hyrex/hydrogen.py | 80 ++--- abcmb/hyrex/hyrex.py | 53 +++- abcmb/hyrex/recomb_functions.py | 27 +- abcmb/main.py | 187 +++++++++--- abcmb/species.py | 4 +- 7 files changed, 548 insertions(+), 382 deletions(-) diff --git a/abcmb/background.py b/abcmb/background.py index 0b5c18c..dbdf91a 100644 --- a/abcmb/background.py +++ b/abcmb/background.py @@ -8,6 +8,7 @@ from .hyrex.array_with_padding import array_with_padding from .hyrex import recomb_functions +from .hyrex.hyrex import RecombInputs from . import ABCMBTools as tools from . import constants as cnst @@ -15,106 +16,58 @@ file_dir = os.path.dirname(__file__) config.update("jax_enable_x64", True) -class Background(eqx.Module): + +class BackgroundPreRecomb(eqx.Module): """ - Background cosmology module for cosmological calculations. + Pre-recombination background-cosmology object (Phase 2 of HyRex CPU lift). - Computes background quantities including Hubble parameter, conformal time, - recombination history, and optical depth evolution. + Holds everything HyRex needs to run on CPU: the conformal-time tabulation, + the species list, and a ``RecombInputs`` struct that bundles HyRex's input + arrays sampled on the recombination grid. None of these depend on xe, Tm, + or the optical depth, so this object is the natural input to the CPU-pinned + HyRex solve and the natural input to the post-recombination Background + construction (which inherits from this class). Attributes: ----------- species_list : tuple A list of all fluids in the cosmology lna_tau_tab : jnp.array - Log scale factor axis used to tabulate conformal time - tau_tab : jnp.array - Tabulated conformal time. - tau0 : float + Log scale factor axis used to tabulate conformal time (class attribute) + tau_tab : jnp.array + Tabulated conformal time. + tau0 : float Conformal time today in Mpc. - xe_tab : array_with_padding - Tabulated free electron fraction xe during recombination - lna_xe_tab : array_with_padding - Log scale factor axis corresponding to tabulated xe values. - Tm_tab : array_with_padding - Tabulated matter temperature Tm during recombination - lna_Tm_tab : array_with_padding - Log scale factor axis corresponding to tabulated Tm values. - kappa_func : diffrax.solution - Visibility function - z_reion : float - Redshift of hydrogen reionization in the CAMB parameterization. - tau_reion : float - Optical depth to reionization - lna_rec : float - Log scale factor of recombination - rA_rec : float - Comoving angular diameter distance at recombination in Mpc - rs_d : float - Sound horizon at baryon decoupling in Mpc - z_d : float - Redshift of baryon decoupling - lna_transfer_start : float - Log scale factor at which to begin integrating transfer functions. - lna_visibility_stop : float - Log scale factor at which to stop integrating T1, T2, and E sources due to small visibility functions. - Only used for l<400. - - Recombination Unrelated Methods: - -------------------------------- - rho_tot : Compute total energy density (units: eV cm^{-3}) - P_tot : Compute total pressure (units: eV cm^{-3}) - H : Compute Hubble parameter (units: s^{-1}) - aH : Compute conformal Hubble parameter (units: Mpc^{-1}) - aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) - d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) - tau : Compute conformal time (units: Mpc) - z_d : Compute baryon decoupling redshift (units: dimensionless) - rs_d : Compute sound horizon at decoupling (units: Mpc) - - Recombination Related Methods: - ------------------------------ - xe : Compute free electron fraction (units: dimensionless) - Tm : Compute matter temperature (units: eV) - mu_bar : Compute mean molecular mass (units: eV) - cs2 : Compute baryon sound speed squared (units: dimensionless) - nH : Compute hydrogen number density (units: cm^{-3}) - TCMB : Compute CMB temperature (units: eV) - tau_c : Compute Thomson scattering time (units: Mpc) - kappa : Compute optical depth (units: dimensionless) - visibility : Compute visibility function (units: Mpc^{-1}) + recomb_inputs : RecombInputs + Bundle of background quantities (TCMB, nH, H) sampled on + ``RecModel.lna_axis_full``; consumed by HyRex. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). + + Methods: + -------- + rho_tot, P_tot, H, aH, aH_prime, d2adtau2_over_a + tau, nH, TCMB, R_ratio_lna """ - # params : dict species_list : tuple - - lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) # Axis for tabulating conformal time. - tau_tab : jnp.array # Tabulated conformal time. - tau0 : float # Conformal time today - # Recombination related - xe_tab : "array_with_padding" - lna_xe_tab : "array_with_padding" - Tm_tab : "array_with_padding" - lna_Tm_tab : "array_with_padding" - kappa_func : "diffrax.solution" - z_reion : float - tau_reion : float - lna_rec : float - rA_rec : float # Comoving angular diameter distance at recombination. + lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) + tau_tab : jnp.array + tau0 : float - # Transfer related - lna_transfer_start : float # Time where transfer functions start integrating. - lna_visibility_stop : float # Time to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400 + recomb_inputs : "RecombInputs" adjoint : "diffrax.adjoint" = eqx.field(static=True) - def __init__(self, params, species_list, RecModel, ReionModel, adjoint=ForwardMode): + def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): """ - Initialize Background cosmology module. + Initialize pre-recombination background. - Computes and tabulates conformal time, recombination history, - optical depth, and key cosmological epochs. + Tabulates conformal time and builds the ``RecombInputs`` struct + HyRex consumes. No reionization correction or optical-depth + integration is done here — those depend on the recombination + history and live on the post-recomb ``Background`` subclass. Parameters: ----------- @@ -122,54 +75,29 @@ def __init__(self, params, species_list, RecModel, ReionModel, adjoint=ForwardMo Cosmological parameters species_list : tuple List of fluid species for energy density calculations - RecModel : callable - Recombination module for computing xe and Tm histories - ReionModel : callable - Reionization module for computing the xe correction. + RecModel : hyrex.recomb_model + Used for its ``lna_axis_full`` sampling grid (not called here). + adjoint : diffrax.adjoint, optional + Adjoint class for diffrax solves (default: ForwardMode) """ self.adjoint = adjoint - # self.params = params self.species_list = species_list self.tau_tab = self._tabulate_conformal_time(params) - self.tau0 =self.tau(0.) - - self.run_recombination(params, RecModel, ReionModel) - - # Find approximate maximum of visibility function. - lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling should have happened at some time in this interval. - vis_vals = vmap(self.visibility,in_axes=[0,None])(lna_vals, params) - self.lna_rec = lna_vals[jnp.argmax(vis_vals)] - self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)] - self.rA_rec = self.tau0 - self.tau(self.lna_rec) - - # Find approximate early time when aH x tau_c = 0.008 - lna_vals = jnp.linspace(-15.0, -6.0, 5000) - aH_tau_c_vals = vmap(self.aH,in_axes=[0,None])(lna_vals,params)*self.tau_c(lna_vals,params) - self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)] - - def run_recombination(self, params, RecModel, ReionModel): - """ - Call HyRex to get the primary recombination history, then patch on tanh reionization. - For now assumes the user will specify tau_reio. - """ - ### RECOMBINATION ### - - # Run hyrex to tabulate recombination output - xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = RecModel((self,params)) - - ### REIONIZATION ### - reion_model = ReionModel(self, params) - self.z_reion = reion_model.z_reion - self.tau_reion = reion_model.tau_reion - - xe_reion_correction = reion_model.xe_reion(self.lna_xe_tab.arr, self.z_reion, params) - xe_full_arr = xe_reion_correction + xe.arr - xe_full = array_with_padding(xe_full_arr) - - self.xe_tab = xe_full - - self.kappa_func = self._tabulate_optical_depth(params) + self.tau0 = self.tau(0.) + + # Bundle the background quantities HyRex needs onto its sampling + # grid. Phase 2 ships these to CPU (see ``Model.__call__``); for + # standard cosmologies the linear interpolation against this dense + # grid is accurate to ~3e-8 (h^2/8 with h=5e-4) — well below + # accuracy_test tolerances. + lna_axis = RecModel.lna_axis_full + self.recomb_inputs = RecombInputs( + lna_grid = lna_axis, + TCMB_arr = vmap(self.TCMB, in_axes=[0, None])(lna_axis, params), + nH_arr = vmap(self.nH, in_axes=[0, None])(lna_axis, params), + H_arr = vmap(self.H, in_axes=[0, None])(lna_axis, params), + ) def rho_tot(self, lna, params): """ @@ -188,16 +116,12 @@ def rho_tot(self, lna, params): -------- float Total energy density (units: eV cm^{-3}) - - Notes: - ------ - User should not modify this function without careful consideration. """ rho_tot = 0. for i in range(len(self.species_list)): rho_tot += self.species_list[i].rho(lna, params) return rho_tot - + def P_tot(self, lna, params): """ Compute total pressure. @@ -215,10 +139,6 @@ def P_tot(self, lna, params): -------- float Total pressure (units: eV cm^{-3}) - - Notes: - ------ - User should not modify this function without careful consideration. """ P_tot = 0. for i in range(len(self.species_list)): @@ -266,7 +186,7 @@ def aH(self, lna, params): Conformal Hubble parameter (units: Mpc^{-1}) """ return jnp.exp(lna)*self.H(lna, params) / cnst.c_Mpc_over_s - + def aH_prime(self, lna, params): """ Compute derivative of conformal Hubble parameter. @@ -307,9 +227,8 @@ def d2adtau2_over_a(self, lna, params): float Second derivative of scale factor (units: Mpc^{-2}) """ - return self.aH(lna, params)**2 + self.aH(lna, params)*self.aH_prime(lna, params) - + def _dtau_dlna(self, lna, y, args): """ Compute derivative of conformal time with respect to ln(a). @@ -416,18 +335,189 @@ def tau(self, lna): -------- float Conformal time (units: Mpc) + """ + return tools.fast_interp(lna, self.lna_tau_tab[0], self.lna_tau_tab[-1], self.tau_tab) - Notes: - ------ - IDEA: Make Background a repeatedly initiated module with both - species_list and params stored. Upon initiation, a full history - of conformal time is calculated with diffrax and stored for - interpolation. This can be done by approximating early time with - radiation approximation, and starting diffrax integration at the - early time with appropriate initial conditions. + def nH(self, lna, params): """ + Compute hydrogen number density. - return tools.fast_interp(lna, self.lna_tau_tab[0], self.lna_tau_tab[-1], self.tau_tab) + Calculates total hydrogen number density at given redshift. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + Hydrogen number density (units: cm^{-3}) + """ + return (1-params['YHe']) * 3. * params['omega_b'] * cnst.H0_over_h**2 / 8 / jnp.pi / cnst.G / cnst.mH / jnp.exp(lna)**3 + + def TCMB(self, lna, params): + """ + Compute CMB temperature. + + Calculates CMB temperature at given redshift using T ∝ 1/a scaling. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + CMB temperature (units: eV) + """ + return params['TCMB0'] / jnp.exp(lna) + + def R_ratio_lna(self, lna, params): + """ + Compute baryon drag ratio. + + Calculates R = 3ρ_b/(4ρ_γ), the ratio of baryon to photon + energy densities that appears in baryon drag calculations. + + Parameters: + ----------- + lna : float + Logarithm of scale factor + params : dict + Cosmological parameters + + Returns: + -------- + float + Baryon drag ratio (units: dimensionless) + """ + rho_b = 0. + rho_g = 0. + + for s in self.species_list: + if s.name == "Photon": + rho_g += s.rho(lna, params) + elif s.name == "Baryon": + rho_b += s.rho(lna, params) + + return 3. * rho_b / (4 * rho_g) + + +class Background(BackgroundPreRecomb): + """ + Full background-cosmology object: pre-recombination state plus + the recombination + reionization history and the optical-depth + tabulation. + + Inherits all cosmology fields and methods from ``BackgroundPreRecomb``. + Construction takes a ``BackgroundPreRecomb`` (output of the GPU pre-recomb + stage) and the recombination output produced by HyRex on CPU, then + applies the reionization correction and integrates the optical depth. + + Attributes: + ----------- + xe_tab : array_with_padding + Tabulated free electron fraction xe with reionization correction. + lna_xe_tab : array_with_padding + Log scale factor axis corresponding to tabulated xe values. + Tm_tab : array_with_padding + Tabulated matter temperature Tm during recombination. + lna_Tm_tab : array_with_padding + Log scale factor axis corresponding to tabulated Tm values. + kappa_func : diffrax.solution + Optical depth function (dense interpolation). + z_reion : float + Redshift of hydrogen reionization in the CAMB parameterization. + tau_reion : float + Optical depth to reionization. + lna_rec : float + Log scale factor of recombination. + rA_rec : float + Comoving angular diameter distance at recombination in Mpc. + lna_transfer_start : float + Log scale factor at which to begin integrating transfer functions. + lna_visibility_stop : float + Log scale factor at which to stop integrating T1, T2, and E sources + due to small visibility functions. Only used for l<400. + + Recombination Related Methods: + ------------------------------ + xe : Compute free electron fraction (units: dimensionless) + Tm : Compute matter temperature (units: eV) + tau_c : Compute Thomson scattering time (units: Mpc) + expmkappa : Compute exp(-kappa) (units: dimensionless) + visibility : Compute visibility function (units: Mpc^{-1}) + z_d : Compute baryon decoupling redshift (units: dimensionless) + rs_d : Compute sound horizon at decoupling (units: Mpc) + """ + + xe_tab : "array_with_padding" + lna_xe_tab : "array_with_padding" + Tm_tab : "array_with_padding" + lna_Tm_tab : "array_with_padding" + kappa_func : "diffrax.solution" + z_reion : float + tau_reion : float + lna_rec : float + rA_rec : float + + lna_transfer_start : float + lna_visibility_stop : float + + def __init__(self, pre_BG, recomb_output, params, ReionModel): + """ + Construct full Background from a pre-recomb stage and the HyRex output. + + Parameters: + ----------- + pre_BG : BackgroundPreRecomb + Output of the GPU pre-recomb stage; provides species_list, + tau_tab, tau0, recomb_inputs, adjoint. + recomb_output : tuple + HyRex's ``(xe, lna_xe, Tm, lna_Tm)`` quadruple — the result of + running ``RecModel((pre_BG.recomb_inputs, params))`` on CPU. + params : dict + Cosmological parameters. + ReionModel : type + ``ReionizationModelFromZ`` or ``ReionizationModelFromTau``. + """ + # Copy pre-recomb fields onto self. + self.adjoint = pre_BG.adjoint + self.species_list = pre_BG.species_list + self.tau_tab = pre_BG.tau_tab + self.tau0 = pre_BG.tau0 + self.recomb_inputs = pre_BG.recomb_inputs + + # Unpack HyRex output and apply reionization correction. + xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = recomb_output + + reion_model = ReionModel(self, params) + self.z_reion = reion_model.z_reion + self.tau_reion = reion_model.tau_reion + + xe_reion_correction = reion_model.xe_reion(self.lna_xe_tab.arr, self.z_reion, params) + xe_full_arr = xe_reion_correction + xe.arr + self.xe_tab = array_with_padding(xe_full_arr) + + self.kappa_func = self._tabulate_optical_depth(params) + + # Find approximate maximum of visibility function. + lna_vals = jnp.linspace(-8.0, -4.0, 1500) # Decoupling falls in here. + vis_vals = vmap(self.visibility, in_axes=[0, None])(lna_vals, params) + self.lna_rec = lna_vals[jnp.argmax(vis_vals)] + self.lna_visibility_stop = lna_vals[jnp.argmin((vis_vals - 1.e-3)**2)] + self.rA_rec = self.tau0 - self.tau(self.lna_rec) + + # Find approximate early time when aH x tau_c = 0.008 + lna_vals = jnp.linspace(-15.0, -6.0, 5000) + aH_tau_c_vals = vmap(self.aH, in_axes=[0, None])(lna_vals, params) * self.tau_c(lna_vals, params) + self.lna_transfer_start = lna_vals[jnp.argmin((aH_tau_c_vals-0.008)**2)] ### RECOMBINATION RELATED ### @@ -447,16 +537,6 @@ def xe(self, lna): -------- float Free electron fraction (units: dimensionless) - - Notes: - ------ - The logic flow is equivalent to: - - if lna < self.lna_xe_tab.arr[0]: return self.xe_tab[0] - - elif lna > self.lna_xe_tab.lastval: return self.xe_tab.lastval - - else: return jnp.interp(lna, self.lna_xe_tab, self.xe_tab) """ return jnp.where( lna < self.lna_xe_tab.arr[0], @@ -524,46 +604,6 @@ def Tm(self, lna, params): ) ) - def nH(self, lna, params): - """ - Compute hydrogen number density. - - Calculates total hydrogen number density at given redshift. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - Hydrogen number density (units: cm^{-3}) - """ - return (1-params['YHe']) * 3. * params['omega_b'] * cnst.H0_over_h**2 / 8 / jnp.pi / cnst.G / cnst.mH / jnp.exp(lna)**3 - - def TCMB(self,lna, params): - """ - Compute CMB temperature. - - Calculates CMB temperature at given redshift using T ∝ 1/a scaling. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - CMB temperature (units: eV) - """ - return params['TCMB0'] / jnp.exp(lna) - def tau_c(self, lna, params): """ Compute Thomson scattering time. @@ -603,11 +643,6 @@ def _tabulate_optical_depth(self, params): -------- array Tabulated optical depth values (units: dimensionless) - - Notes: - ------ - Also computes time derivative of optical depth, which is the - integrand involving the free electron fraction. """ integrand = lambda lna, y, args: -1./self.tau_c(lna, params)/self.aH(lna, params) term = ODETerm(integrand) @@ -615,13 +650,13 @@ def _tabulate_optical_depth(self, params): adjoint=self.adjoint() sol = diffeqsolve( term, - solver=Kvaerno5(), + solver=Kvaerno5(), stepsize_controller=stepsize_controller, - t0=0., - t1=-10., - dt0=-1.e-3, + t0=0., + t1=-10., + dt0=-1.e-3, max_steps=2048, - y0=0.0, + y0=0.0, saveat=SaveAt(dense=True), adjoint=adjoint ) @@ -629,7 +664,7 @@ def _tabulate_optical_depth(self, params): def expmkappa(self, lna): """ - Compute optical depth. + Compute exp(-optical depth). Interpolates from pre-tabulated optical depth history. @@ -641,9 +676,8 @@ def expmkappa(self, lna): Returns: -------- float - Optical depth (units: dimensionless) + exp(-κ) (units: dimensionless) """ - return jnp.where( lna < -10., 0., @@ -669,10 +703,6 @@ def visibility(self, lna, params): -------- float Visibility function (units: Mpc^{-1}) - - Notes: - ------ - Used in computing source functions for CMB anisotropies. """ return self.expmkappa(lna)/self.tau_c(lna, params) @@ -699,12 +729,10 @@ def find_z_at_kappad_equals_one(self, z, kappa_d): float Decoupling redshift (units: dimensionless) """ - # ensure sorted ascending idx = jnp.argsort(z) z_sorted = z[idx] kappa_d_sorted = jnp.abs(kappa_d)[idx] - # interpolate z_d = jnp.interp(1.0, kappa_d_sorted, z_sorted) return z_d @@ -731,37 +759,6 @@ def interp_rs_at_z(self, z_bg, r_s, z_d): rs_sorted = r_s[idx] return jnp.interp(z_d, z_sorted, rs_sorted) - def R_ratio_lna(self, lna, params): - """ - Compute baryon drag ratio. - - Calculates R = 3ρ_b/(4ρ_γ), the ratio of baryon to photon - energy densities that appears in baryon drag calculations. - - Parameters: - ----------- - lna : float - Logarithm of scale factor - params : dict - Cosmological parameters - - Returns: - -------- - float - Baryon drag ratio (units: dimensionless) - """ - - rho_b = 0. - rho_g = 0. - - for s in self.species_list: - if s.name == "Photon": - rho_g += s.rho(lna, params) - elif s.name == "Baryon": - rho_b += s.rho(lna, params) - - return 3. * rho_b / (4 * rho_g) - @jax.named_scope("tabulate kappa d") def _tabulate_kappa_d(self, params): """ @@ -784,17 +781,17 @@ def _tabulate_kappa_d(self, params): term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() - + solution = diffeqsolve( term, - solver=Tsit5(), # Kvaerno5 is just slower but gives same result + solver=Tsit5(), stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[-1], # Initial x value (~0 in this case) - t1=self.lna_tau_tab[0], # Final x value (smallest x value) - dt0=-1e-3, # Initial step size + t0=self.lna_tau_tab[-1], + t1=self.lna_tau_tab[0], + dt0=-1e-3, max_steps=2048, - y0=0.0, # Initial value tau(x=0) = 0 - saveat=SaveAt(ts=self.lna_tau_tab[::-1]), # Save at all points in x, reverse order since integrating backwards + y0=0.0, + saveat=SaveAt(ts=self.lna_tau_tab[::-1]), adjoint=adjoint ) result = solution.ys[::-1] @@ -818,19 +815,18 @@ def _tabulate_rs(self, params): array Tabulated sound horizon values (units: Mpc) """ - # initial condition assuming cs**2 = 1/3 at early times rs0 = 1./jnp.sqrt(3) / (self.aH( self.lna_tau_tab[0], params )) integrand = lambda lna, y, args: 1./jnp.sqrt(3*(1+self.R_ratio_lna(lna, params))) / (self.aH(lna, params)) term = ODETerm(integrand) stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=1.e-3, atol=1.e-6) adjoint=self.adjoint() - + solution = diffeqsolve( term, solver=Tsit5(), stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[0], # reversed direction since I know rs at early times + t0=self.lna_tau_tab[0], t1=self.lna_tau_tab[-1], dt0=1e-3, max_steps=2048, @@ -840,7 +836,6 @@ def _tabulate_rs(self, params): ) result = solution.ys return result - def z_d(self, params): """ @@ -879,14 +874,15 @@ def rs_d(self, params): """ return self.interp_rs_at_z(1/jnp.exp(self.lna_tau_tab) - 1, self._tabulate_rs(params), self.z_d(params)) + class ReionizationModel(eqx.Module): """ Object for computing the reionization correction to the free electron fraction. - Provides the base methods + Provides the base methods xe_reion : calculates the tanh electron fraction correction at redshifts lna, given z_reion and params tau_reion_fn : calculates the optical depth to reionization. - + At the moment we only support the CAMB tanh parameterization, but we need different approaches based on whether the use inputs the optical depth tau_reion or the reionization redshift z_reion. @@ -898,7 +894,7 @@ class ReionizationModel(eqx.Module): def xe_reion(self, lna, z_reion, params): """ Passing in an lna array should get you the correct tanh patching based on the - reionization parameter. + reionization parameter. """ fHe = params['YHe'] / 4 / (1-params['YHe']) z = 1/jnp.exp(lna) - 1 @@ -919,7 +915,7 @@ def xe_reion(self, lna, z_reion, params): def tau_reion_fn(self, z_reion, BG, params): lna_axis = jnp.linspace(-5., 0., 2000) xe_reion_correction = self.xe_reion(lna_axis, z_reion, params) - # Free electron number density belonging only to reionized hydrogen. + # Free electron number density belonging only to reionized hydrogen. ne = BG.nH(lna_axis, params) * xe_reion_correction Gamma = jnp.exp(lna_axis)*ne*cnst.thomson_xsec*cnst.c/cnst.c_Mpc_over_s aH = BG.aH(lna_axis, params) @@ -957,4 +953,4 @@ def tau_target_fn(z_reion, args): solver = optx.Newton(rtol=1e-5, atol=1e-5) sol = optx.root_find(tau_target_fn, solver, 7.6, params.get("tau_reion", jnp.array(0.05430842))) self.z_reion = sol.value - self.tau_reion = params.get("tau_reion", jnp.array(0.05430842)) \ No newline at end of file + self.tau_reion = params.get("tau_reion", jnp.array(0.05430842)) diff --git a/abcmb/hyrex/helium.py b/abcmb/hyrex/helium.py index 51d96c1..78740a7 100644 --- a/abcmb/hyrex/helium.py +++ b/abcmb/hyrex/helium.py @@ -71,7 +71,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -98,7 +99,8 @@ def get_helium_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_st Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -160,7 +162,8 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9): lna_axis : array Log scale factor grid args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for HeIII fraction to stop calculation (default: 1e-9) @@ -169,7 +172,7 @@ def xesaha_HeII_III(self, lna_axis, args, threshold=1e-9): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Pre-allocate xe_output xe_output = jnp.ones_like(lna_axis)*jnp.inf lna_output = jnp.ones_like(lna_axis)*jnp.inf @@ -183,8 +186,8 @@ def compute_xe(carry): lna = lna_axis[iz] # Cosmological parameters - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) # compute xHeIII fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number @@ -240,18 +243,19 @@ def xHeII_post_Saha(self, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float HeII fraction (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args fHe = params['YHe']/(1.-params['YHe'])/3.97153 - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) # Saha ratio xe * xHeII / xHeI s = 4 * 2.414194e15 * TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB) * jnp.exp(-285325. / (TCMB/cnst.kB)) / nH @@ -271,16 +275,17 @@ def xH1_Saha(self, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float Neutral hydrogen fraction (units: dimensionless) """ - BG, params = args - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) + recomb_inputs, params = args + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) xHeII = self.xHeII_post_Saha(lna, args) s = 2.4127161187130e15* TCMB/cnst.kB * jnp.sqrt(TCMB/cnst.kB)*jnp.exp(-157801.37882/(TCMB/cnst.kB))/nH xH1 = jnp.where(s>1e5,(1.+xHeII)/s - (xHeII**2 + 3.*xHeII + 2.)/s**2,\ @@ -301,7 +306,8 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5): starting_lna : float Initial log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for deviation from Saha (default: 1e-5) @@ -310,7 +316,7 @@ def post_saha_xHeII(self, starting_lna, args, threshold=1e-5): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Pre-allocate xe_output xe_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf lna_output = jnp.ones_like(self.concrete_axis_size_postSahaHe)*jnp.inf @@ -383,22 +389,23 @@ def helium_dxHeIIdlna(self, xe, lna, args): lna : float Log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float HeII recombination rate dxHeII/dlna (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args fHe = params['YHe']/(1.-params['YHe'])/3.97153 # abundance of helium by number # cosmology #lna = -jnp.log(1+z) - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s # compute xH1 in Saha equilibrium, xHeII in post-saha @@ -462,7 +469,8 @@ def xe_derivative_HeII(self, lna, state, args): state : float Current HeII ionization state args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -470,7 +478,7 @@ def xe_derivative_HeII(self, lna, state, args): Time derivative of HeII fraction (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args #z = 1. / jnp.exp(lna) - 1. # use xe = xHeII + (1.-xH1) xe = state + self.xH1_Saha(lna, args) @@ -491,7 +499,8 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-6) atol : float, optional @@ -506,10 +515,10 @@ def solve_HeII_full(self, starting_lna, xe0, args, rtol=1e-6, atol=1e-9,solver=K tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB_init = BG.TCMB(starting_lna, params) # Initial matter temperature + TCMB_init = recomb_inputs.TCMB(starting_lna) # Initial matter temperature initial_state = jnp.array([xe0]) term = ODETerm(self.xe_derivative_HeII) diff --git a/abcmb/hyrex/hydrogen.py b/abcmb/hyrex/hydrogen.py index 6a335e4..76ccec0 100644 --- a/abcmb/hyrex/hydrogen.py +++ b/abcmb/hyrex/hydrogen.py @@ -93,7 +93,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -122,7 +123,8 @@ def get_hydrogen_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_ Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -195,7 +197,8 @@ def post_Saha_expansion(self, starting_lna, args, threshold=1e-5): starting_lna : float Initial log scale factor args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). threshold : float, optional Threshold for deviation from Saha (default: 1e-5) @@ -204,10 +207,10 @@ def post_Saha_expansion(self, starting_lna, args, threshold=1e-5): tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB = BG.TCMB(starting_lna, params) - nH = BG.nH(starting_lna, params) + TCMB = recomb_inputs.TCMB(starting_lna) + nH = recomb_inputs.nH(starting_lna) xe0, _ = recomb_functions.xe_Saha(TCMB, nH) # Saha equilibrium is our intial condition # Pre-allocate xe_output @@ -223,9 +226,9 @@ def compute_xe(carry): lna = starting_lna + iz*self.integration_spacing # Cosmological parameters - TCMB = BG.TCMB(lna, params) - nH = BG.nH(lna, params) - H = BG.H(lna, params) + TCMB = recomb_inputs.TCMB(lna) + nH = recomb_inputs.nH(lna) + H = recomb_inputs.H(lna) # Saha equilibrium for xe xe_Saha, s = recomb_functions.xe_Saha(TCMB, nH) @@ -287,19 +290,20 @@ def xe_derivative_twophoton(self, lna, xe, args): xe : float Current ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- float Time derivative dxe/dlna (units: dimensionless) """ - BG, params = args + recomb_inputs, params = args x1s = 1. - xe # fraction of neutral hydrogen - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s Tm = TCMB * (1.-H/GammaC) @@ -325,7 +329,8 @@ def solve_emla_twophoton(self, lna_axis_init, lna_axis_final, xe0, args, rtol=1e xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-6) atol : float, optional @@ -340,10 +345,10 @@ def solve_emla_twophoton(self, lna_axis_init, lna_axis_final, xe0, args, rtol=1e tuple (xe_output, lna_output) - ionization fraction and log scale factor arrays """ - BG, params = args + recomb_inputs, params = args # Initial conditions - TCMB_init = BG.TCMB(lna_axis_init, params) # Initial CMB temperature + TCMB_init = recomb_inputs.TCMB(lna_axis_init) # Initial CMB temperature initial_state = xe0 term = ODETerm(self.xe_derivative_twophoton) @@ -392,7 +397,8 @@ def xe_tm_derivative(self, lna, state, args): state : array Current state [xe, Tm] args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -400,11 +406,11 @@ def xe_tm_derivative(self, lna, state, args): Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state - BG, params = args + recomb_inputs, params = args - TCMB = BG.TCMB(lna, params) # eV - nH = BG.nH(lna, params) # hydrogen number density, 1/cm^3 - H = BG.H(lna, params) # Hubble parameter, 1/s + TCMB = recomb_inputs.TCMB(lna) # eV + nH = recomb_inputs.nH(lna) # hydrogen number density, 1/cm^3 + H = recomb_inputs.H(lna) # Hubble parameter, 1/s GammaC = recomb_functions.Gamma_compton(xe, TCMB, params['YHe']) # Compton scattering rate, 1/s Delta = 0.0 @@ -427,7 +433,8 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste xe0 : float Initial ionization fraction args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional @@ -447,15 +454,15 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste t0 = lna0 t1 = jnp.inf - BG,params = args + recomb_inputs, params = args # need to go at least twice max_steps to make sure we catch the t1 we actually want t_arr = jnp.linspace(t0+self.integration_spacing, t0+2*max_steps*self.integration_spacing, 2*max_steps) save_at = SaveAt(ts=t_arr) - TCMB_init = BG.TCMB(t0, params) # Initial CMB temperature - Tm0 = TCMB_init * (1.-BG.H(t0, params)/recomb_functions.Gamma_compton(xe0, TCMB_init, params['YHe'])) + TCMB_init = recomb_inputs.TCMB(t0) # Initial CMB temperature + Tm0 = TCMB_init * (1.-recomb_inputs.H(t0)/recomb_functions.Gamma_compton(xe0, TCMB_init, params['YHe'])) initial_state = jnp.array([xe0, Tm0]) term = ODETerm(self.xe_tm_derivative) @@ -464,7 +471,7 @@ def solve_emla(self, lna0, xe0, args, rtol=1e-7,atol=1e-9,solver=Tsit5(),max_ste def temperature_check(t, y, args, **kwargs): lna = t _, Tm = y - TCMB = BG.TCMB(lna, params) + TCMB = recomb_inputs.TCMB(lna) TR_MIN = recomb_functions.TR_MIN # Minimum Tcmb in eV T_RATIO_MIN = recomb_functions.T_RATIO_MIN # Minimum Tratio ratio = jnp.minimum(Tm / TCMB, TCMB / Tm) @@ -541,7 +548,8 @@ def get_current_correction_func(self, TCMB, args): TCMB : float CMB temperature (units: eV) args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -553,7 +561,7 @@ def get_current_correction_func(self, TCMB, args): omega_cb_fid = 0.14175 Neff_fid = 3.046 - BG, params = args + recomb_inputs, params = args # For the user inputed cosmology currently scanned over. @@ -666,7 +674,8 @@ def TLA_xe_deriv(self, lna, state, args): state : array Current state [xe, Tm] args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). Returns: -------- @@ -674,12 +683,12 @@ def TLA_xe_deriv(self, lna, state, args): Time derivatives [dxe/dlna, dTm/dlna] (units: dimensionless, eV) """ xe, Tm = state - BG, params = args + recomb_inputs, params = args xHII = xe # since everything else is fully recombined - nH = BG.nH(lna, params) - TCMB = BG.TCMB(lna, params) # eV - H = BG.H(lna, params) + nH = recomb_inputs.nH(lna) + TCMB = recomb_inputs.TCMB(lna) # eV + H = recomb_inputs.H(lna) C = recomb_functions.peebles_C(jnp.exp(-lna) - 1.0, xHII, H, nH, args) alpha = recomb_functions.alpha_H(Tm) @@ -709,7 +718,8 @@ def solve_TLA(self, lna0, xe0, Tm0, args, rtol=1e-7, atol=1e-9, solver=Kvaerno3( Tm0: float Starting matter temperature args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance (default: 1e-7) atol : float, optional diff --git a/abcmb/hyrex/hyrex.py b/abcmb/hyrex/hyrex.py index 8c4b819..8e16c47 100644 --- a/abcmb/hyrex/hyrex.py +++ b/abcmb/hyrex/hyrex.py @@ -9,8 +9,51 @@ from .hydrogen import hydrogen_model from .helium import helium_model from .array_with_padding import array_with_padding +from ..ABCMBTools import fast_interp config.update("jax_enable_x64", True) + +class RecombInputs(eqx.Module): + """ + Bundle of background quantities sampled on a fixed lna grid, consumed by + the recombination model in place of a full ``Background`` instance. + + Phase 1 of the HyRex CPU lift refactors HyRex to depend only on these + arrays, so the recombination kernel becomes physics-agnostic and the + GPU/CPU device boundary has a clean interface. + + Attributes + ---------- + lna_grid : jnp.array + Uniform log scale-factor sampling axis. + TCMB_arr : jnp.array + Photon-bath temperature TCMB(lna), eV. + nH_arr : jnp.array + Hydrogen number density nH(lna), cm^-3. + H_arr : jnp.array + Hubble parameter H(lna), s^-1. + + Methods + ------- + TCMB(lna), nH(lna), H(lna) + Linear interpolation of the corresponding stored array at lna, + using ``ABCMBTools.fast_interp`` (uniform-grid path). + """ + + lna_grid : jnp.array + TCMB_arr : jnp.array + nH_arr : jnp.array + H_arr : jnp.array + + def TCMB(self, lna): + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.TCMB_arr) + + def nH(self, lna): + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.nH_arr) + + def H(self, lna): + return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.H_arr) + class recomb_model(eqx.Module): """ Complete recombination model implementation. @@ -69,7 +112,8 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -86,7 +130,7 @@ def __call__(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): with reionization, log scale factor, matter temperature, and temperature grid """ return self.get_history(args, rtol, atol, solver, max_steps) - + def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=1024): """ Compute complete recombination and reionization history. @@ -97,7 +141,8 @@ def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=102 Parameters: ----------- args : tuple - Background cosmology and cosmological parameters (BG, params) + Recombination input arrays and cosmological parameters + (recomb_inputs, params). rtol : float, optional Relative tolerance for ODE solver (default: 1e-6) atol : float, optional @@ -115,7 +160,7 @@ def get_history(self, args, rtol=1e-6, atol=1e-9,solver=Kvaerno3(),max_steps=102 matter temperature, and temperature grid """ - BG, params = args + recomb_inputs, params = args lna_axis_4Heequil = self.lna_axis_full[self.idx_4He_equil] xe_4He, lna_4He = helium_model(lna_axis_4Heequil, adjoint=self.adjoint)(args) diff --git a/abcmb/hyrex/recomb_functions.py b/abcmb/hyrex/recomb_functions.py index 13a81d5..7d9a017 100644 --- a/abcmb/hyrex/recomb_functions.py +++ b/abcmb/hyrex/recomb_functions.py @@ -13,13 +13,15 @@ #Tabulated values of 2s-2p transition rates to interpolate. alpha_tab = jnp.array(np.loadtxt(file_dir+"/tabs/Alpha_inf.dat")) +# Phase 2 of the HyRex CPU lift: the recombination model now executes +# under ``eqx.filter_jit(backend='cpu')`` (orchestrated from +# ``Model.__call__``), so these lookup tables should live on CPU. Pinning +# them to CPU here avoids per-call device migration when HyRex traces. try: - gpus = devices('gpu') - R_tab = device_put( - R_tab, device=gpus[0]) - alpha_tab = device_put( - alpha_tab, device=gpus[0]) -except: + cpus = devices('cpu') + R_tab = device_put(R_tab, device=cpus[0]) + alpha_tab = device_put(alpha_tab, device=cpus[0]) +except Exception: pass # File handling and interpolating related constants. @@ -94,7 +96,7 @@ def beta_H(TCMB): def peebles_C(z, xHII, H, nH, args): """ Peebles C factor, probability for an n=2 hydrogen atom to reach the ground state before becoming photoionized, a function of redshift and ionized proton fraction. - + Dimensions: None Parameters @@ -107,23 +109,24 @@ def peebles_C(z, xHII, H, nH, args): Value(s) of Hubble at given redshift(s). nH : float/jnp.array Value(s) of hydrogen number density at given redshift(s) - BG: cosmology.Background - Background cosmology object + args : tuple + (recomb_inputs, params) — recombination input arrays and + cosmological parameters. Returns ------- C : float/jnp.array Peebles C factor. """ - BG, params = args + recomb_inputs, params = args # (2p to 1s rate) x (1 - xHII), in s^{-1} rate_2p1s_times_x1s = 8*jnp.pi*H / (3 * (nH*(cnst.c/cnst.lya_freq)**3)) # s^{-1} rate_exc = 3. * rate_2p1s_times_x1s/4. + (1.-xHII) * cnst.R2s1s/4. - + # Ionization rate, in s^{-1} - rate_ion = (1-xHII) * beta_H( BG.TCMB( jnp.log(1/(1+z)) , params ) ) + rate_ion = (1-xHII) * beta_H( recomb_inputs.TCMB( jnp.log(1/(1+z)) ) ) return rate_exc / (rate_exc + rate_ion) diff --git a/abcmb/main.py b/abcmb/main.py index ee717a3..da32539 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -15,6 +15,7 @@ from . import background, perturbations, spectrum, model_specs from . import constants as cnst from .ABCMBTools import bilinear_interp +from .background import BackgroundPreRecomb, Background, ReionizationModelFromZ, ReionizationModelFromTau from .linx.background import BackgroundModel from .linx.abundances import AbundanceModel @@ -72,9 +73,9 @@ class Model(eqx.Module): specs : dict species_list : tuple = () - species_dict : dict - - PArthENoPE_CLASS_table : Array + species_dict : dict + + PArthENoPE_CLASS_table : Array thermo_model_DNeff : BackgroundModel abundanceModel : AbundanceModel @@ -142,7 +143,14 @@ def __init__(self, scale_pol=specs["scale_pol"] ) - # Initialize recombination model + # Initialize recombination model. Phase 2 of the HyRex CPU lift + # invokes RecModel under ``eqx.filter_jit(backend='cpu')``; we do + # NOT device_put RecModel to CPU here, because doing so would mix + # device platforms inside Model's pytree (RecModel on CPU, PE/SS + # on GPU) and the GPU-pinned ``_run_post_recomb`` jit would reject + # ``self`` as having "incompatible devices". JAX migrates RecModel + # to CPU lazily on first trace of the CPU jit — same pattern as + # the LINX call in ``add_derived_parameters``. self.RecModel = hyrex.recomb_model(adjoint=adjoint) # DO NOT CHANGE z1 FROM 0 # Initialize BBN model @@ -157,14 +165,23 @@ def __init__(self, self.adjoint = adjoint - # need this outside of the jit context - # since we want LINX to run on CPU + # NOTE on jit nesting (Phase 2 of the HyRex CPU lift): + # ``__call__`` is plain Python because it orchestrates a CPU-pinned + # HyRex jit between two GPU-pinned jits. Wrapping ``__call__`` in an + # outer ``eqx.filter_jit`` (default backend = GPU) would attempt to + # inline the inner CPU jit and fails with + # "Received incompatible devices for jitted computation". + # Same constraint applies to ``run_cosmology_abbr``. Mirrors the LINX + # pattern in ``add_derived_parameters``. def __call__(self, params : dict = {}): """ Compute CMB angular power spectra for given parameters. - Runs the full pipeline from background evolution through - perturbation integration to CMB power spectrum computation. + Runs the full pipeline: + params ─► add_derived_parameters (CPU LINX, unjitted) + ─► get_BG_pre_recomb (GPU JIT) + ─► RecModel ((recomb_inputs, params)) (CPU JIT) + ─► run_cosmology_abbr (GPU JIT) Parameters: ----------- @@ -173,23 +190,86 @@ def __call__(self, params : dict = {}): Returns: -------- - tuple - (ℓ values, (C_ℓ^TT, C_ℓ^TE, C_ℓ^EE)) for computed multipoles + Output + ClTT/ClTE/ClEE/Pk plus the BG and PT objects. """ - - full_params = self.add_derived_parameters(params) return self.run_cosmology_abbr(full_params) - - ### JITTED OR JITTABLE FUNCTIONS ### - @eqx.filter_jit + ### Top-level orchestration (Phase 2): GPU → CPU → GPU. ### def run_cosmology_abbr(self, params : dict): """ - Compute CMB angular power spectra for given parameters. + Orchestrate the full pipeline given derived params. NOT jit-wrapped; + contains a CPU-pinned HyRex call sandwiched between two GPU jits. + + Parameters: + ----------- + params : dict + Cosmological parameters (must already have derived keys). + + Returns: + -------- + Output + CMB power spectra and friends. + """ + # Cast int/bool params to float64 BEFORE entering any + # ``eqx.filter_jit``. Without this, filter_jit's diff/non-diff + # custom_vjp partition routes int leaves to "non-diff" and + # asserts ``perturbed=False`` (equinox/_ad.py:859). Under outer + # ``jax.grad``, those leaves are tracers with perturbed=True and + # the assertion trips. The defaults (e.g. ``N_nu_massive=0``, + # ``omega_Lambda=0``) are unused as integers anywhere downstream + # so the cast is safe. + def _to_float(v): + arr = jnp.asarray(v) + if arr.dtype.kind in 'iub': + return arr.astype(jnp.float64) + return arr + params = jax.tree_util.tree_map(_to_float, params) + + # Stage 1 (GPU JIT): tabulate conformal time + bundle recomb_inputs. + pre_BG = self.get_BG_pre_recomb(params) + + # Stage 2 (CPU JIT): HyRex consumes recomb_inputs, returns + # (xe, lna_xe, Tm, lna_Tm) — see ``hyrex.recomb_model.get_history``. + try: + cpu_dev = jax.devices('cpu')[0] + recomb_inputs_cpu = jax.device_put(pre_BG.recomb_inputs, cpu_dev) + params_cpu = jax.device_put(params, cpu_dev) + except Exception: + recomb_inputs_cpu = pre_BG.recomb_inputs + params_cpu = params + + recomb_output = eqx.filter_jit(self.RecModel, backend='cpu')((recomb_inputs_cpu, params_cpu)) + + try: + recomb_output = jax.device_put(recomb_output, jax.devices('gpu')[0]) + except Exception: + # No GPU: leave recomb_output where it is. + pass + + # ``recomb_output`` contains ``array_with_padding`` objects whose + # ``padding_size`` and ``lastnum`` are JAX int arrays from + # ``jnp.argmax``. Inside HyRex's CPU jit they are used as indices + # in ``lax.dynamic_update_slice`` (concat), but downstream of + # HyRex the only fields touched are ``arr`` and ``lastval``. The + # ``checkpointed_while_loop`` filter_custom_vjp inside + # ``_run_post_recomb``'s diffrax solves trips + # ``_get_value_assert_unperturbed`` on int leaves under outer + # AD; cast them to float at the boundary to suppress this. + recomb_output = jax.tree_util.tree_map(_to_float, recomb_output) + + # Stage 3 (GPU JIT): apply reionization, integrate optical depth, + # locate decoupling, integrate perturbations, build CMB spectra. + return self._run_post_recomb(params, pre_BG, recomb_output) + + @eqx.filter_jit + def get_BG_pre_recomb(self, params : dict): + """ + Pre-recomb stage: tabulate conformal time and bundle ``recomb_inputs``. - Runs the full pipeline from background evolution through - perturbation integration to CMB power spectrum computation. + This is the only piece of ``__call__`` that fires before HyRex; it + runs on whatever device JAX defaults to (typically GPU on Perlmutter). Parameters: ----------- @@ -198,10 +278,29 @@ def run_cosmology_abbr(self, params : dict): Returns: -------- - tuple - (ℓ values, (C_ℓ^TT, C_ℓ^TE, C_ℓ^EE)) for computed multipoles + BackgroundPreRecomb + """ + return BackgroundPreRecomb(params, self.species_list, self.RecModel, adjoint=self.adjoint) + + @eqx.filter_jit + def _run_post_recomb(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ + Post-recomb GPU stage: full Background construction (reionization, + optical depth, decoupling), perturbation evolution, CMB spectra. + + Parameters: + ----------- + params : dict + Cosmological parameters + pre_BG : BackgroundPreRecomb + Output of :meth:`get_BG_pre_recomb`. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. + Returns: + -------- + Output + """ # let the user know the code is compiling print("") print(' /\\ ') @@ -216,12 +315,12 @@ def run_cosmology_abbr(self, params : dict): print("") # Compute background and linear perturbations - PT, BG = self.get_PTBG(params) + PT, BG = self.get_PTBG(params, pre_BG, recomb_output) # Compute CMB power spectra Cls = self.SS.get_Cl(PT, BG, params) l = self.SS.ells - + # Compute linear matter power spectrum Pk = self.SS.Pk_lin(self.SS.k_axis_Pk_output, 0., PT, params) k = self.SS.k_axis_Pk_output @@ -235,61 +334,65 @@ def run_cosmology_abbr(self, params : dict): return output @eqx.filter_jit - def get_PTBG(self, params : dict): + def get_PTBG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ - Get perturbation table and background. + Get perturbation table and full Background. - Computes background and evolves perturbations for the given parameters. + Constructs the post-recomb Background from ``pre_BG`` + ``recomb_output`` + and runs the perturbation evolver. Parameters: ----------- params : dict Cosmological parameters + pre_BG : BackgroundPreRecomb + Pre-recomb stage object. + recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. Returns: -------- tuple - (PerturbationTable, Background) objects + (PerturbationTable, Background) """ - BG = self.get_BG(params) + BG = self.get_BG(params, pre_BG, recomb_output) PT = self.PE.full_evolution((BG, params)) - return PT, BG - @eqx.filter_jit - def get_BG(self, params : dict): + def get_BG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ - Get background for given parameters. + Construct the full ``Background`` from pre-recomb + HyRex output. + + Selects the reionization model (z-input vs tau-input) via ``lax.cond``. + NOT directly ``@eqx.filter_jit``-decorated; called from inside + ``_run_post_recomb`` (which is jit-wrapped). Parameters: ----------- params : dict Cosmological parameters + pre_BG : BackgroundPreRecomb + recomb_output : tuple Returns: -------- background.Background - Background object """ - # Bind to a local so both closures capture a plain class rather than - # an attribute lookup on self. The class is never placed in the - # lax.cond operand tuple (keeping it a valid JAX pytree). - adjoint = self.adjoint def get_BG_z_reion(args): - params, species_list, RecModel = args - return background.Background(params, species_list, RecModel, background.ReionizationModelFromZ, adjoint=adjoint) + params, pre_BG, recomb_output = args + return Background(pre_BG, recomb_output, params, ReionizationModelFromZ) def get_BG_tau_reion(args): - params, species_list, RecModel = args - return background.Background(params, species_list, RecModel, background.ReionizationModelFromTau, adjoint=adjoint) + params, pre_BG, recomb_output = args + return Background(pre_BG, recomb_output, params, ReionizationModelFromTau) BG = lax.cond( self.specs["input_tau_reion"], get_BG_tau_reion, get_BG_z_reion, - (params, self.species_list, self.RecModel) + (params, pre_BG, recomb_output) ) - + return BG def add_derived_parameters(self, param_in : dict) -> dict: diff --git a/abcmb/species.py b/abcmb/species.py index 95a178c..3c79701 100644 --- a/abcmb/species.py +++ b/abcmb/species.py @@ -41,10 +41,10 @@ class Fluid(eqx.Module): rho_plus_P_sigma : Compute standard shear perturbation (units: eV cm^{-3}) """ - delta_idx : int = eqx.field(default=0) + delta_idx : int = eqx.field(default=0, static=True) num_moments : int = eqx.field(default=0, static=True) name : str = eqx.field(default="", static=True) - is_matter : bool = eqx.field(default=False) # Does the fluid contribute towards matter overdensity today. + is_matter : bool = eqx.field(default=False, static=True) # Does the fluid contribute towards matter overdensity today. def __init__(self, delta_idx, specs): self.delta_idx = delta_idx From 35eb84ecbfcb28b72a846d7104b62b2c0ef800c2 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 5 May 2026 14:53:00 -0700 Subject: [PATCH 02/14] fix reverseAD lensing bug --- abcmb/spectrum.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/abcmb/spectrum.py b/abcmb/spectrum.py index dbe3ef6..60833f9 100644 --- a/abcmb/spectrum.py +++ b/abcmb/spectrum.py @@ -411,15 +411,27 @@ def lensing_Cl(self, ells, PT, BG, params): coeff = 8.*jnp.pi**2/(ells+0.5)**3 chi = lambda lna : BG.tau0 - BG.tau(lna) + # The previous jnp.nan_to_num(integrand, nan=0.) here masked the + # forward NaN but left a 0*NaN cotangent in the backward through + # the where-mask that nan_to_num secretly expands to, which + # propagated through BG.tau. Fix: substitute lna_safe everywhere, + # then mask the result to 0 at the boundary. + lna_axis = jnp.linspace(BG.lna_rec, 0., 500) + lna_floor = lna_axis[-2] + def integrand_func(lna): - k = (ells+0.5)/chi(lna) - window = (chi(BG.lna_rec) - chi(lna))/chi(BG.lna_rec)/chi(lna) - res = chi(lna)/BG.aH(lna, params) * window**2 * self.lensing_power_spectrum(k, lna, PT, BG, params) - return res + lna_safe = jnp.where(lna < 0., lna, lna_floor) + chi_safe = chi(lna_safe) + k = (ells+0.5)/chi_safe + window = (chi(BG.lna_rec) - chi_safe)/chi(BG.lna_rec)/chi_safe + res = ( + chi_safe / BG.aH(lna_safe, params) + * window**2 + * self.lensing_power_spectrum(k, lna_safe, PT, BG, params) + ) + return jnp.where(lna < 0., res, 0.) - lna_axis = jnp.linspace(BG.lna_rec, 0., 500) integrand = vmap(integrand_func)(lna_axis) - integrand = jnp.nan_to_num(integrand, nan=0.) return coeff*jnp.trapezoid(integrand, lna_axis, axis=0) def lensed_Cls(self, ells, ClTT_unlensed, ClTE_unlensed, ClEE_unlensed, PT, BG, params): From 11af8801b3b1770b5df32d399dc2b179df81747c Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 14:00:00 -0700 Subject: [PATCH 03/14] fixes remaining reverseAD bug --- abcmb/model_specs.py | 11 +++++++++++ abcmb/perturbations.py | 17 ++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/abcmb/model_specs.py b/abcmb/model_specs.py index 6cd57b1..2a82acd 100644 --- a/abcmb/model_specs.py +++ b/abcmb/model_specs.py @@ -67,6 +67,17 @@ def load_specs(input_specs): specs["pcoeff_PE"] = input_specs.get("pcoeff_PE", 0.25) specs["icoeff_PE"] = input_specs.get("icoeff_PE", 0.8) specs["dcoeff_PE"] = input_specs.get("dcoeff_PE", 0.) + # Newton convergence threshold for Kvaerno5's VeryChord root finder. + # diffrax default 0.01 is fine for FORWARD-only use, but the + # corresponding Newton iterate has a non-tight residual that can + # NaN reverse-mode AD on the lensing=True extension k-axis. + # If you encounter NaN cotangents under jax.grad / eqx.filter_grad, + # DECREASE this value (i.e. make it smaller, e.g. 1e-3) to tighten + # the convergence check. Tighter values cost a few extra Newton + # iterations per step in marginal cases; the forward result is also + # slightly more accurate but the effect on integrated outputs + # (Cl's, Pk) is small. + specs["kappa_PE"] = input_specs.get("kappa_PE", 0.01) ### Physical contributions to CMB temperature transfer function ### specs["scale_sw"] = input_specs.get("scale_sw", 1) diff --git a/abcmb/perturbations.py b/abcmb/perturbations.py index 3da858e..2c289ff 100644 --- a/abcmb/perturbations.py +++ b/abcmb/perturbations.py @@ -3,6 +3,8 @@ import numpy as np from jax import vmap, lax import diffrax +from diffrax import with_stepsize_controller_tols +from diffrax._root_finder import VeryChord import equinox as eqx from . import constants as cnst @@ -281,7 +283,20 @@ def evolution_one_k(self, k, lna, args): # Settings for post-tight coupling term = diffrax.ODETerm(self.get_derivatives) - solver = diffrax.Kvaerno5() + # Root finder: VeryChord (Kvaerno5's default Chord-method root + # finder), with kappa pulled from specs["kappa_PE"] (default + # 0.01 = diffrax default; user-tunable). If reverse-AD produces + # NaN cotangents on omega_b/omega_cdm, decrease kappa_PE in specs + # (e.g. to 1e-3) to tighten Newton convergence. The default kappa + # silently accepts a Newton residual that is not tight enough + # for ABCMB's stiff linear PE Jacobian under reverse-AD; see + # CHANGELOG 2026-05-12. + _rf = eqx.tree_at( + lambda s: s.kappa, + with_stepsize_controller_tols(VeryChord)(), + replace=self.specs["kappa_PE"], + ) + solver = diffrax.Kvaerno5(root_finder=_rf) rtol=jnp.where( k > self.specs["k_split_PE"], From 15a2fc55c7da627e8c14b13a919ee6e3b96afa88 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 15:05:39 -0700 Subject: [PATCH 04/14] auto-set kappa_PE --- abcmb/__init__.pyc | Bin 260 -> 0 bytes abcmb/main.py | 5 +++++ abcmb/model_specs.py | 14 ++++---------- 3 files changed, 9 insertions(+), 10 deletions(-) delete mode 100644 abcmb/__init__.pyc diff --git a/abcmb/__init__.pyc b/abcmb/__init__.pyc deleted file mode 100644 index a3ad9a0dcf1a43fddafa65490eb9a2ab63211dff..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 260 zcmYL>!Ab)`42Ea7Sj&K*Z;<1HrM`d^yFKVZy?EP8$nMO#1G6(Sld05)@MV1g(;^ti zm;X=rv-r8vOa1Kz^;eb52kCq*FiLg8N_CYowR&q7r^{BLckTYA-Dtgov3P$3`D!`$B?y5Jab^Ks;rZ13l2inrSf(fRas GS^ole@;)j6 diff --git a/abcmb/main.py b/abcmb/main.py index da32539..ced8f6d 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -108,6 +108,11 @@ def __init__(self, # tracing). Default preserves prior ForwardMode behavior. adjoint = kwargs.pop("adjoint", diffrax.ForwardMode) + # If user requested RecursiveCheckpointAdjoint, auto- + # tighten kappa_PE to a reverse-AD-safe value (unless another value is specified) + if adjoint is diffrax.RecursiveCheckpointAdjoint and "kappa_PE" not in kwargs: + kwargs["kappa_PE"] = 1e-3 + # Fill in all user defined and missing specs parameters specs = model_specs.load_specs(kwargs) self.specs = specs diff --git a/abcmb/model_specs.py b/abcmb/model_specs.py index 2a82acd..6a3dfba 100644 --- a/abcmb/model_specs.py +++ b/abcmb/model_specs.py @@ -67,16 +67,10 @@ def load_specs(input_specs): specs["pcoeff_PE"] = input_specs.get("pcoeff_PE", 0.25) specs["icoeff_PE"] = input_specs.get("icoeff_PE", 0.8) specs["dcoeff_PE"] = input_specs.get("dcoeff_PE", 0.) - # Newton convergence threshold for Kvaerno5's VeryChord root finder. - # diffrax default 0.01 is fine for FORWARD-only use, but the - # corresponding Newton iterate has a non-tight residual that can - # NaN reverse-mode AD on the lensing=True extension k-axis. - # If you encounter NaN cotangents under jax.grad / eqx.filter_grad, - # DECREASE this value (i.e. make it smaller, e.g. 1e-3) to tighten - # the convergence check. Tighter values cost a few extra Newton - # iterations per step in marginal cases; the forward result is also - # slightly more accurate but the effect on integrated outputs - # (Cl's, Pk) is small. + # Newton convergence threshold for Kvaerno5 VeryChord root finder. + # If you encounter NaNs in gradients, decrease this value (e.g. 1e-3). + # Small (<1%) performance hit for doing so, so default stays at 1e-2 + # unless reverseAD is requested (see main). specs["kappa_PE"] = input_specs.get("kappa_PE", 0.01) ### Physical contributions to CMB temperature transfer function ### From 24d8507ca39d546c1e44b8c74362213900e2b3e8 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 15:53:41 -0700 Subject: [PATCH 05/14] cleaned and reviewed background changes --- abcmb/background.py | 144 +++++++++++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 49 deletions(-) diff --git a/abcmb/background.py b/abcmb/background.py index dbdf91a..ace9d43 100644 --- a/abcmb/background.py +++ b/abcmb/background.py @@ -19,21 +19,17 @@ class BackgroundPreRecomb(eqx.Module): """ - Pre-recombination background-cosmology object (Phase 2 of HyRex CPU lift). + Pre-recombination background-cosmology object. - Holds everything HyRex needs to run on CPU: the conformal-time tabulation, - the species list, and a ``RecombInputs`` struct that bundles HyRex's input - arrays sampled on the recombination grid. None of these depend on xe, Tm, - or the optical depth, so this object is the natural input to the CPU-pinned - HyRex solve and the natural input to the post-recombination Background - construction (which inherits from this class). + Holds everything HyRex needs to run on CPU: (conformal-time tabulation, + species list, and HyRex input arrays via ``RecombInputs`` object). Attributes: ----------- species_list : tuple A list of all fluids in the cosmology lna_tau_tab : jnp.array - Log scale factor axis used to tabulate conformal time (class attribute) + Log scale factor axis used to tabulate conformal time tau_tab : jnp.array Tabulated conformal time. tau0 : float @@ -46,15 +42,23 @@ class BackgroundPreRecomb(eqx.Module): Methods: -------- - rho_tot, P_tot, H, aH, aH_prime, d2adtau2_over_a - tau, nH, TCMB, R_ratio_lna + rho_tot : Compute total energy density (units: eV cm^{-3}) + P_tot : Compute total pressure (units: eV cm^{-3}) + H : Compute Hubble parameter (units: s^{-1}) + aH : Compute conformal Hubble parameter (units: Mpc^{-1}) + aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) + d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) + tau : Compute conformal time (units: Mpc) + nH : Compute hydrogen number density (units: cm^{-3}) + TCMB : Compute CMB temperature (units: eV) + R_ratio_lna : Compute baryon drag ratio (units: dimensionless) """ species_list : tuple - lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) - tau_tab : jnp.array - tau0 : float + lna_tau_tab = jnp.linspace(-33.0, 0.0, 10000) # Axis for tabulating conformal time. + tau_tab : jnp.array # Tabulated conformal time. + tau0 : float # Conformal time today recomb_inputs : "RecombInputs" @@ -64,10 +68,8 @@ def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): """ Initialize pre-recombination background. - Tabulates conformal time and builds the ``RecombInputs`` struct - HyRex consumes. No reionization correction or optical-depth - integration is done here — those depend on the recombination - history and live on the post-recomb ``Background`` subclass. + Tabulates conformal time and builds the RecombInputs object for + HyRex. Parameters: ----------- @@ -76,7 +78,7 @@ def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): species_list : tuple List of fluid species for energy density calculations RecModel : hyrex.recomb_model - Used for its ``lna_axis_full`` sampling grid (not called here). + Recombination module for computing xe and Tm histories adjoint : diffrax.adjoint, optional Adjoint class for diffrax solves (default: ForwardMode) """ @@ -87,10 +89,7 @@ def __init__(self, params, species_list, RecModel, adjoint=ForwardMode): self.tau0 = self.tau(0.) # Bundle the background quantities HyRex needs onto its sampling - # grid. Phase 2 ships these to CPU (see ``Model.__call__``); for - # standard cosmologies the linear interpolation against this dense - # grid is accurate to ~3e-8 (h^2/8 with h=5e-4) — well below - # accuracy_test tolerances. + # grid (acccording to the input RecModel) lna_axis = RecModel.lna_axis_full self.recomb_inputs = RecombInputs( lna_grid = lna_axis, @@ -335,6 +334,15 @@ def tau(self, lna): -------- float Conformal time (units: Mpc) + + Notes: + ------ + IDEA: Make Background a repeatedly initiated module with both + species_list and params stored. Upon initiation, a full history + of conformal time is calculated with diffrax and stored for + interpolation. This can be done by approximating early time with + radiation approximation, and starting diffrax integration at the + early time with appropriate initial conditions. """ return tools.fast_interp(lna, self.lna_tau_tab[0], self.lna_tau_tab[-1], self.tau_tab) @@ -411,17 +419,29 @@ def R_ratio_lna(self, lna, params): class Background(BackgroundPreRecomb): """ - Full background-cosmology object: pre-recombination state plus - the recombination + reionization history and the optical-depth - tabulation. + Full Background cosmology module for cosmological calculations. Inherits all cosmology fields and methods from ``BackgroundPreRecomb``. - Construction takes a ``BackgroundPreRecomb`` (output of the GPU pre-recomb - stage) and the recombination output produced by HyRex on CPU, then - applies the reionization correction and integrates the optical depth. + Construction takes a ``BackgroundPreRecomb`` and the recombination output + from HyRex, then applies reionization and integrates the optical depth. + + This factorization allows HyRex to always run on CPU (its faster backend). Attributes: ----------- + species_list : tuple + A list of all fluids in the cosmology + lna_tau_tab : jnp.array + Log scale factor axis used to tabulate conformal time + tau_tab : jnp.array + Tabulated conformal time. + tau0 : float + Conformal time today in Mpc. + recomb_inputs : RecombInputs + Bundle of background quantities (TCMB, nH, H) sampled on + ``RecModel.lna_axis_full``; consumed by HyRex. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). xe_tab : array_with_padding Tabulated free electron fraction xe with reionization correction. lna_xe_tab : array_with_padding @@ -446,8 +466,15 @@ class Background(BackgroundPreRecomb): Log scale factor at which to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400. - Recombination Related Methods: - ------------------------------ + Methods: + -------- + rho_tot : Compute total energy density (units: eV cm^{-3}) + P_tot : Compute total pressure (units: eV cm^{-3}) + H : Compute Hubble parameter (units: s^{-1}) + aH : Compute conformal Hubble parameter (units: Mpc^{-1}) + aH_prime : Compute derivative of conformal Hubble (units: Mpc^{-1}) + d2adtau2_over_a : Compute second derivative of scale factor (units: Mpc^{-2}) + tau : Compute conformal time (units: Mpc) xe : Compute free electron fraction (units: dimensionless) Tm : Compute matter temperature (units: eV) tau_c : Compute Thomson scattering time (units: Mpc) @@ -465,27 +492,29 @@ class Background(BackgroundPreRecomb): z_reion : float tau_reion : float lna_rec : float - rA_rec : float + rA_rec : float # Comoving angular diameter distance at recombination. - lna_transfer_start : float - lna_visibility_stop : float + # Transfer related + lna_transfer_start : float # Time where transfer functions start integrating. + lna_visibility_stop : float # Time to stop integrating T1, T2, and E sources due to small visibility functions. Only used for l<400 def __init__(self, pre_BG, recomb_output, params, ReionModel): """ - Construct full Background from a pre-recomb stage and the HyRex output. + Initialize Background cosmology module. + + Consolidates pre-recombination and recombination elements of background cosmology. Parameters: ----------- pre_BG : BackgroundPreRecomb - Output of the GPU pre-recomb stage; provides species_list, + Output of the pre-recomb stage; provides species_list, tau_tab, tau0, recomb_inputs, adjoint. recomb_output : tuple - HyRex's ``(xe, lna_xe, Tm, lna_Tm)`` quadruple — the result of - running ``RecModel((pre_BG.recomb_inputs, params))`` on CPU. + HyRex output ``(xe, lna_xe, Tm, lna_Tm)`` quadruple params : dict Cosmological parameters. - ReionModel : type - ``ReionizationModelFromZ`` or ``ReionizationModelFromTau``. + ReionModel : callable + Reionization module for computing the xe correction. """ # Copy pre-recomb fields onto self. self.adjoint = pre_BG.adjoint @@ -494,7 +523,7 @@ def __init__(self, pre_BG, recomb_output, params, ReionModel): self.tau0 = pre_BG.tau0 self.recomb_inputs = pre_BG.recomb_inputs - # Unpack HyRex output and apply reionization correction. + # Unpack HyRex output and apply reionization. xe, self.lna_xe_tab, self.Tm_tab, self.lna_Tm_tab = recomb_output reion_model = ReionModel(self, params) @@ -537,6 +566,14 @@ def xe(self, lna): -------- float Free electron fraction (units: dimensionless) + + Notes: + ------ + The logic flow is equivalent to: + + if lna < self.lna_xe_tab.arr[0]: return self.xe_tab[0] + elif lna > self.lna_xe_tab.lastval: return self.xe_tab.lastval + else: return jnp.interp(lna, self.lna_xe_tab, self.xe_tab) """ return jnp.where( lna < self.lna_xe_tab.arr[0], @@ -643,6 +680,11 @@ def _tabulate_optical_depth(self, params): -------- array Tabulated optical depth values (units: dimensionless) + + Notes: + ------ + Also computes time derivative of optical depth, which is the + integrand involving the free electron fraction. """ integrand = lambda lna, y, args: -1./self.tau_c(lna, params)/self.aH(lna, params) term = ODETerm(integrand) @@ -676,7 +718,7 @@ def expmkappa(self, lna): Returns: -------- float - exp(-κ) (units: dimensionless) + exp(-(optical depth)) (units: dimensionless) """ return jnp.where( lna < -10., @@ -703,6 +745,10 @@ def visibility(self, lna, params): -------- float Visibility function (units: Mpc^{-1}) + + Notes: + ------ + Used in computing source functions for CMB anisotropies. """ return self.expmkappa(lna)/self.tau_c(lna, params) @@ -729,6 +775,7 @@ def find_z_at_kappad_equals_one(self, z, kappa_d): float Decoupling redshift (units: dimensionless) """ + # ensure sorted ascending idx = jnp.argsort(z) z_sorted = z[idx] kappa_d_sorted = jnp.abs(kappa_d)[idx] @@ -759,7 +806,6 @@ def interp_rs_at_z(self, z_bg, r_s, z_d): rs_sorted = r_s[idx] return jnp.interp(z_d, z_sorted, rs_sorted) - @jax.named_scope("tabulate kappa d") def _tabulate_kappa_d(self, params): """ Tabulate baryon optical depth. @@ -784,20 +830,19 @@ def _tabulate_kappa_d(self, params): solution = diffeqsolve( term, - solver=Tsit5(), + solver=Tsit5(), # Kvaerno5 is just slower but gives same result stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[-1], - t1=self.lna_tau_tab[0], + t0=self.lna_tau_tab[-1], # Initial x value (~0 in this case) + t1=self.lna_tau_tab[0], # Final x value (smallest x value) dt0=-1e-3, max_steps=2048, - y0=0.0, - saveat=SaveAt(ts=self.lna_tau_tab[::-1]), + y0=0.0, # Initial value tau(x=0) = 0 + saveat=SaveAt(ts=self.lna_tau_tab[::-1]), # Save at all points in x, reverse order since integrating backwards adjoint=adjoint ) result = solution.ys[::-1] return result - @jax.named_scope("tabulate rs") def _tabulate_rs(self, params): """ Tabulate sound horizon evolution. @@ -815,6 +860,7 @@ def _tabulate_rs(self, params): array Tabulated sound horizon values (units: Mpc) """ + # initial condition assuming cs**2 = 1/3 at early times rs0 = 1./jnp.sqrt(3) / (self.aH( self.lna_tau_tab[0], params )) integrand = lambda lna, y, args: 1./jnp.sqrt(3*(1+self.R_ratio_lna(lna, params))) / (self.aH(lna, params)) @@ -826,7 +872,7 @@ def _tabulate_rs(self, params): term, solver=Tsit5(), stepsize_controller=stepsize_controller, - t0=self.lna_tau_tab[0], + t0=self.lna_tau_tab[0], # reversed direction since I know rs at early times t1=self.lna_tau_tab[-1], dt0=1e-3, max_steps=2048, From 63045cc3ee5395022b6e2805f23d80f0c58c8376 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 15:58:49 -0700 Subject: [PATCH 06/14] review and cleaning --- abcmb/hyrex/helium.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/abcmb/hyrex/helium.py b/abcmb/hyrex/helium.py index 78740a7..5272a6d 100644 --- a/abcmb/hyrex/helium.py +++ b/abcmb/hyrex/helium.py @@ -55,6 +55,8 @@ def __init__(self, lna_axis_4Heequil, integration_spacing = 5.0e-4, Nsteps=800, Initial redshift (default: 8000.) z1 : float, optional Final redshift (default: 20.) + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). """ self.integration_spacing = integration_spacing self.concrete_axis_size_postSahaHe = jnp.zeros(Nsteps_postSahaHe) @@ -215,8 +217,7 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. - # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp + # eqx.internal.while_loop with kind='checkpointed' uses a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse # checkpointing. max_steps must be a static upper bound. final_state = eqx.internal.while_loop( @@ -359,7 +360,6 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse # checkpointing. max_steps must be a static upper bound. From 8b69c92b952fde6845d6fd76e33cd0ba5a56144f Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:01:19 -0700 Subject: [PATCH 07/14] hydrogen review and cleaning --- abcmb/hyrex/hydrogen.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/abcmb/hyrex/hydrogen.py b/abcmb/hyrex/hydrogen.py index 76ccec0..b9d914a 100644 --- a/abcmb/hyrex/hydrogen.py +++ b/abcmb/hyrex/hydrogen.py @@ -70,6 +70,9 @@ def __init__(self, xe_4He, lna_4He, lna_end, last_4He_lna, twog_redshift, integr Maximum number of integration steps (default: 800) swift : array, optional SWIFT correction function tabulation + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.swift = swift @@ -259,11 +262,9 @@ def stop_condition(state): # Initial state: (xe_output, xe, iz, stop flag) initial_state = (xe_output, lna_output, xe, iz, stop) - # Run the while loop until the stop condition is met. # eqx.internal.while_loop with kind='checkpointed' installs a custom_vjp # so reverse-mode AD can traverse this dynamic-stop loop via treeverse - # checkpointing. max_steps must be a static upper bound; the output - # axis size serves that role here. + # checkpointing. max_steps must be a static upper bound final_state = eqx.internal.while_loop( stop_condition, compute_xe, initial_state, max_steps=self.concrete_axis_size.size, From 4ce6d98ce863b685d66caf4fcf638eedef6e1f08 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:09:13 -0700 Subject: [PATCH 08/14] review and cleanup for hyrex main module --- abcmb/hyrex/hyrex.py | 64 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/abcmb/hyrex/hyrex.py b/abcmb/hyrex/hyrex.py index 8e16c47..90b3857 100644 --- a/abcmb/hyrex/hyrex.py +++ b/abcmb/hyrex/hyrex.py @@ -15,29 +15,25 @@ class RecombInputs(eqx.Module): """ - Bundle of background quantities sampled on a fixed lna grid, consumed by - the recombination model in place of a full ``Background`` instance. - - Phase 1 of the HyRex CPU lift refactors HyRex to depend only on these - arrays, so the recombination kernel becomes physics-agnostic and the - GPU/CPU device boundary has a clean interface. + Bundle of pre-recombination background quantities sampled on a + fixed lna grid for computing recombination. Attributes ---------- lna_grid : jnp.array - Uniform log scale-factor sampling axis. + Uniform log scale-factor sampling axis. (units: dimensionless) TCMB_arr : jnp.array - Photon-bath temperature TCMB(lna), eV. + Photon-bath temperature TCMB(lna) (units: eV) nH_arr : jnp.array - Hydrogen number density nH(lna), cm^-3. + Hydrogen number density nH(lna) (units: cm^-3) H_arr : jnp.array - Hubble parameter H(lna), s^-1. + Hubble parameter H(lna) (units: s^-1) Methods ------- - TCMB(lna), nH(lna), H(lna) - Linear interpolation of the corresponding stored array at lna, - using ``ABCMBTools.fast_interp`` (uniform-grid path). + TCMB : Linear interpolation of CMB temperature over lna (units: eV) + nH : Linear interpolation of hydrogen number density over lna (units: cm^-3) + H : Linear interpolation of Hubble over lna (units: s^-1) """ lna_grid : jnp.array @@ -46,12 +42,51 @@ class RecombInputs(eqx.Module): H_arr : jnp.array def TCMB(self, lna): + """ + Linearly interpolate CMB temperature at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + CMB temperature TCMB(lna) (units: eV). + """ return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.TCMB_arr) def nH(self, lna): + """ + Linearly interpolate hydrogen number density at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + Hydrogen number density nH(lna) (units: cm^-3). + """ return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.nH_arr) def H(self, lna): + """ + Linearly interpolate Hubble parameter at lna. + + Parameters: + ----------- + lna : float + Logarithm of scale factor. + + Returns: + -------- + float + Hubble parameter H(lna) (units: s^-1). + """ return fast_interp(lna, self.lna_grid[0], self.lna_grid[-1], self.H_arr) class recomb_model(eqx.Module): @@ -92,6 +127,9 @@ def __init__(self, integration_spacing = 5.0e-4, z0=8000., z1=0., adjoint = Forw Initial redshift (default: 8000.) z1 : float, optional Final redshift (default: 0.) + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.adjoint = adjoint From 1cb052f7b5d955e9df2550c7ee5d766debde6518 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:12:05 -0700 Subject: [PATCH 09/14] review and cleaning for the rest of hyrex --- abcmb/hyrex/helium.py | 3 ++- abcmb/hyrex/recomb_functions.py | 5 +---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/abcmb/hyrex/helium.py b/abcmb/hyrex/helium.py index 5272a6d..d8f5ed5 100644 --- a/abcmb/hyrex/helium.py +++ b/abcmb/hyrex/helium.py @@ -56,7 +56,8 @@ def __init__(self, lna_axis_4Heequil, integration_spacing = 5.0e-4, Nsteps=800, z1 : float, optional Final redshift (default: 20.) adjoint : diffrax.adjoint - Adjoint mode for diffrax solves (static field). + Adjoint mode for diffrax solves (static field). Defaults + to ForwardMode. """ self.integration_spacing = integration_spacing self.concrete_axis_size_postSahaHe = jnp.zeros(Nsteps_postSahaHe) diff --git a/abcmb/hyrex/recomb_functions.py b/abcmb/hyrex/recomb_functions.py index 7d9a017..adbe0e0 100644 --- a/abcmb/hyrex/recomb_functions.py +++ b/abcmb/hyrex/recomb_functions.py @@ -13,10 +13,7 @@ #Tabulated values of 2s-2p transition rates to interpolate. alpha_tab = jnp.array(np.loadtxt(file_dir+"/tabs/Alpha_inf.dat")) -# Phase 2 of the HyRex CPU lift: the recombination model now executes -# under ``eqx.filter_jit(backend='cpu')`` (orchestrated from -# ``Model.__call__``), so these lookup tables should live on CPU. Pinning -# them to CPU here avoids per-call device migration when HyRex traces. +# pin to CPU try: cpus = devices('cpu') R_tab = device_put(R_tab, device=cpus[0]) From 79f21c067d78b14dc8ed4fa3c58f1a59eaf33889 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:14:34 -0700 Subject: [PATCH 10/14] review and cleaning for linx --- abcmb/linx/abundances.py | 3 +++ abcmb/linx/background.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/abcmb/linx/abundances.py b/abcmb/linx/abundances.py index f0559f8..cfde07d 100644 --- a/abcmb/linx/abundances.py +++ b/abcmb/linx/abundances.py @@ -40,6 +40,9 @@ class AbundanceModel(eqx.Module): Binding energy of each species. species_mass : list Mass of each species. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Defaults + to ForwardMode. """ nuclear_net : nucl.NuclearRates diff --git a/abcmb/linx/background.py b/abcmb/linx/background.py index c102856..f75df7b 100644 --- a/abcmb/linx/background.py +++ b/abcmb/linx/background.py @@ -31,6 +31,8 @@ class BackgroundModel(eqx.Module): Whether to use leading order QED correction. Default is `True`. NLO : bool, optional Whether to use next-to-leading order QED correction. Default is True. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. """ decoupled : bool From d7100e01da09ea9519c80b428f29aed841db8f5e Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:40:46 -0700 Subject: [PATCH 11/14] clean and review main --- abcmb/main.py | 126 +++++++++++++++++++------------------------------- 1 file changed, 48 insertions(+), 78 deletions(-) diff --git a/abcmb/main.py b/abcmb/main.py index ced8f6d..4d8e673 100644 --- a/abcmb/main.py +++ b/abcmb/main.py @@ -58,6 +58,8 @@ class Model(eqx.Module): A LINX abundance model used for computing the helium-4 mass fraction given the user's input baryon density, Neff, neutron lifetime, and nuclear reaction rates. + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. Methods: -------- @@ -105,11 +107,11 @@ def __init__(self, # Pull adjoint out of kwargs before load_specs — it must NOT end up # inside self.specs (a non-JAX pytree leaf breaks lax.cond / filter_jit - # tracing). Default preserves prior ForwardMode behavior. + # tracing). adjoint = kwargs.pop("adjoint", diffrax.ForwardMode) - # If user requested RecursiveCheckpointAdjoint, auto- - # tighten kappa_PE to a reverse-AD-safe value (unless another value is specified) + # If user requested RecursiveCheckpointAdjoint, auto-tighten kappa_PE + # to a reverse-AD-safe value (unless another value is specified) if adjoint is diffrax.RecursiveCheckpointAdjoint and "kappa_PE" not in kwargs: kwargs["kappa_PE"] = 1e-3 @@ -148,14 +150,7 @@ def __init__(self, scale_pol=specs["scale_pol"] ) - # Initialize recombination model. Phase 2 of the HyRex CPU lift - # invokes RecModel under ``eqx.filter_jit(backend='cpu')``; we do - # NOT device_put RecModel to CPU here, because doing so would mix - # device platforms inside Model's pytree (RecModel on CPU, PE/SS - # on GPU) and the GPU-pinned ``_run_post_recomb`` jit would reject - # ``self`` as having "incompatible devices". JAX migrates RecModel - # to CPU lazily on first trace of the CPU jit — same pattern as - # the LINX call in ``add_derived_parameters``. + # Initialize recombination model. self.RecModel = hyrex.recomb_model(adjoint=adjoint) # DO NOT CHANGE z1 FROM 0 # Initialize BBN model @@ -170,23 +165,12 @@ def __init__(self, self.adjoint = adjoint - # NOTE on jit nesting (Phase 2 of the HyRex CPU lift): - # ``__call__`` is plain Python because it orchestrates a CPU-pinned - # HyRex jit between two GPU-pinned jits. Wrapping ``__call__`` in an - # outer ``eqx.filter_jit`` (default backend = GPU) would attempt to - # inline the inner CPU jit and fails with - # "Received incompatible devices for jitted computation". - # Same constraint applies to ``run_cosmology_abbr``. Mirrors the LINX - # pattern in ``add_derived_parameters``. + # need this outside of the main jit context + # since we want LINX/HyRex to run on CPU def __call__(self, params : dict = {}): """ - Compute CMB angular power spectra for given parameters. - - Runs the full pipeline: - params ─► add_derived_parameters (CPU LINX, unjitted) - ─► get_BG_pre_recomb (GPU JIT) - ─► RecModel ((recomb_inputs, params)) (CPU JIT) - ─► run_cosmology_abbr (GPU JIT) + Runs the full pipeline from background evolution through + perturbation integration to CMB power spectrum computation. Parameters: ----------- @@ -196,16 +180,21 @@ def __call__(self, params : dict = {}): Returns: -------- Output - ClTT/ClTE/ClEE/Pk plus the BG and PT objects. + Bundle of CMB power spectra (ClTT, ClTE, ClEE) and their + multipole grid l, matter power spectrum Pk and its k-grid, + the Background and PerturbationTable objects, and the + full parameter dict including derived keys. """ full_params = self.add_derived_parameters(params) return self.run_cosmology_abbr(full_params) - ### Top-level orchestration (Phase 2): GPU → CPU → GPU. ### + def run_cosmology_abbr(self, params : dict): """ - Orchestrate the full pipeline given derived params. NOT jit-wrapped; - contains a CPU-pinned HyRex call sandwiched between two GPU jits. + Compute CMB angular power spectra for given parameters. + + Runs the full pipeline from background evolution through + perturbation integration to CMB power spectrum computation. Parameters: ----------- @@ -217,14 +206,9 @@ def run_cosmology_abbr(self, params : dict): Output CMB power spectra and friends. """ - # Cast int/bool params to float64 BEFORE entering any - # ``eqx.filter_jit``. Without this, filter_jit's diff/non-diff - # custom_vjp partition routes int leaves to "non-diff" and - # asserts ``perturbed=False`` (equinox/_ad.py:859). Under outer - # ``jax.grad``, those leaves are tracers with perturbed=True and - # the assertion trips. The defaults (e.g. ``N_nu_massive=0``, - # ``omega_Lambda=0``) are unused as integers anywhere downstream - # so the cast is safe. + # Cast int/bool params to float64 before entering any + # ``eqx.filter_jit`` for custom_vjp/AD safety in + # checkpointed_while_loop def _to_float(v): arr = jnp.asarray(v) if arr.dtype.kind in 'iub': @@ -232,49 +216,33 @@ def _to_float(v): return arr params = jax.tree_util.tree_map(_to_float, params) - # Stage 1 (GPU JIT): tabulate conformal time + bundle recomb_inputs. pre_BG = self.get_BG_pre_recomb(params) - # Stage 2 (CPU JIT): HyRex consumes recomb_inputs, returns - # (xe, lna_xe, Tm, lna_Tm) — see ``hyrex.recomb_model.get_history``. - try: - cpu_dev = jax.devices('cpu')[0] - recomb_inputs_cpu = jax.device_put(pre_BG.recomb_inputs, cpu_dev) - params_cpu = jax.device_put(params, cpu_dev) - except Exception: - recomb_inputs_cpu = pre_BG.recomb_inputs - params_cpu = params + cpu_dev = jax.devices('cpu')[0] + recomb_inputs_cpu = jax.device_put(pre_BG.recomb_inputs, cpu_dev) + params_cpu = jax.device_put(params, cpu_dev) recomb_output = eqx.filter_jit(self.RecModel, backend='cpu')((recomb_inputs_cpu, params_cpu)) try: recomb_output = jax.device_put(recomb_output, jax.devices('gpu')[0]) except Exception: - # No GPU: leave recomb_output where it is. pass - # ``recomb_output`` contains ``array_with_padding`` objects whose - # ``padding_size`` and ``lastnum`` are JAX int arrays from - # ``jnp.argmax``. Inside HyRex's CPU jit they are used as indices - # in ``lax.dynamic_update_slice`` (concat), but downstream of - # HyRex the only fields touched are ``arr`` and ``lastval``. The - # ``checkpointed_while_loop`` filter_custom_vjp inside - # ``_run_post_recomb``'s diffrax solves trips - # ``_get_value_assert_unperturbed`` on int leaves under outer - # AD; cast them to float at the boundary to suppress this. + # recomb_output contains array_with_padding objects whose + # padding_size and lastnum int arrays. The + # checkpointed_while_loop's filter_custom_vjp inside + # _run_post_recomb's diffrax solves trips an internal + # _get_value_assert_unperturbed on int leaves under outer + # AD; convert to float to avoid. recomb_output = jax.tree_util.tree_map(_to_float, recomb_output) - # Stage 3 (GPU JIT): apply reionization, integrate optical depth, - # locate decoupling, integrate perturbations, build CMB spectra. return self._run_post_recomb(params, pre_BG, recomb_output) @eqx.filter_jit def get_BG_pre_recomb(self, params : dict): """ - Pre-recomb stage: tabulate conformal time and bundle ``recomb_inputs``. - - This is the only piece of ``__call__`` that fires before HyRex; it - runs on whatever device JAX defaults to (typically GPU on Perlmutter). + Pre-recomb stage: tabulate conformal time and bundle H, T, nH for recombination. Parameters: ----------- @@ -285,12 +253,24 @@ def get_BG_pre_recomb(self, params : dict): -------- BackgroundPreRecomb """ + # let the user know the code is compiling + print("") + print(' /\\ ') + print(' / \\ ') + print(' / /\\ \\ ') + print(' / /__\\ \\ ___ ___ ') + print(' / ______ \\ | _ \\ / __\\ _ _ ') + print(' / / \\ \\ | _// / | \\/ | __ ') + print(' / / \\ \\| _ \\\\ \\___||\\/||| -) ') + print(' /_/ \\_|___/ \\___/|| |||_-) is compiling...') + print('\\_____/ ') + print("") return BackgroundPreRecomb(params, self.species_list, self.RecModel, adjoint=self.adjoint) @eqx.filter_jit def _run_post_recomb(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): """ - Post-recomb GPU stage: full Background construction (reionization, + Post-recombination stage: full Background construction (reionization, optical depth, decoupling), perturbation evolution, CMB spectra. Parameters: @@ -306,18 +286,6 @@ def _run_post_recomb(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb -------- Output """ - # let the user know the code is compiling - print("") - print(' /\\ ') - print(' / \\ ') - print(' / /\\ \\ ') - print(' / /__\\ \\ ___ ___ ') - print(' / ______ \\ | _ \\ / __\\ _ _ ') - print(' / / \\ \\ | _// / | \\/ | __ ') - print(' / / \\ \\| _ \\\\ \\___||\\/||| -) ') - print(' /_/ \\_|___/ \\___/|| |||_-) is compiling...') - print('\\_____/ ') - print("") # Compute background and linear perturbations PT, BG = self.get_PTBG(params, pre_BG, recomb_output) @@ -351,7 +319,7 @@ def get_PTBG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output) params : dict Cosmological parameters pre_BG : BackgroundPreRecomb - Pre-recomb stage object. + Pre-recombination stage object. recomb_output : tuple HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. @@ -377,7 +345,9 @@ def get_BG(self, params : dict, pre_BG : "BackgroundPreRecomb", recomb_output): params : dict Cosmological parameters pre_BG : BackgroundPreRecomb + Pre-recombination stage object. recomb_output : tuple + HyRex output ``(xe, lna_xe, Tm, lna_Tm)``. Returns: -------- From 51d0024b78ee1cfc355846c790c5b6654f25b27d Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:41:06 -0700 Subject: [PATCH 12/14] no try/except --- abcmb/hyrex/recomb_functions.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/abcmb/hyrex/recomb_functions.py b/abcmb/hyrex/recomb_functions.py index adbe0e0..383da61 100644 --- a/abcmb/hyrex/recomb_functions.py +++ b/abcmb/hyrex/recomb_functions.py @@ -14,12 +14,9 @@ alpha_tab = jnp.array(np.loadtxt(file_dir+"/tabs/Alpha_inf.dat")) # pin to CPU -try: - cpus = devices('cpu') - R_tab = device_put(R_tab, device=cpus[0]) - alpha_tab = device_put(alpha_tab, device=cpus[0]) -except Exception: - pass +cpus = devices('cpu') +R_tab = device_put(R_tab, device=cpus[0]) +alpha_tab = device_put(alpha_tab, device=cpus[0]) # File handling and interpolating related constants. # Do not change these unless something about the tabulated files have changed. From 99658d8d37aeca25536438079117b84fd9d8ec40 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:52:54 -0700 Subject: [PATCH 13/14] clean and review perturbations --- abcmb/perturbations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/abcmb/perturbations.py b/abcmb/perturbations.py index 2c289ff..7757cf3 100644 --- a/abcmb/perturbations.py +++ b/abcmb/perturbations.py @@ -40,6 +40,8 @@ class PerturbationEvolver(eqx.Module): A list of wavenumbers k at which to compute perturbations specs : dict A dictionary containing run options + adjoint : diffrax.adjoint + Adjoint mode for diffrax solves. Default is ForwardMode. Methods: -------- @@ -283,14 +285,12 @@ def evolution_one_k(self, k, lna, args): # Settings for post-tight coupling term = diffrax.ODETerm(self.get_derivatives) - # Root finder: VeryChord (Kvaerno5's default Chord-method root - # finder), with kappa pulled from specs["kappa_PE"] (default - # 0.01 = diffrax default; user-tunable). If reverse-AD produces - # NaN cotangents on omega_b/omega_cdm, decrease kappa_PE in specs - # (e.g. to 1e-3) to tighten Newton convergence. The default kappa - # silently accepts a Newton residual that is not tight enough - # for ABCMB's stiff linear PE Jacobian under reverse-AD; see - # CHANGELOG 2026-05-12. + # LCDM defaults for very high k sometimes failed with reverseAD. + # The reason for that was the default precision parameter in + # the Kvaerno5 rootfinder (via VeryChord). Parameter now read + # in from specs; if reverse-AD produces NaNs (especially on omega_b/ + # omega_cdm), decrease kappa_PE in specs (e.g. to 1e-3) to tighten + # convergence. _rf = eqx.tree_at( lambda s: s.kappa, with_stepsize_controller_tols(VeryChord)(), From 00873e01cc4ff28de158ad7589f141de9d62af48 Mon Sep 17 00:00:00 2001 From: Cara Giovanetti Date: Tue, 12 May 2026 16:58:48 -0700 Subject: [PATCH 14/14] increment version --- abcmb/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abcmb/version.py b/abcmb/version.py index 7a17bdd..845be45 100644 --- a/abcmb/version.py +++ b/abcmb/version.py @@ -1 +1 @@ -__version__ = "0.2.4" \ No newline at end of file +__version__ = "0.2.5" \ No newline at end of file