Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions linx/abundances.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,14 @@ def __call__(
# Default SaveAt
saveat = SaveAt(t1=True)

# Precompute rate tables once, outside the ODE RHS.
rate_tables = self.nuclear_net.precompute_rate_tables(nuclear_rates_q)
sol = diffeqsolve(
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args=(
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q
nTOp_bkwrd, eta_fac, tau_n_fac, rate_tables
),
saveat=saveat,
stepsize_controller=PIDController(
Expand Down Expand Up @@ -488,7 +490,18 @@ def Y_prime(self, t, Y, args):
Y : array
Array of abundances for evaluating :math:`dY_i/dt`.
args : tuple of arrays
Other relevant information for evaluating the derivative. These are respectively, 0) an array of scale factors; 1) an array of times; 2) an array of EM sector temperatures; 3) an array representing the abscissa of EM sector temperatures for evaluating weak rates; 4) an array of n -> p rates to interpolate over; 5) an array of p -> n rates to interpolate over; 6) the rescaling factor for baryon-to-photon ratio `eta_fac`; 7) the rescaling factor for neutron decay lifetime `tau_n_fac` and 8) the array rescaling nuclear rates, `nuclear_rates_q`.
Other relevant information for evaluating the derivative.
These are respectively,
0) an array of scale factors;
1) an array of times;
2) an array of EM sector temperatures;
3) an array representing the abscissa of EM sector temperatures
for evaluating weak rates;
4) an array of n -> p rates to interpolate over;
5) an array of p -> n rates to interpolate over;
6) the rescaling factor for baryon-to-photon ratio `eta_fac`;
7) the rescaling factor for neutron decay lifetime `tau_n_fac` and
8) the pre-computed rate tables, `rate_tables` (built once from nuclear_rates_q)

Returns
-------
Expand All @@ -505,7 +518,7 @@ def Y_prime(self, t, Y, args):
nTOp_bkwrd_vec_in = args[5]
eta_fac = args[6]
tau_n_fac = args[7]
nuclear_rates_q = args[8]
rate_tables = args[8]

a_in = a_vec_in[0]
a_fin = a_vec_in[-1]
Expand All @@ -528,7 +541,7 @@ def Y_prime(self, t, Y, args):
dY = self.nuclear_net(
Y, T_t, rhoBBN, T_interval_in, nTOp_frwrd_vec_in,
nTOp_bkwrd_vec_in, tau_n_fac=tau_n_fac,
nuclear_rates_q=nuclear_rates_q
rate_tables = rate_tables
)

