Skip to content

Commit

Permalink
Revert "gwinferno review" (#131)
Browse files Browse the repository at this point in the history
* Revert "gwinferno review (#130)"

This reverts commit 0a995f2.

* change actions/upload-artifact from v3 to v4

See [this page](https://github.blog/changelog/2024-04-16-deprecation-notice-v3-of-the-artifact-actions/) for more details
  • Loading branch information
jaxengodfrey authored Jan 23, 2025
1 parent 0a995f2 commit 2ac1fa6
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 162 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
run: python -m coverage report --show-missing
- name: Upload test results
if: always()
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: pytest-${{ matrix.python-version }}
path: pytest.xml
12 changes: 0 additions & 12 deletions gwinferno/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
PLANCK_2015_OmegaLambda = 1.0 - PLANCK_2015_OmegaMatter
PLANCK_2015_OmegaRadiation = 0.0

PLANCK_2015_LVK_Ho = 67.90 / 1e-3
PLANCK_2015_LVK_OmegaMatter = 0.3065
PLANCK_2015_LVK_OmegaLambda = 1.0 - PLANCK_2015_LVK_OmegaMatter
PLANCK_2015_LVK_OmegaRadiation = PLANCK_2015_OmegaRadiation

DEFAULT_DZ = 1e-3 # should be good enough for most numeric integrations we want to do


Expand Down Expand Up @@ -146,10 +141,3 @@ def z2DL(self, z, dz=DEFAULT_DZ):
PLANCK_2015_OmegaRadiation,
PLANCK_2015_OmegaLambda,
)

PLANCK_2015_LVK_Cosmology = Cosmology(
PLANCK_2015_LVK_Ho,
PLANCK_2015_LVK_OmegaMatter,
PLANCK_2015_LVK_OmegaRadiation,
PLANCK_2015_LVK_OmegaLambda,
)
2 changes: 1 addition & 1 deletion gwinferno/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def powerlaw_logit_pdf(xx, alpha, low=None, high=None, low_fall_off=4.0, high_fa
if low is not None:
prob *= logistic_unit(xx, low, sgn=-1.0, sc=low_fall_off)
if high is not None:
prob *= logistic_unit(xx, high, sgn=1.0, sc=high_fall_off)
logistic_unit(xx, high, sgn=1.0, sc=high_fall_off)
return prob


Expand Down
2 changes: 1 addition & 1 deletion gwinferno/models/bsplines/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from jax.scipy.integrate import trapezoid

from gwinferno.cosmology import PLANCK_2015_LVK_Cosmology as Planck15
from gwinferno.cosmology import PLANCK_2015_Cosmology as Planck15

from ...interpolation import BSpline
from ...interpolation import LogXBSpline
Expand Down
19 changes: 5 additions & 14 deletions gwinferno/models/parametric/parametric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
from jax.scipy.integrate import trapezoid

from gwinferno.cosmology import PLANCK_2015_LVK_Cosmology as Planck15
from gwinferno.cosmology import PLANCK_2015_Cosmology as Planck15

from ...distributions import betadist
from ...distributions import powerlaw_logit_pdf
Expand Down Expand Up @@ -64,6 +64,10 @@ def beta_spin_magnitude(a, alpha, beta, amax=1):
return betadist(a, alpha, beta, scale=amax)


def mixture_isoalign_spin_tilt(ct, xi_tilt, sigma_tilt):
return (1 - xi_tilt) / 2 + xi_tilt * truncnorm_pdf(ct, 1, sigma_tilt, -1, 1)


def iid_spin_magnitude(a1, a2, alpha_mag, beta_mag, amax=1):
return betadist(a1, alpha_mag, beta_mag, scale=amax) * betadist(a2, alpha_mag, beta_mag, scale=amax)

Expand All @@ -81,11 +85,6 @@ def independent_spin_magnitude_beta_dist(
return betadist(a1, alpha_mag1, beta_mag1, scale=amax1) * betadist(a2, alpha_mag2, beta_mag2, scale=amax2)


def mixture_isoalign_spin_tilt(ct, xi_tilt, sigma_tilt):
cut = jnp.where(jnp.greater(ct, 1) | jnp.less(ct, -1), 0, 1)
return cut * (1 - xi_tilt) / 2 + xi_tilt * truncnorm_pdf(ct, 1, sigma_tilt, -1, 1)


def iid_spin_tilt(ct1, ct2, xi_tilt, sigma_tilt):
return mixture_isoalign_spin_tilt(ct1, xi_tilt, sigma_tilt) * mixture_isoalign_spin_tilt(ct2, xi_tilt, sigma_tilt)

Expand All @@ -94,14 +93,6 @@ def independent_spin_tilt(ct1, ct2, xi_tilt_1, xi_tilt_2, sigma_tilt1, sigma_til
return mixture_isoalign_spin_tilt(ct1, xi_tilt_1, sigma_tilt1) * mixture_isoalign_spin_tilt(ct2, xi_tilt_2, sigma_tilt2)


def default_spin_tilt(ct1, ct2, xi_tilt, sigma_tilt):
iso1 = jnp.where(jnp.greater(ct1, 1) | jnp.less(ct1, -1), 0, 0.5)
iso2 = jnp.where(jnp.greater(ct2, 1) | jnp.less(ct2, -1), 0, 0.5)
ali1 = truncnorm_pdf(ct1, 1, sigma_tilt, -1, 1)
ali2 = truncnorm_pdf(ct2, 1, sigma_tilt, -1, 1)
return (1 - xi_tilt) * iso1 * iso2 + xi_tilt * ali1 * ali2


"""
***************************************
REDSHIFT MODELS
Expand Down
79 changes: 19 additions & 60 deletions gwinferno/pipeline/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def per_event_log_bayes_factors(weights, log=False):
Array of per-event log bayes factors.
jax.DeviceArray
Array of per-event log effective samples sizes from Monte Carlo integrals.
jax.DeviceArray
Array of per-event estimated variances from log of Monte Carlo integrals.
"""
if log:
logweights = weights
Expand All @@ -84,8 +82,7 @@ def per_event_log_bayes_factors(weights, log=False):
BFs /= weights.shape[1]
logBFs = jnp.log(BFs)
logn_effs = jnp.log(n_effs)
variances = 1 / jnp.exp(logn_effs) - 1 / weights.shape[1]
return logBFs, logn_effs, variances
return logBFs, logn_effs


@partial(jit, static_argnames=["log"])
Expand Down Expand Up @@ -117,9 +114,7 @@ def detection_efficiency(weights, Ninj, log=False):
jax.DeviceArray
Array of log detection efficiency.
jax.DeviceArray
Array of log N_eff from Monte Carlo integral.
jax.DeviceArray
Array of variance estimated from log of Monte Carlo integral.
Array of log N_eff from Monte Carlo Integral.
"""
if log:
logweights = weights
Expand All @@ -132,8 +127,7 @@ def detection_efficiency(weights, Ninj, log=False):
var = jnp.sum(weights**2) / Ninj**2 - mu**2 / Ninj
logmu = jnp.log(mu)
logn_eff = 2 * logmu - jnp.log(var)
variance = 1 / jnp.exp(logn_eff) - 1 / Ninj
return logmu, logn_eff, variance
return logmu, logn_eff


def hierarchical_likelihood(
Expand All @@ -142,16 +136,15 @@ def hierarchical_likelihood(
total_inj,
Nobs,
Tobs,
surveyed_hypervolume=None,
surveyed_hypervolume_function,
categorical=False,
marginal_qs=False,
indv_weights=None,
rngkey=None,
pop_frac=None,
reconstruct_rate=True,
marginalize_selection=False,
reconstruct_rate=True,
min_neff_cut=True,
max_variance_cut=False,
posterior_predictive_check=False,
param_names=None,
pedata=None,
Expand Down Expand Up @@ -181,8 +174,6 @@ def hierarchical_likelihood(
Total number of observed events analyzing.
Tobs : float
Time spent observing to produce catalog (in yrs).
surveyed_hypervolume : float
Total VT (normalization of the redshift model).
categorical : bool, optional
If `True` use latent categorical parameters to assign
each event to one of many subpopulations. Defaults to `False`.
Expand All @@ -197,6 +188,12 @@ def hierarchical_likelihood(
Tuple of true astrophysical population fractions.
Shape is number of categorical subpopulations, needs to sum to 1, and is
needed if `categorical=True`. Defaults to `None`.
surv_hypervolume_fct : callable, optional
Callable function to calculate total VT (normalization of the redshift model).
Defaults to `TotalVTCalculator()`.
vtfct_kwargs : dict, optional
Diction of args needed to call `surv_hypervolume_fct()`.
Defaults to `{"lamb": 0}`.
marginalize_selection : bool, optional
Flag to marginalize over uncertainty in selection monte carlo integral.
Defaults to `True`.
Expand All @@ -205,11 +202,6 @@ def hierarchical_likelihood(
min_neff_cut : bool, optional
Flag to use the `min_neff` cut on the likelihood ensuring Monte Carlo
integrals converge. Defaults to `True`.
max_variance_cut : bool, optional
Flag to use a cut on the maximum allowed variance < 1 estimated for the
total log likelihood. If this is `True`, then `marginalize_selection` and
`min_neff_cut` must be `False`.
Defaults to `False`.
posterior_predictive_check : bool, optional
Flag to sample from the PE/injection data to perform posterior predictive check.
Defaults to `False`.
Expand All @@ -234,14 +226,6 @@ def hierarchical_likelihood(
float
Marginalized merger rate in units of `Gpc^-3 yr^-1`.
"""
if max_variance_cut and (marginalize_selection or min_neff_cut):
raise ValueError(
"max_variance_cut is True which requires marginalize_selection and "
"min_neff_cut to be False but got "
f"marginalize_selection = {marginalize_selection} "
f"and min_neff_cut = {min_neff_cut}",
)

rate = None
if categorical:
with numpyro.plate("nObs", Nobs) as i:
Expand All @@ -251,28 +235,24 @@ def hierarchical_likelihood(
rng_key=rngkey,
).reshape((-1, 1))
mix_pe_weights = jnp.where(Qs[i] == 0, pe_weights[0][i], pe_weights[1][i])
logBFs, logn_effs, variances = per_event_log_bayes_factors(mix_pe_weights, log=log)
logBFs, logn_effs = per_event_log_bayes_factors(mix_pe_weights, log=log)
else:

logBFs, logn_effs, variances = per_event_log_bayes_factors(pe_weights, log=log)
logBFs, logn_effs = per_event_log_bayes_factors(pe_weights, log=log)

log_det_eff, logn_eff_inj, variance = detection_efficiency(inj_weights, total_inj, log=log)
log_det_eff, logn_eff_inj = detection_efficiency(inj_weights, total_inj, log=log)
numpyro.deterministic("log_nEff_inj", logn_eff_inj)
numpyro.deterministic("log_nEffs", logn_effs)
numpyro.deterministic("logBFs", logBFs)
numpyro.deterministic("detection_efficiency", jnp.exp(log_det_eff))
numpyro.deterministic("variance_log_BFs", variances)
numpyro.deterministic("variance_log_detection_efficiency", variance)
if reconstruct_rate:
total_vt = numpyro.deterministic("surveyed_hypervolume", surveyed_hypervolume / 1.0e9 * Tobs)
total_vt = numpyro.deterministic("surveyed_hypervolume", surveyed_hypervolume_function / 1.0e9 * Tobs)
unscaled_rate = numpyro.sample("unscaled_rate", dist.Gamma(Nobs))
rate = numpyro.deterministic("rate", unscaled_rate / jnp.exp(log_det_eff) / total_vt)
if marginalize_selection:
log_det_eff = log_det_eff - (3 + Nobs) / (2 * jnp.exp(logn_eff_inj))
if min_neff_cut:
log_det_eff = jnp.where(
jnp.greater_equal(logn_eff_inj, jnp.log(4 * Nobs)),
log_det_eff,
log_det_eff - (3 + Nobs) / (2 * jnp.exp(logn_eff_inj)),
jnp.inf,
)
sel = numpyro.deterministic(
Expand Down Expand Up @@ -302,20 +282,6 @@ def hierarchical_likelihood(
),
)

variance = numpyro.deterministic(
"variance_log_likelihood",
Nobs**2 * variance + variances.sum(),
)
if max_variance_cut:
log_l = numpyro.deterministic(
"variance_less_1",
jnp.where(
jnp.less_equal(variance, 1),
log_l,
jnp.nan_to_num(-jnp.inf),
),
)

numpyro.factor("log_likelihood", log_l)

if posterior_predictive_check:
Expand Down Expand Up @@ -356,14 +322,7 @@ def hierarchical_likelihood(
return rate


def construct_hierarchical_model(
model_dict,
prior_dict,
marginalize_selection=False,
min_neff_cut=True,
max_variance_cut=False,
posterior_predictive_check=True,
):
def construct_hierarchical_model(model_dict, prior_dict, min_neff_cut=True, marginalize_selection=False, posterior_predictive_check=True):
source_param_names = [k for k in model_dict.keys()]
hyper_params = {k: None for k in prior_dict.keys()}
pop_models = {k: None for k in model_dict.keys()}
Expand Down Expand Up @@ -407,10 +366,10 @@ def model(samps, injs, Ninj, Nobs, Tobs):
total_inj=Ninj,
Nobs=Nobs,
Tobs=Tobs,
surveyed_hypervolume=pop_models["redshift"].norm,
surv_hypervolume_fct=lambda *_: pop_models["redshift"].norm,
vtfct_kwargs={},
marginalize_selection=marginalize_selection,
min_neff_cut=min_neff_cut,
max_variance_cut=max_variance_cut,
posterior_predictive_check=posterior_predictive_check,
pedata=samps,
injdata=injs,
Expand Down
4 changes: 2 additions & 2 deletions gwinferno/preprocess/data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def processed_catalog_dataset_from_dict(catalog_dict, mmax=100.0):
sel = np.ones_like(catalog_dict[ev]["posterior"][param_mapping["mass_1"]], dtype=bool)
if np.sum(catalog_dict[ev]["posterior"][param_mapping["mass_1"]] > mmax) > 0:
print(np.sum(catalog_dict[ev]["posterior"][param_mapping["mass_1"]] > mmax))
print(f"removing samples from {ev} with mass_1 > {mmax}")
print(f"removing samples from {ev} with mass_1 > 100")
sel = catalog_dict[ev]["posterior"][param_mapping["mass_1"]] < mmax

data = np.array([catalog_dict[ev]["posterior"][param_mapping[param]][sel] for param in list(param_mapping.keys())])
Expand Down Expand Up @@ -169,7 +169,7 @@ def load_posterior_dataset(maximum_mass=100.0, catalog_metadata=None, key_file=N
return full_catalog_array.to_dataset(name="posteriors", promote_attrs=True)


def load_injection_dataset(injfile, param_names, through_o4a=False, through_o3=True, ifar_threshold=1, snr_threshold=10, additional_cuts=None):
def load_injection_dataset(injfile, param_names, through_o4a=False, through_o3=True, ifar_threshold=1, snr_threshold=11, additional_cuts=None):

if through_o4a:
injs = get_o4a_cumulative_injection_dict(
Expand Down
24 changes: 6 additions & 18 deletions gwinferno/preprocess/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,17 @@ def get_o4a_cumulative_injection_dict(file, param_names, ifar=1, snr=10):
Valid parameter options are: 'mass_1', 'mass_2', 'mass_ratio', 'redshift', 'a_1', 'a_2', 'cos_tilt_1', 'cos_tilt_2'.
NOTE: 'chi_eff' and 'chi_p' cannot be accounted for here.
Please use `gwinferno.preprocess.data_collection.load_injection_dataset` if you wish to work in 'chi_eff' and 'chi_p'.
ifar (int or float, optional): Inverse false alarm rate threshold for found injections. Defaults to 1.
snr (int or float, optional): signal to noise ratio threshold for found injections. Defaults to 10.
ifar (int, optional): Inverse false alarm rate threshold for found injections. Defaults to 1.
snr (int, optional): signal to noise ratio threshold for found injections. Defaults to 10.
Returns:
DataArray: xarray DataArray of injection data.
"""
with h5py.File(file, "r") as ff:
total_generated = ff.attrs["total_generated"]
analysis_time = ff.attrs["analysis_time"]
injections = np.asarray(ff["events"][:])

analysis_time = None
for key in "analysis_time", "total_analysis_time", "analysis_time_s":
if key in ff.attrs:
analysis_time = ff.attrs[key]
if analysis_time is None:
raise Exception("analysis time not found")

found = injections["semianalytic_observed_phase_maximized_snr_net"] >= snr

for key in injections.dtype.names:
Expand All @@ -51,7 +45,7 @@ def get_o4a_cumulative_injection_dict(file, param_names, ifar=1, snr=10):
)

inj_weights = inj_weights
total_generated = total_generated
total_generated = int(total_generated)
analysis_time = analysis_time / 365.25 / 24 / 60 / 60

injs["prior"] = jnp.exp(injections["lnpdraw_mass1_source_mass2_source_redshift_spin1x_spin1y_spin1z_spin2x_spin2y_spin2z"][found]) / inj_weights
Expand Down Expand Up @@ -104,14 +98,8 @@ def get_o3_cumulative_injection_dict(fi, param_names, ifar=1, snr=10, additional
redshift=data["redshift"][()][found],
)

total_generated = data.attrs["total_generated"][()]

analysis_time = None
for key in "analysis_time", "total_analysis_time", "analysis_time_s":
if key in ff.attrs:
analysis_time = ff.attrs[key][()] / 365.25 / 24 / 60 / 60
if analysis_time is None:
raise Exception("analysis time not found")
total_generated = int(data.attrs["total_generated"][()])
analysis_time = data.attrs["analysis_time_s"][()] / 365.25 / 24 / 60 / 60

injs["prior"] = data["sampling_pdf"][()][found]

Expand Down
Loading

0 comments on commit 2ac1fa6

Please sign in to comment.