Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
0.1.1 (2026-05-06)
-------------------
- Update definitions and naming conventions of weights in generators of weighted lightcones (https://github.com/ArgonneCPAC/diffhalos/pull/39)


0.1.0 (2026-02-10)
-------------------
- Implement lightcone_generators functionality (https://github.com/ArgonneCPAC/diffhalos/pull/29)
Expand Down
4 changes: 2 additions & 2 deletions diffhalos/ccshmf/ccshmf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ def subhalo_lightcone_weights_kern(

# compute relative abundance of subhalos
_weights = 10 ** predict_diff_cshmf(ccshmf_params, lgmhost, lgmu)
weights = _weights / _weights.sum()
weights_unit_normalized = _weights / _weights.sum()

# compute relative number of subhalos
nsubhalos_in_host = subhalo_counts_per_halo * weights
nsubhalos_in_host = subhalo_counts_per_halo * weights_unit_normalized

return nsubhalos_in_host, lgmu

Expand Down
69 changes: 30 additions & 39 deletions diffhalos/lightcone_generators/mc_lightcone.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def mc_lc(
sky_area_degsq: float
sky area, in deg^2

nhalos_tot: int
total number of halos to generate in the lightcone

cosmo_params: namedtuple
dsps.cosmology.flat_wcdm cosmology
cosmo_params = (Om0, w0, wa, h)
Expand Down Expand Up @@ -401,25 +398,19 @@ def weighted_lc(
logt0: float
Base-10 log of z=0 age of the Universe for the input cosmology

nhalos: ndarray of shape (n_halos_tot, )
weight of the (sub)halo
the nhalos value has a different interpretation for hosts and subs
For host halos, this is the multiplicity factor by which
each halo should be upweighted in order for the generated lightcone
to have the correct host halo mass function across redshift
For subhalos, this is the multiplicity factor associated with the
conditional subhalo mass function, N_sub(logmp | Mhost)
Thus when computing the (unconditional) (sub)halo mass function,
the host halos should be weighted by nhalos,
but subhalos should be weighted by nhalos*nhalos_host.
cen_weight: ndarray of shape (n_halos_tot, )

For centrals, cen_weight is determined by the halo mass function (HMF)
For satellites, cen_weight is HMF weight of the associated central
For satellites, cen_weight = halopop.cen_weight[halopop.halo_indx]

central : ndarray of shape (n_halos_tot, )
Integer equals 1 for central halos and 0 for subhalos

nhalos_host: ndarray of shape (n_halos_tot, )
Multiplicity factor of the host halo
sat_weight: ndarray of shape (n_halos_tot, )
Multiplicity factor of the subhalo richness
Equals 1 for central halos
For subhalos, halopop.nhalos_host = halopop.nhalos[halopop.halo_indx]
For subhalos, halopop.sat_weight = <Nsat(Msub) | Mhost>

nsub_per_host: int
number of subhalos per host halo
Expand Down Expand Up @@ -518,17 +509,16 @@ def _weighted_lc_from_grid(
halo_indx = jnp.concatenate((host_indx, subhalo_indx)).astype(int)
central = jnp.concatenate((jnp.ones(n_host), jnp.zeros(n_sub))).astype(int)

z_obs_subs = jnp.repeat(cenpop.z_obs, subpop.nsub_per_host)
z_obs_all = jnp.concatenate((cenpop.z_obs, z_obs_subs))
z_obs_all = jnp.concatenate(
(cenpop.z_obs, jnp.repeat(cenpop.z_obs, subpop.nsub_per_host))
)
cenpop = cenpop._replace(z_obs=z_obs_all)

t_obs_subs = jnp.repeat(cenpop.t_obs, subpop.nsub_per_host)
t_obs_all = jnp.concatenate((cenpop.t_obs, t_obs_subs))
t_obs_all = jnp.concatenate(
(cenpop.t_obs, jnp.repeat(cenpop.t_obs, subpop.nsub_per_host))
)
cenpop = cenpop._replace(t_obs=t_obs_all)

nhalos_host_subs = jnp.repeat(cenpop.nhalos, subpop.nsub_per_host)
nhalos_host_all = jnp.concatenate((jnp.ones(n_host), nhalos_host_subs))

logmp_obs_all = jnp.concatenate((cenpop.logmp_obs, subpop.logmp_obs))
cenpop = cenpop._replace(logmp_obs=logmp_obs_all)

Expand All @@ -537,6 +527,12 @@ def _weighted_lc_from_grid(
logmp0_all = jnp.concatenate((cenpop.logmp0, logmp0_subs))
cenpop = cenpop._replace(logmp0=logmp0_all)

cenpop = cenpop._replace(
cen_weight=np.concatenate(
(cenpop.cen_weight, jnp.repeat(cenpop.cen_weight, subpop.nsub_per_host))
)
)

# combine halo and subhalo mah_params
mah_params_names = cenpop.mah_params._fields
mah_params_tot = np.zeros((len(mah_params_names), n_host + n_sub))
Expand All @@ -552,25 +548,20 @@ def _weighted_lc_from_grid(
)
cenpop = cenpop._replace(mah_params=mah_params_ntup)

# combine halo and subhalo weights
cenpop = cenpop._replace(nhalos=np.concatenate((cenpop.nhalos, subpop.nsubhalos)))
logmu_obs_all = jnp.concatenate((jnp.zeros(n_host), subpop.logmu_obs))

logmu_obs_host = jnp.zeros(n_host)
logmu_obs_all = jnp.concatenate((logmu_obs_host, subpop.logmu_obs))
sat_weight_all = jnp.concatenate((jnp.ones(n_host), subpop.sat_weight))

# create the output namedtuple containing host and subhalo information;
# this will contain all host halo information, updated to include
# the subhalo information and some fields are updated to new shapes
halopop = namedtuple(
"weighted_lc",
[
*cenpop._fields,
"central",
"nhalos_host",
"nsub_per_host",
"logmu_obs",
"halo_indx",
],
)(*cenpop, central, nhalos_host_all, subpop.nsub_per_host, logmu_obs_all, halo_indx)
halopop = WeightedLightcone(
*cenpop, central, sat_weight_all, subpop.nsub_per_host, logmu_obs_all, halo_indx
)

return halopop


SAT_FIELDS = ["central", "sat_weight", "nsub_per_host", "logmu_obs", "halo_indx"]
_FIELDS = list(mclch.CenPop._fields) + SAT_FIELDS
WeightedLightcone = namedtuple("WeightedLightcone", _FIELDS)
10 changes: 5 additions & 5 deletions diffhalos/lightcone_generators/mc_lightcone_halos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from jax import random as jran
from jax import vmap

from ..cosmology import DEFAULT_COSMOLOGY, flat_wcdm
from ..cosmology import DEFAULT_COSMOLOGY
from ..cosmology.cosmo_basics import get_tobs_from_zobs
from ..cosmology.geometry_utils import compute_volume_from_sky_area
from ..hmf import mc_hosts
Expand All @@ -39,7 +39,7 @@
"mah_params",
"logmp0",
"logt0",
"nhalos",
"cen_weight",
)
CenPop = namedtuple("CenPop", _CENPOP_FIELDS)

Expand Down Expand Up @@ -345,7 +345,7 @@ def weighted_lc_halos(
logt0: float
Base-10 log of z=0 age of the Universe for the input cosmology

nhalos: ndarray of shape (n_halos, )
cen_weight: ndarray of shape (n_halos, )
Multiplicity factor by which each halo should be upweighted
in order for the generated lightcone to have the correct
host halo mass function across redshift
Expand Down Expand Up @@ -384,7 +384,7 @@ def _weighted_lc_halos_from_grid(
centrals_model_key=DEFAULT_DIFFMAHNET_CEN_MODEL,
):
# get halo weights
nhalo_weights = halo_lightcone_weights(
cen_weight = halo_lightcone_weights(
logmp_obs,
z_obs,
sky_area_degsq,
Expand All @@ -410,7 +410,7 @@ def _weighted_lc_halos_from_grid(
logmp0 = _log_mah_kern(mah_params, 10**logt0, logt0)

# create output namedtuple
values = (z_obs, t_obs, logmp_obs, mah_params, logmp0, logt0, nhalo_weights)
values = (z_obs, t_obs, logmp_obs, mah_params, logmp0, logt0, cen_weight)
cenpop = CenPop(*values)

return cenpop
16 changes: 8 additions & 8 deletions diffhalos/lightcone_generators/mc_lightcone_subhalos.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def weighted_lc_subhalos(
-------
subpop: namedtuple
subhalo population with fields:
nsubhalos: ndarray of shape (n_nub, )
Multiplicity factor by which each subhalo should be upweighted
in order for the generated lightcone to have the correct
number of subhalos conditional subhalo mass function
sat_weight: ndarray of shape (n_nub, )
Multiplicity factor <Nsub | Mhost> by which each subhalo
should be weighted in order for the generated lightcone to have
the correct number of subhalos within each host

mah_params_subs: namedtuple of ndarray's with shape (n_subs, n_mah_params)
diffmah parameters for each subhalo in the lightcone
Expand All @@ -253,14 +253,14 @@ def weighted_lc_subhalos(
mah_key, w_key = jran.split(ran_key, 2)

# get subhalo weights
nsubhalo_weights, lgmu = subhalo_lightcone_weights(
sat_weight, lgmu = subhalo_lightcone_weights(
w_key,
cenpop.logmp_obs,
lgmsub_min,
n_mu_per_host,
ccshmf_params,
)
nsubhalo_weights = nsubhalo_weights.reshape(n_host * n_mu_per_host)
sat_weight = sat_weight.reshape(n_host * n_mu_per_host)
lgmu = lgmu.reshape(n_host * n_mu_per_host)

# get the subhalo mass and time of observation for MAH computations
Expand All @@ -282,8 +282,8 @@ def weighted_lc_subhalos(
logmu_obs = logmsub_obs - jnp.repeat(cenpop.logmp_obs, n_mu_per_host)

# add subhalo weights to the dictionary
fields = ("nsubhalos", "mah_params", "logmu_obs", "logmp_obs", "nsub_per_host")
data = (nsubhalo_weights, mah_params_subs, logmu_obs, logmsub_obs, n_mu_per_host)
fields = ("sat_weight", "mah_params", "logmu_obs", "logmp_obs", "nsub_per_host")
data = (sat_weight, mah_params_subs, logmu_obs, logmsub_obs, n_mu_per_host)
subpop = namedtuple("subpop", fields)(*data)

return subpop
66 changes: 61 additions & 5 deletions diffhalos/lightcone_generators/tests/test_mc_lightcone.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_weighted_lc_tpeak_clip():
assert np.allclose(logmp0_subs, logmsub_obs)


def test_weighted_lc_nhalos_host():
def test_weighted_lc_central():
ran_key = jran.key(0)

n_host_halos = 100
Expand All @@ -268,9 +268,65 @@ def test_weighted_lc_nhalos_host():
assert np.allclose(halopop.central[:n_host_halos], 1)
assert np.allclose(halopop.central[n_host_halos:], 0)

assert np.allclose(halopop.nhalos_host[:n_host_halos], 1)

assert np.allclose(
halopop.nhalos_host[n_host_halos:],
halopop.nhalos[halopop.halo_indx][n_host_halos:],
def test_weighted_lc_sat_weight_is_unity_for_centrals():
"""Enforce sat_weight=1 for centrals"""
ran_key = jran.key(0)

n_host_halos = 100
z_min, z_max = 0.1, 3.1
sky_area_degsq = 10.0
lgmp_min, lgmp_max = 10.0, 15.0
args = (ran_key, n_host_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq)
halopop = mclc.weighted_lc(*args)

assert np.allclose(halopop.sat_weight[:n_host_halos], 1)


def test_weighted_lc_cen_weight():
"""Enforce sat_weight=1 for centrals"""
ran_key = jran.key(0)

n_host_halos = 100
z_min, z_max = 0.1, 3.1
sky_area_degsq = 10.0
lgmp_min, lgmp_max = 10.0, 15.0
args = (ran_key, n_host_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq)
halopop = mclc.weighted_lc(*args)

correct_cen_weight_sats = np.repeat(
halopop.cen_weight[:n_host_halos], halopop.nsub_per_host
)
assert np.allclose(halopop.cen_weight[n_host_halos:], correct_cen_weight_sats)


def test_weighted_lc_logmu_obs():
"""Enforce logmu = logmp_obs - logmp_host for all (sub)halos"""
ran_key = jran.key(0)

n_host_halos = 100
z_min, z_max = 0.1, 3.1
sky_area_degsq = 10.0
lgmp_min, lgmp_max = 10.0, 15.0
args = (ran_key, n_host_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq)
halopop = mclc.weighted_lc(*args)

correct_logmp_host_sats = halopop.logmp_obs[halopop.halo_indx]
correct_logmu_sats = halopop.logmp_obs - correct_logmp_host_sats
assert np.allclose(halopop.logmu_obs, correct_logmu_sats)
assert np.allclose(halopop.logmu_obs[:n_host_halos], 0.0)


def test_weighted_lc_gal_weight():
ran_key = jran.key(0)

n_host_halos = 100
z_min, z_max = 0.1, 3.1
sky_area_degsq = 10.0
lgmp_min, lgmp_max = 10.0, 15.0
args = (ran_key, n_host_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq)
halopop = mclc.weighted_lc(*args)

gal_weight = halopop.cen_weight * halopop.sat_weight
assert np.allclose(gal_weight[:n_host_halos], halopop.cen_weight[:n_host_halos])
assert not np.any(gal_weight[n_host_halos:] == halopop.cen_weight[n_host_halos:])
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_mc_weighted_halo_lightcone_stratified():
sky_area_degsq,
)

assert np.all(np.isfinite(cenpop.nhalos))
assert np.all(np.isfinite(cenpop.cen_weight))
assert cenpop.logmp_obs.size == num_halos

assert np.all(cenpop.z_obs >= z_min)
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_mc_weighted_halo_lightcone_input_grid():
sky_area_degsq,
)

assert np.all(np.isfinite(cenpop.nhalos))
assert np.all(np.isfinite(cenpop.cen_weight))
assert cenpop.logmp_obs.size == num_halos

assert np.all(cenpop.z_obs >= z_min)
Expand All @@ -309,7 +309,7 @@ def test_weighted_lc_halos():
for field in mclh.CenPop._fields:
assert hasattr(cenpop, field)

assert np.all(np.isfinite(cenpop.nhalos))
assert np.all(np.isfinite(cenpop.cen_weight))
assert cenpop.logmp_obs.size == n_halos

assert np.all(cenpop.z_obs >= z_min)
Expand Down
18 changes: 6 additions & 12 deletions diffhalos/lightcone_generators/tests/test_mc_lightcone_subhalos.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_mc_weighted_lc_subhalos_behaves_as_expected():
for _field in subpop._fields:
assert np.all(np.isfinite(subpop._asdict()[_field]))

assert subpop.nsubhalos.shape == (n_cens * n_sub_per_host,)
assert subpop.sat_weight.shape == (n_cens * n_sub_per_host,)
assert subpop.logmu_obs.shape == (n_cens * n_sub_per_host,)

nsub_tot = int(subpop.nsub_per_host * n_cens)
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_mc_weighted_lc_subhalos_with_different_nsubs_per_host():
for _field in subpop._fields:
assert np.all(np.isfinite(subpop._asdict()[_field]))

assert subpop.nsubhalos.shape == (n_cens * n_sub_per_host,)
assert subpop.sat_weight.shape == (n_cens * n_sub_per_host,)
assert subpop.logmu_obs.shape == (n_cens * n_sub_per_host,)

nsub_tot = int(subpop.nsub_per_host * n_cens)
Expand Down Expand Up @@ -278,14 +278,8 @@ def test_mc_weighted_lc_subhalos_agrees_with_mc_subhalopop():
for lgmp_min in lgmp_min_arr:
halopop = mclsh.weighted_lc_subhalos(ran_key, cenpop, lgmp_min)

mc_lg_mu_pop = mc_subs.generate_subhalopop(
ran_key,
cenpop.logmp_obs,
lgmp_min,
)[0]
mc_lg_mu_pop = mc_subs.generate_subhalopop(ran_key, cenpop.logmp_obs, lgmp_min)[
0
]

assert np.allclose(
mc_lg_mu_pop.size,
halopop.nsubhalos.sum(),
rtol=0.1,
)
assert np.allclose(mc_lg_mu_pop.size, halopop.sat_weight.sum(), rtol=0.1)
Loading