return dY
Expand Down
41 changes: 34 additions & 7 deletions linx/nuclear.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,34 @@ def __init__(
self.bkwrd_reaction_by_particle[i].append(rxn.name)

self.reactions_names.append(rxn.name)

def precompute_rate_tables(self, nuclear_rates_q=None):
"""
Pre-compute the per-reaction rate tables.

Parameters
----------
nuclear_rates_q : array, optional
Rescaling parameter of expsigma in nuclear rate. Default
is None (no rescaling).

Returns
-------
list
One entry per reaction (in order of `self.reactions`), as returned
by `Reaction.frwrd_rate_table`.
"""

if nuclear_rates_q is None:
nuclear_rates_q = jnp.array([0. for _ in self.reactions])
return [rxn.frwrd_rate_table(nuclear_rates_q[i]) for i, rxn in enumerate(self.reactions)]


@eqx.filter_jit
def __call__(
self, Y, T_t, rhoBBN, T_interval,
nTOp_frwrd_vec, nTOp_bkwrd_vec, tau_n_fac=1., nuclear_rates_q=None
nTOp_frwrd_vec, nTOp_bkwrd_vec, tau_n_fac=1., nuclear_rates_q=None,
rate_tables=None
):
"""
Returns the rate of change of the abundances.
Expand All @@ -169,16 +191,21 @@ def __call__(
nuclear_rates_q : array
Rescaling parameter of expsigma in nuclear rate. If None,
no rescaling is assumed.
rate_tables : list, optional
Precomputed rate tables (avoid rebuilding them on every ODE step).
Defaults to None, meaning they built once given `nuclear_rates_q`.

Returns
-------
Array
dY/dt in s^-1. Same dimensions as Y.
"""

if nuclear_rates_q is None:
if rate_tables is None:

nuclear_rates_q = jnp.array([0. for _ in self.reactions])
rate_tables = self.precompute_rate_tables(nuclear_rates_q) # nuclear_rates_q = None
# is handled appropriately
# in this fn

dYdt_vec = jnp.zeros(len(Y))

Expand All @@ -204,13 +231,13 @@ def __call__(

# These functions take temperature in K.
frwrd_rate_params = {
rxn.name:self.frwrd_rate_param[rxn.name](
T_t / const.kB, nuclear_rates_q[i]
rxn.name:rxn.frwrd_rate_from_table(
T_t / const.kB, rate_tables[i]
) for i,rxn in enumerate(self.reactions)
}
bkwrd_rate_params = {
rxn.name:self.bkwrd_rate_param[rxn.name](
T_t / const.kB, nuclear_rates_q[i]
rxn.name:rxn.bkwrd_rate_from_table(
T_t / const.kB, rate_tables[i]
) for i,rxn in enumerate(self.reactions)
}

Expand Down
130 changes: 114 additions & 16 deletions linx/reactions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import numpy as np

Expand Down Expand Up @@ -59,11 +60,14 @@ class Reaction(eqx.Module):
alpha : float
beta : float
gamma : float
T9_vec : list
T9_vec : list
mu_median_vec : list
expsigma_vec : list
interp_type : str
frwrd_rate_param_func : callable
log_T9_vec : list
log_mu_median_vec : list
log_expsigma_vec : list
interp_type : str
frwrd_rate_param_func : callable

def __init__(
self, name, in_states, out_states, alpha, beta, gamma,
Expand Down Expand Up @@ -133,7 +137,10 @@ def __init__(
self.T9_vec = None
self.mu_median_vec = None
self.expsigma_vec = None
self.frwrd_rate_param_func = None
self.log_T9_vec = None
self.log_mu_median_vec = None
self.log_expsigma_vec = None
self.frwrd_rate_param_func = None

if spline_data:
self.T9_vec, self.mu_median_vec, self.expsigma_vec = np.loadtxt(
Expand All @@ -153,7 +160,13 @@ def __init__(
# No GPU available or no GPU devices found - data stays on CPU
pass

elif frwrd_rate_param_func is not None:
# Precompute the logs of the (constant) interpolation tables once,
# so frwrd_rate_param doesn't recompute them on every rate eval.
self.log_T9_vec = jnp.log(self.T9_vec)
self.log_mu_median_vec = jnp.log(self.mu_median_vec)
self.log_expsigma_vec = jnp.log(self.expsigma_vec)

elif frwrd_rate_param_func is not None:

self.frwrd_rate_param_func = frwrd_rate_param_func

Expand Down Expand Up @@ -189,32 +202,92 @@ def frwrd_rate_param(self, T, p):
The rate here is either <sigma v> or <sigma v^2> divided by
(1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2).
"""
warnings.warn('Reaction.frwrd_rate_param is deprecated. '
'Use table = frwrd_rate_table(q), frwrd_rate_from_table(T, table) instead.',
FutureWarning, stacklevel=2)

T9 = T*1e-9

if self.T9_vec is not None:
if self.T9_vec is not None:

rate_vec = self.mu_median_vec * jnp.exp(
p * jnp.log(self.expsigma_vec)
)

if self.interp_type == 'linear':
if self.interp_type == 'linear':

rate_vec = self.mu_median_vec * jnp.exp(
p * self.log_expsigma_vec # builds rate_vec at every step--wasteful
)
return jnp.interp(
T9, self.T9_vec, rate_vec, left=0., right=0.
)

elif self.interp_type == 'log':


elif self.interp_type == 'log':

# log(rate) = log(mu) + p*log(expsigma): no per-call vector logs.
log_rate_vec = (
self.log_mu_median_vec + p * self.log_expsigma_vec
)
return jnp.exp(jnp.interp(
jnp.log(T9), jnp.log(self.T9_vec), jnp.log(rate_vec),
jnp.log(T9), self.log_T9_vec, log_rate_vec,
left=0., right=0.
))

else:
else:

return self.frwrd_rate_param_func(T, p)

@eqx.filter_jit
def frwrd_rate_table(self, q):
"""
Precompute the q-dependent, T-independent part of the forward rate.

Parameters
----------
q : float
Rescaling parameter for expsigma

Returns
-------
Array or float
Forward rate with q rescaling

"""
if self.T9_vec is not None:
if self.interp_type == 'linear':
return self.mu_median_vec * jnp.exp(q * self.log_expsigma_vec)
elif self.interp_type == 'log':
return self.log_mu_median_vec + q * self.log_expsigma_vec
else:
return q


@eqx.filter_jit
def frwrd_rate_from_table(self, T, table):
"""
Forward rate at a temperature T given the precomputed table.

Parameters
----------
T : float
Temperature in K.
table : Array or float
Output scaled rate from `self.frwrd_rate_table(q)`. For interpolated
reactions this is the q-dependent rate table. For analytic reactions,
it is just the scalar q.

Returns
-------
float
Forward rate at temperature T.
"""
T9 = T*1e-9

if self.T9_vec is not None:
if self.interp_type == 'linear':
return jnp.interp(T9, self.T9_vec,table, left=0., right=0.)
elif self.interp_type == 'log':
return jnp.exp(jnp.interp(jnp.log(T9), self.log_T9_vec, table, left=0., right=0.))
else:
return self.frwrd_rate_param_func(T,table) # analytic

@eqx.filter_jit
def bkwrd_rate_param(self, T, p):
"""
Expand Down Expand Up @@ -242,9 +315,34 @@ def bkwrd_rate_param(self, T, p):
The rate here is either <sigma v> or <sigma v^2> divided by
(1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2).
"""
warnings.warn('Reaction.bkwrd_rate_param is deprecated. '
'Use table = frwrd_rate_table(q), bkwrd_rate_from_table(T, table) instead.',
FutureWarning, stacklevel=2)
T9 = T*1e-9

return self.alpha*T9**self.beta*jnp.exp(self.gamma/T9) * (
self.frwrd_rate_param(T, p)
)

@eqx.filter_jit
def bkwrd_rate_from_table(self, T, table):
"""
Backward rate at temperature T given the precomputed forward table.

Parameters
----------
T : float
Temperature in K.
table : Array or float
Output of `frwrd_rate_table(q)`.

Returns
-------
float
Backward rate at temperature T.
"""
T9 = T*1e-9

return self.alpha * T9**self.beta * jnp.exp(self.gamma/T9) * (
self.frwrd_rate_from_table(T,table)
)
Loading