diff --git a/.gitignore b/.gitignore index 75bef041..8b5c3b87 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,9 @@ docs/_build docs/api/** docs/_build/** docs/_build/_sources -docs/_build/_static \ No newline at end of file +docs/_build/_static +run-experimentation +**.h5 +**.ipynb +/posterior_samples_and_injections_spin_magnitude +**.txt \ No newline at end of file diff --git a/examples/basis_spline_example.py b/examples/basis_spline_example.py index b5413dda..5f5d37cd 100755 --- a/examples/basis_spline_example.py +++ b/examples/basis_spline_example.py @@ -39,12 +39,13 @@ def load_parser(): parser.add_argument("--q-knots", type=int, default=30) parser.add_argument("--tilt-knots", type=int, default=25) parser.add_argument("--z-knots", type=int, default=20) + parser.add_argument("--skip-prior", action="store_true", default=True) return parser.parse_args() def setup_mass_BSpline_model(injdata, pedata, pmap, nknots, qknots, mmin=3.0, mmax=100.0): - print(f"Basis Spline model in m1 w/ {nknots} knots logspaced from {mmin} to {mmax}...") - print(f"Basis Spline model in q w/ {qknots} knots linspaced from {mmin/mmax} to 1...") + print(f"Basis Spline model in m1 w/ {nknots} number of bases. Knots are logspaced from {mmin} to {mmax}...") + print(f"Basis Spline model in q w/ {qknots} number of bases. Knots are linspaced from {mmin/mmax} to 1...") model = BSplinePrimaryBSplineRatio( nknots, @@ -152,47 +153,53 @@ def model( Tobs, sample_prior=False, ): - mass_knots = mass_model.primary_model.nknots - q_knots = mass_model.ratio_model.nknots + mass_knots = mass_model.primary_model.n_splines + q_knots = mass_model.ratio_model.n_splines mag_model = spin_models["mag"] tilt_model = spin_models["tilt"] - mag_knots = mag_model.primary_model.nknots - tilt_knots = tilt_model.primary_model.nknots + mag_knots = mag_model.primary_model.n_splines + tilt_knots = tilt_model.primary_model.n_splines z_knots = z_model.nknots mass_cs = numpyro.sample("mass_cs", dist.Normal(0, 6), sample_shape=(mass_knots,)) - mass_tau = numpyro.sample("mass_tau", dist.Uniform(1, 1000)) - numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_tau, degree=2)) + mass_tau_squared = numpyro.sample("mass_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.01), low = 0, high = 1)) + mass_lambda = numpyro.deterministic("mass_lambda", 1/mass_tau_squared) + numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_lambda, degree=2)) q_cs = numpyro.sample("q_cs", dist.Normal(0, 4), sample_shape=(q_knots,)) - q_tau = numpyro.sample("q_tau", dist.Uniform(1, 25)) - numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_tau, degree=2)) + q_tau_squared = numpyro.sample("q_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + q_lambda = numpyro.deterministic("q_lambda", 1/q_tau_squared) + numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_lambda, degree=2)) mag_cs = numpyro.sample("mag_cs", dist.Normal(0, 2), sample_shape=(mag_knots,)) - mag_tau = numpyro.sample("mag_tau", dist.Uniform(1, 10)) - numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_tau, degree=2)) + mag_tau_squared = numpyro.sample("mag_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + mag_lambda = numpyro.deterministic("mag_lambda", 1/mag_tau_squared) + numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_lambda, degree=2)) tilt_cs = numpyro.sample("tilt_cs", dist.Normal(0, 2), sample_shape=(tilt_knots,)) - tilt_tau = numpyro.sample("tilt_tau", dist.Uniform(1, 10)) - numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_tau, degree=2)) + tilt_tau_squared = numpyro.sample("tilt_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + tilt_lambda = numpyro.deterministic("tilt_lambda", 1/tilt_tau_squared) + numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_lambda, degree=2)) lamb = numpyro.sample("lamb", dist.Normal(0, 3)) z_cs = numpyro.sample("z_cs", dist.Normal(), sample_shape=(z_knots,)) - z_tau = numpyro.sample("z_tau", dist.Uniform(1, 5)) - numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_tau, degree=2)) + z_tau_squared = numpyro.sample("z_tau_squared", dist.Uniform(1, 10)) + z_lambda = numpyro.deterministic("z_lambda", 1/z_tau_squared) + numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_lambda, degree=2)) if not sample_prior: - def get_weights(z, prior): - p_m1q = mass_model(len(z.shape), mass_cs, q_cs) - p_a1a2 = mag_model(len(z.shape), mag_cs) - p_ct1ct2 = tilt_model(len(z.shape), tilt_cs) + def get_weights(z, prior, pe_samples = True): + p_m1q = mass_model(mass_cs, q_cs, pe_samples) + p_a1a2 = mag_model(mag_cs, pe_samples) + p_ct1ct2 = tilt_model(tilt_cs, pe_samples) p_z = z_model(z, lamb, z_cs) wts = p_m1q * p_a1a2 * p_ct1ct2 * p_z / prior + return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts) peweights = get_weights(pedict["redshift"], pedict["prior"]) - injweights = get_weights(injdict["redshift"], injdict["prior"]) + injweights = get_weights(injdict["redshift"], injdict["prior"], pe_samples=False) hierarchical_likelihood( peweights, injweights, @@ -202,7 +209,7 @@ def get_weights(z, prior): surv_hypervolume_fct=z_model.normalization, vtfct_kwargs=dict(lamb=lamb, cs=z_cs), marginalize_selection=False, - min_neff_cut=True, + min_neff_cut=False, posterior_predictive_check=True, pedata=pedict, injdata=injdict, @@ -281,17 +288,17 @@ def main(): "logBFs", "log_l", "mag_cs", - "mag_tau", + "mag_lambda", "mass_cs", - "mass_tau", + "mass_lambda", "q_cs", - "q_tau", + "q_lambda", "rate", "surveyed_hypervolume", "tilt_cs", - "tilt_tau", + "tilt_lambda", "z_cs", - "z_tau", + "z_lambda", ] fig = az.plot_trace(az.from_numpyro(mcmc), var_names=plot_params) plt.savefig(f"{label}_trace_plot.png") @@ -311,6 +318,8 @@ def main(): mmin=args.mmin, m1mmin=args.mmin, mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, ) print("calculating mass posterior ppds...") pm1s, pqs, ms, qs = calculate_m1q_bspline_ppds( @@ -322,52 +331,63 @@ def main(): mmin=args.mmin, m1mmin=args.mmin, mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, ) - print("calculating mag prior ppds...") - prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1) + if not args.skip_prior: + print("calculating mag prior ppds...") + prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline) print("calculating mag posterior ppds...") - pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1) + pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline) - print("calculating tilt prior ppds...") - prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1) + if not args.skip_prior: + print("calculating tilt prior ppds...") + prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline) print("calculating tilt posterior ppds...") - ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1) + ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline) - print("calculating rate prior ppds...") - prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z) + if not args.skip_prior: + print("calculating rate prior ppds...") + prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z) print("calculating rate posterior ppds...") Rofz, zs = calculate_powerbspline_rate_of_z_ppds(posterior["lamb"], posterior["z_cs"], posterior["rate"], z) - ppd_dict = { - "dRdm1": pm1s, - "dRdq": pqs, - "m1s": ms, - "qs": qs, - "dRda": pmags, - "mags": mags, - "dRdct": ptilts, - "tilts": tilts, - "Rofz": Rofz, - "zs": zs, - } - dd.io.save(f"{label}_ppds.h5", ppd_dict) - prior_ppd_dict = { - "pm1": prior_pm1s, - "pq": prior_pqs, - "pa": prior_pmags, - "pct": prior_ptilts, - "m1s": ms, - "qs": qs, - "mags": mags, - "tilts": tilts, - "Rofz": prior_Rofz, - "zs": zs, - } - dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict) - del ppd_dict, prior_ppd_dict +# Lines (357-383) are commented out due to deepdish errors + # if not args.skip_prior: + # prior_ppd_dict = { + # "pm1": prior_pm1s, + # "pq": prior_pqs, + # "pa": prior_pmags, + # "pct": prior_ptilts, + # "m1s": ms, + # "qs": qs, + # "mags": mags, + # "tilts": tilts, + # "Rofz": prior_Rofz, + # "zs": zs, + # } + + # dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict) + # del prior_ppd_dict + + # ppd_dict = { + # "dRdm1": pm1s, + # "dRdq": pqs, + # "m1s": ms, + # "qs": qs, + # "dRda": pmags, + # "mags": mags, + # "dRdct": ptilts, + # "tilts": tilts, + # "Rofz": Rofz, + # "zs": zs, + # } + # dd.io.save(f"{label}_ppds.h5", ppd_dict) + # del ppd_dict print("plotting mass distribution...") + priors = None if args.skip_prior else {"m1": prior_pm1s, "q": prior_pqs} fig = plot_mass_dist( pm1s, pqs, @@ -375,21 +395,24 @@ def main(): qs, mmin=5.0, mmax=args.mmax, - priors={"m1": prior_pm1s, "q": prior_pqs}, + priors=priors, ) plt.savefig(f"{label}_mass_distribution.png") del fig print("plotting spin distributions...") - fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors={"mags": prior_pmags, "tilts": prior_ptilts}) + priors = None if args.skip_prior else {"mags": prior_pmags, "tilts": prior_ptilts} + fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors=priors) plt.savefig(f"{label}_iid_component_spin_distribution.png") del fig print("plotting R(z)...") - fig = plot_rofz(Rofz, zs, prior=prior_Rofz) + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, prior=prior) plt.savefig(f"{label}_rate_vs_z.png") del fig - fig = plot_rofz(Rofz, zs, logx=True, prior=prior_Rofz) + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, logx=True, prior=prior) plt.savefig(f"{label}_rate_vs_z_logscale.png") del fig @@ -405,4 +428,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/basis_spline_example_chieff.py b/examples/basis_spline_example_chieff.py new file mode 100755 index 00000000..f56aa02a --- /dev/null +++ b/examples/basis_spline_example_chieff.py @@ -0,0 +1,355 @@ +import arviz as az +import deepdish as dd +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpyro +from jax import random +from numpyro import distributions as dist +from numpyro.infer import MCMC +from numpyro.infer import NUTS + +from gwinferno.interpolation import LogXLogYBSpline +from gwinferno.interpolation import LogYBSpline +from gwinferno.models.bsplines.separable import BSplinePrimaryBSplineRatio +from gwinferno.models.bsplines.single import BSplineChiEffective +from gwinferno.models.bsplines.smoothing import apply_difference_prior +from gwinferno.models.gwpopulation.gwpopulation import PowerlawRedshiftModel +from gwinferno.pipeline.analysis import hierarchical_likelihood +from gwinferno.pipeline.parser import load_base_parser +from gwinferno.postprocess.calculate_ppds import calculate_m1q_bspline_ppds +from gwinferno.postprocess.calculate_ppds import calculate_powerlaw_rate_of_z_ppds +from gwinferno.postprocess.calculate_ppds import calculate_chieff_bspline_ppds +from gwinferno.postprocess.plotting import plot_m1_vs_z_ppc +from gwinferno.postprocess.plotting import plot_mass_dist +from gwinferno.postprocess.plotting import plot_rofz +from gwinferno.postprocess.plotting import plot_chieff_dist + +az.style.use("arviz-darkgrid") + +def load_parser(): + parser = load_base_parser() + parser.add_argument("--mass-knots", type=int, default=100) + parser.add_argument("--mag-knots", type=int, default=30) + parser.add_argument("--q-knots", type=int, default=30) + parser.add_argument("--tilt-knots", type=int, default=25) + parser.add_argument("--z-knots", type=int, default=20) + parser.add_argument("--chieff-nsplines", type=int, default=30) + parser.add_argument("--skip-prior", action="store_true", default=True) + return parser.parse_args() + + +def setup_mass_BSpline_model(injdata, pedata, pmap, nknots, qknots, mmin=3.0, mmax=100.0): + print(f"Basis Spline model in m1 w/ {nknots} number of bases. Knots are logspaced from {mmin} to {mmax}...") + print(f"Basis Spline model in q w/ {qknots} number of bases. Knots are linspaced from {mmin/mmax} to 1...") + + model = BSplinePrimaryBSplineRatio( + nknots, + qknots, + pedata[pmap["mass_1"]], + injdata[pmap["mass_1"]], + pedata[pmap["mass_ratio"]], + injdata[pmap["mass_ratio"]], + m1min=mmin, + m2min=mmin, + mmax=mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + return model + +def setup_chieff_BSpline_model(nsplines, injdata, pedata, pmap): + print(f"Basis spline model in chieff w/ {nsplines} bases. Knots are linearly spaced.") + model = BSplineChiEffective( + n_splines=nsplines, + chieff=pedata[pmap['chi_eff']], + chieff_inj=injdata[pmap['chi_eff']], + basis=LogYBSpline, + ) + return model + + +def setup_redshift_model(injdata, pedata, pmap): + print(f"Powerlaw redshift model set up.") + z_pe = pedata[pmap["redshift"]] + z_inj = injdata[pmap["redshift"]] + model = PowerlawRedshiftModel(z_pe, z_inj) + return model + + +def setup(args): + #Provide location to PE and injection samples below. + inj_pe_path = "./saved-pe-and-injs/posterior_samples_and_injections_chi_effective.h5" + df = dd.io.load(inj_pe_path) + pedata = df['pedata'] + injdata = df['injdata'] + param_map = df['param_map'] + param_names = [ + "mass_1", "mass_ratio", "redshift", "chi_eff", "prior" + ] + param_map = {p: i for i, p in enumerate(param_names)} + injdict = {k: injdata[param_map[k]] for k in param_names} + pedict = {k: pedata[param_map[k]] for k in param_names} + nObs = pedata.shape[1] + total_inj = df["total_generated"] + obs_time = df["analysis_time"] + + mass_model = setup_mass_BSpline_model( + injdata, + pedata, + param_map, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + mmax=args.mmax, + ) + z_model = setup_redshift_model(injdata, pedata, param_map) + chieff_model = setup_chieff_BSpline_model(args.chieff_nsplines, injdata, pedata, param_map) + injdict = {k: injdata[param_map[k]] for k in param_names} + pedict = {k: pedata[param_map[k]] for k in param_names} + + print(f"{len(injdict['redshift'])} found injections out of {total_inj} total") + print(f"Observed {nObs} events, each with {pedict['redshift'].shape[1]} samples, over an observing time of {obs_time} yrs") + + return ( + mass_model, + chieff_model, + z_model, + pedict, + injdict, + total_inj, + nObs, + obs_time, + ) + + +def model( + mass_model, + chieff_model, + z_model, + pedict, + injdict, + total_inj, + Nobs, + Tobs, + sample_prior=False, +): + mass_knots = mass_model.primary_model.n_splines + q_knots = mass_model.ratio_model.n_splines + chieff_nsplines = chieff_model.n_splines + + mass_cs = numpyro.sample("mass_cs", dist.Normal(0, 6), sample_shape=(mass_knots,)) + mass_tau_squared = numpyro.sample("mass_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.01), low = 0, high = 1)) + mass_lambda = numpyro.deterministic("mass_lambda", 1/mass_tau_squared) + numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_lambda, degree=2)) + + q_cs = numpyro.sample("q_cs", dist.Normal(0, 4), sample_shape=(q_knots,)) + q_tau_squared = numpyro.sample("q_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + q_lambda = numpyro.deterministic("q_lambda", 1/q_tau_squared) + numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_lambda, degree=2)) + + chieff_cs = numpyro.sample("chieff_cs", dist.Normal(0,4), sample_shape=(chieff_nsplines,)) + chieff_tau_squared = numpyro.sample("chieff_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1)) + chieff_lambda = numpyro.deterministic("chieff_lambda", 1/chieff_tau_squared) + numpyro.factor("chieff_log_smoothing_prior", apply_difference_prior(chieff_cs, chieff_lambda, degree=2)) + + lamb = numpyro.sample("lamb", dist.Normal(0, 3)) + + if not sample_prior: + + def get_weights(z, prior, pe_samples = True): + p_m1q = mass_model(mass_cs, q_cs, pe_samples) + p_chieff = chieff_model(chieff_cs, pe_samples) + p_z = z_model(z, lamb) + wts = p_m1q * p_chieff * p_z / prior + + return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts) + + peweights = get_weights(pedict["redshift"], pedict["prior"]) + injweights = get_weights(injdict["redshift"], injdict["prior"], pe_samples=False) + hierarchical_likelihood( + peweights, + injweights, + total_inj=total_inj, + Nobs=Nobs, + Tobs=Tobs, + vtfct_kwargs=dict(lamb=lamb), + marginalize_selection=False, + min_neff_cut=False, + posterior_predictive_check=True, + pedata=pedict, + injdata=injdict, + param_names=[ + "mass_1", + "mass_ratio", + "redshift", + "chi_eff", + ], + ) + + +def main(): + args = load_parser() + label = f"{args.outdir}/bsplines_{args.chieff_nsplines}chieff_{args.mass_knots}m1_{args.q_knots}q_z" + mass, chieff, z, pedict, injdict, total_inj, nObs, obs_time = setup(args) + if not args.skip_inference: + RNG = random.PRNGKey(0) + MCMC_RNG, PRIOR_RNG, _RNG = random.split(RNG, num=3) + kernel = NUTS(model) + mcmc = MCMC( + kernel, + thinning=args.thinning, + num_warmup=args.warmup, + num_samples=args.samples, + num_chains=args.chains, + ) + print("running mcmc: sampling prior...") + mcmc.run( + PRIOR_RNG, + mass, + chieff, + z, + pedict, + injdict, + float(total_inj), + nObs, + obs_time, + sample_prior=True, + ) + prior = mcmc.get_samples() + dd.io.save(f"{label}_prior_samples.h5", prior) + + kernel = NUTS(model) + mcmc = MCMC( + kernel, + thinning=args.thinning, + num_warmup=args.warmup, + num_samples=args.samples, + num_chains=args.chains, + ) + print("running mcmc: sampling posterior...") + mcmc.run( + MCMC_RNG, + mass, + chieff, + z, + pedict, + injdict, + float(total_inj), + nObs, + obs_time, + sample_prior=False, + ) + mcmc.print_summary() + posterior = mcmc.get_samples() + dd.io.save(f"{label}_posterior_samples.h5", posterior) + plot_params = [ + "detection_efficency", + "lamb", + "log_nEff_inj", + "log_nEffs", + "logBFs", + "log_l", + "chieff_cs", + "chieff_lambda", + "mass_cs", + "mass_lambda", + "q_cs", + "q_lambda", + "rate", + "surveyed_hypervolume", + ] + fig = az.plot_trace(az.from_numpyro(mcmc), var_names=plot_params) + plt.savefig(f"{label}_trace_plot.png") + del fig, mcmc, pedict, injdict, total_inj, obs_time + else: + print(f"loading prior and posterior samples from run with label: {label}...") + prior = dd.io.load(f"{label}_prior_samples.h5") + posterior = dd.io.load(f"{label}_posterior_samples.h5") + + print("calculating mass prior ppds...") + prior_pm1s, prior_pqs, ms, qs = calculate_m1q_bspline_ppds( + prior["mass_cs"], + prior["q_cs"], + BSplinePrimaryBSplineRatio, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + m1mmin=args.mmin, + mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + print("calculating mass posterior ppds...") + pm1s, pqs, ms, qs = calculate_m1q_bspline_ppds( + posterior["mass_cs"], + posterior["q_cs"], + BSplinePrimaryBSplineRatio, + args.mass_knots, + args.q_knots, + mmin=args.mmin, + m1mmin=args.mmin, + mmax=args.mmax, + basis_m=LogXLogYBSpline, + basis_q=LogYBSpline, + ) + + if not args.skip_prior: + print("calculating rate prior ppds...") + prior_Rofz, zs = calculate_powerlaw_rate_of_z_ppds(prior["lamb"], jnp.ones_like(prior["lamb"]), z) + print("calculating rate posterior ppds...") + Rofz, zs = calculate_powerlaw_rate_of_z_ppds(posterior["lamb"], posterior["rate"], z) + + if not args.skip_prior: + print("calculating chieff prior ppds...") + prior_pchieff, xs = calculate_chieff_bspline_ppds( + coefs=prior["chieff_cs"], + model=chieff, + nknots=args.chieff_nsplines, + basis=LogYBSpline, + ) + + print("calculating chieff posterior ppds...") + pchieff, xs = calculate_chieff_bspline_ppds( + coefs=posterior["chieff_cs"], + model=BSplineChiEffective, + nknots=args.chieff_nsplines, + basis=LogYBSpline, + ) + + print("plotting mass distribution...") + priors = None if args.skip_prior else {"m1": prior_pm1s, "q": prior_pqs} + fig = plot_mass_dist( + pm1s, + pqs, + ms, + qs, + mmin=5.0, + mmax=args.mmax, + priors=priors, + ) + plt.savefig(f"{label}_mass_distribution.png") + del fig + + print("plotting chieff distribution...") + prior = None if args.skip_prior else prior_pchieff + fig = plot_chieff_dist(pchieff, xs, prior=prior) + plt.savefig(f"{label}_chieff_distribution.png") + del fig + + print("plotting R(z)...") + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, prior=prior) + plt.savefig(f"{label}_rate_vs_z.png") + del fig + prior = None if args.skip_prior else prior_Rofz + fig = plot_rofz(Rofz, zs, logx=True, prior=prior) + plt.savefig(f"{label}_rate_vs_z_logscale.png") + del fig + + print("plotting m1/z PPC...") + fig = plot_m1_vs_z_ppc(posterior, nObs, 5.0, args.mmax, z.zmax) + plt.savefig(f"{label}_m1_vs_z_ppc.png") + del fig + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/gwinferno/cosmology.py b/gwinferno/cosmology.py index 16b1c5fd..05d397d6 100644 --- a/gwinferno/cosmology.py +++ b/gwinferno/cosmology.py @@ -48,7 +48,7 @@ def __init__(self, Ho, omega_matter, omega_radiation, omega_lambda, distance_uni self.z = jnp.array([0.0]) self.Dc = jnp.array([0.0]) self.Vc = jnp.array([0.0]) - # self.extend(max_z=2.3, dz=DEFAULT_DZ) + self.extend(max_z=2.3, dz=DEFAULT_DZ) @property def DL(self): diff --git a/gwinferno/interpolation.py b/gwinferno/interpolation.py index 7327c52f..79561648 100644 --- a/gwinferno/interpolation.py +++ b/gwinferno/interpolation.py @@ -75,6 +75,7 @@ def __init__( xrange=(0, 1), k=4, normalize=True, + norm_grid = 1000 ): """ Class to construct a basis spline (with the M-Spline basis) @@ -105,7 +106,7 @@ def __init__( self.normalize = normalize self.basis_vols = np.ones(self.N) if normalize: - self.grid = jnp.linspace(*xrange, 1000) + self.grid = jnp.linspace(*xrange, norm_grid) self.grid_bases = jnp.array(self.bases(self.grid)) self.basis_vols = jnp.array([jnp.trapz(self.grid_bases[i, :], self.grid) for i in range(self.N)]) @@ -237,6 +238,7 @@ def __init__( xrange=(0, 1), k=4, normalize=False, + **kwargs ): """ Class to construct a basis spline (B-Spline) @@ -258,6 +260,7 @@ def __init__( xrange=xrange, k=k, normalize=normalize, + **kwargs ) def _bases(self, xs): @@ -432,11 +435,12 @@ def __init__( ydf, xrange=(0, 1), yrange=(0, 1), - kx=4, - ky=4, + xorder=4, + yorder=4, xbasis=BSpline, ybasis=BSpline, normalize=True, + norm_grid=(1000, 1000) ): """ Class to construct a 2D (bivariate) rectangular basis spline @@ -454,14 +458,14 @@ def __init__( """ self.xdf = xdf self.ydf = ydf - self.x_interpolator = xbasis(xdf, xrange=xrange, k=kx, normalize=False) - self.y_interpolator = ybasis(ydf, xrange=yrange, k=ky, normalize=False) + self.x_interpolator = xbasis(xdf, xrange=xrange, k=xorder, normalize=False) + self.y_interpolator = ybasis(ydf, xrange=yrange, k=yorder, normalize=False) self.normalize = normalize self.x_bases = None self.y_bases = None if self.normalize: - self.gridx = jnp.linspace(*xrange, 750) - self.gridy = jnp.linspace(*yrange, 750) + self.gridx = jnp.linspace(*xrange, norm_grid[0]) + self.gridy = jnp.linspace(*yrange, norm_grid[1]) self.gxx, self.gyy = jnp.meshgrid(self.gridx, self.gridy) self.grid_bases = self.bases(self.gxx, self.gyy) @@ -488,17 +492,18 @@ def bases(self, xs, ys): Args: xs (array_like): input values to evaluate the X basis spline at - xs (array_like): input values to evaluate the Y basis spline at + ys (array_like): input values to evaluate the Y basis spline at Returns: array_like: the design matrix evaluated at xs. shape (xdf, ydf, *xs.shape) """ - self.x_bases = self.x_interpolator.bases(xs) + self.x_bases = self.x_interpolator.bases(xs) self.y_bases = self.y_interpolator.bases(ys) out = jnp.array([[self.x_bases[i] * self.y_bases[j] for i in range(self.xdf)] for j in range(self.ydf)]).reshape( self.xdf, self.ydf, *xs.shape ) - self.reset_bases() + self._reset_bases() + return out def _project(self, bases, coefs): @@ -512,7 +517,7 @@ def _project(self, bases, coefs): Returns: array_like: The linear combination of the basis components given the coefficients """ - return jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs)) + return jnp.einsum("ij...,ij->...", bases, coefs) def project(self, bases, coefs): """ @@ -526,3 +531,56 @@ def project(self, bases, coefs): array_like: The linear combination of the basis components given the coefficients """ return self._project(bases, coefs) * self.norm_2d(coefs) + +class LogZRectBivariateBasisSpline(RectBivariateBasisSpline): + def __init__( + self, + xdf, + ydf, + xrange=(0, 1), + yrange=(0, 1), + xorder=4, + yorder=4, + xbasis=BSpline, + ybasis=BSpline, + normalize=True, + norm_grid=(1000, 1000) + ): + """ + Class to construct a 2D (bivariate) rectangular basis spline + + Args: + xdf (int): number of degrees of freedom for the spline in the X direction + ydf (int): number of degrees of freedom for the spline in the Y direction + xrange (tuple, optional): domain of X spline. Defaults to (0, 1). + yrange (tuple, optional): domain of Y spline. Defaults to (0, 1). + kx (int, optional): order of the X spline +1, i.e. cubcic splines->k=4. Defaults to 4 (cubic spline). + ky (int, optional): order of the Y spline +1, i.e. cubcic splines->k=4. Defaults to 4 (cubic spline). + xbasis (object, optional): Choice of basis to use for the X spline. Defaults to BSpline. + ybasis (object, optional): Choice of basis to use for the Y spline. Defaults to BSpline. + normalize (bool, optional): flag whether or not to numerically normalize the spline. Defaults to True. + """ + super().__init__( + xdf, + ydf, + xrange=xrange, + yrange=yrange, + xorder=xorder, + yorder=yorder, + xbasis=xbasis, + ybasis=ybasis, + normalize=normalize, + norm_grid = norm_grid + ) + def _project(self, bases, coefs): + """ + _project given a design matrix (or bases) and coefficients, project the coefficients onto the spline + + Args: + bases (array_like): The set of basis components or design matrix to project onto + coefs (array_like): coefficients for the basis components + + Returns: + array_like: The linear combination of the basis components given the coefficients + """ + return jnp.exp(jnp.einsum("ij...,ij->...", bases, coefs)) \ No newline at end of file diff --git a/gwinferno/models/bsplines/joint.py b/gwinferno/models/bsplines/joint.py index a7549ba4..0191362d 100644 --- a/gwinferno/models/bsplines/joint.py +++ b/gwinferno/models/bsplines/joint.py @@ -4,7 +4,7 @@ import jax.numpy as jnp -from ...interpolation import RectBivariateBasisSpline +from ...interpolation import RectBivariateBasisSpline, BSpline, LogZRectBivariateBasisSpline class Base2DBSplineModel(object): @@ -18,6 +18,8 @@ def __init__( yy_inj, xrange=(0, 1), yrange=(0, 1), + xbasis = BSpline, + ybasis = BSpline, basis=RectBivariateBasisSpline, **kwargs, ): @@ -25,7 +27,7 @@ def __init__( self.yknots = ynknots self.xmin, self.xmax = xrange self.ymin, self.ymax = yrange - self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, **kwargs) + self.interpolator = basis(xnknots, ynknots, xrange=xrange, yrange=yrange, xbasis=xbasis, ybasis=ybasis, **kwargs) self.pe_design_matrix = jnp.array(self.interpolator.bases(xx, yy)) self.inj_design_matrix = jnp.array(self.interpolator.bases(xx_inj, yy_inj)) self.funcs = [self.inj_pdf, self.pe_pdf] @@ -39,8 +41,8 @@ def pe_pdf(self, coefs): def inj_pdf(self, coefs): return self.eval_spline(self.inj_design_matrix, coefs) - def __call__(self, ndim, coefs): - return self.funcs[ndim - 1](coefs) + def __call__(self, coefs, pe_samples = True): + return self.funcs[1](coefs) if pe_samples else self.funcs[0](coefs) class BSplineJointMassRatioChiEffective(Base2DBSplineModel): @@ -52,6 +54,10 @@ def __init__( q, chieff_inj, q_inj, + chieff_range=(-1,1), + q_range=(0,1), + chi_order = 4, + q_order = 4, **kwargs, ): super().__init__( @@ -61,7 +67,39 @@ def __init__( yy=q, xx_inj=chieff_inj, yy_inj=q_inj, - xrange=(-1, 1), - yrange=(0, 1), + xrange=chieff_range, + yrange=q_range, + xorder = chi_order, + yorder = q_order, **kwargs, ) +class BSplineJointMassRedshift(Base2DBSplineModel): + def __init__( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + mmin=3., + mmax=100., + order_m=3, + order_z=3, + basis_m=BSpline, + basis_z=BSpline, + **kwargs, + ): + super().__init__( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + xorder = order_m, + yorder = order_z, + xrange = (mmin, mmax), + yrange = (0, 2), + xbasis = basis_m, + ybasis = basis_z, + ) \ No newline at end of file diff --git a/gwinferno/models/bsplines/separable.py b/gwinferno/models/bsplines/separable.py index b5a86122..8fc546e3 100644 --- a/gwinferno/models/bsplines/separable.py +++ b/gwinferno/models/bsplines/separable.py @@ -676,3 +676,56 @@ def __call__(self, ecoefs, pcoefs, pe_samples=True): p_chieff = self.chi_eff_model(ecoefs, pe_samples=pe_samples) p_chip = self.chi_p_model(pcoefs, pe_samples=pe_samples) return p_chieff * p_chip + +class BSplineJointMassRedshiftBSplineRatio(object): + def __init__( + self, + nknots_m, + nknots_z, + nknots_q, + m1, + m1_inj, + q, + q_inj, + z, + z_inj, + order_m=3, + order_q=3, + order_z=3, + m1min=3.0, + m2min=3.0, + mmax=100.0, + basis_m=BSpline, + basis_q=BSpline, + basis_z=BSpline, + **kwargs, + ): + self.primary_model = BSplineJointMassRedshift( + nknots_m, + nknots_z, + m1, + z, + m1_inj, + z_inj, + mmin=m1min, + mmax=mmax, + order_m=order_m, + order_z=order_z, + basis_m=basis_m, + basis_z=basis_z, + **kwargs, + ) + self.ratio_model = BSplineRatio( + nknots_q, + q, + q_inj, + qmin=m2min / mmax, + knots=knots_q, + order=order_q, + prefix=prefix_q, + basis=basis_q, + **kwargs, + ) + + def __call__(self, ndim, mcoefs, qcoefs): + return self.ratio_model(ndim, qcoefs) * self.primary_model(ndim, mcoefs) \ No newline at end of file diff --git a/gwinferno/models/bsplines/single.py b/gwinferno/models/bsplines/single.py index 93c20f04..f4264653 100644 --- a/gwinferno/models/bsplines/single.py +++ b/gwinferno/models/bsplines/single.py @@ -140,7 +140,6 @@ def __init__( **kwargs, ) - class BSplineSpinTilt(Base1DBSplineModel): """Class to construct a cosine tilt (cos(theta)) B-Spline model for a single binary component diff --git a/gwinferno/models/bsplines/smoothing.py b/gwinferno/models/bsplines/smoothing.py index 3bbed04f..f8c11bd1 100644 --- a/gwinferno/models/bsplines/smoothing.py +++ b/gwinferno/models/bsplines/smoothing.py @@ -22,7 +22,7 @@ def apply_difference_prior(coefs, inv_var, degree=1): return -0.5 * inv_var * jnp.dot(delta_c, delta_c.T) -def apply_twod_difference_prior(coefs, inv_var_row, inv_var_col, degree_row=1, degree_col=1): +def apply_2d_difference_prior(coefs, inv_var_row, inv_var_col, degree_row=1, degree_col=1): """ Computes the difference penalty for a 2d B-spline. Uses equation 4.19 from Practical Smoothing by Eilers and Marx. diff --git a/gwinferno/models/gwpopulation/gwpopulation.py b/gwinferno/models/gwpopulation/gwpopulation.py index e09c4686..43181dd9 100644 --- a/gwinferno/models/gwpopulation/gwpopulation.py +++ b/gwinferno/models/gwpopulation/gwpopulation.py @@ -124,4 +124,4 @@ def __call__(self, z, lamb): jnp.less_equal(z, self.zmax), self.prob(z, dVdz, lamb) / self.normalization(lamb), 0, - ) + ) \ No newline at end of file diff --git a/gwinferno/pipeline/analysis.py b/gwinferno/pipeline/analysis.py index a02b62a1..9bc50d99 100644 --- a/gwinferno/pipeline/analysis.py +++ b/gwinferno/pipeline/analysis.py @@ -246,9 +246,9 @@ def hierarchical_likelihood( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), - jnp.nan_to_num(-jnp.inf), - jnp.nan_to_num(log_l), + jnp.isnan(log_l) | jnp.isnan(jnp.min(logn_effs)) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), + -1000, + log_l, ), ) else: @@ -367,7 +367,7 @@ def hierarchical_likelihood_in_log( numpyro.factor( "log_likelihood", jnp.where( - jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), 10), + jnp.isnan(log_l) | jnp.less_equal(jnp.exp(jnp.min(logn_effs)), Nobs), jnp.nan_to_num(-jnp.inf), jnp.nan_to_num(log_l), ), diff --git a/gwinferno/pipeline/parser.py b/gwinferno/pipeline/parser.py index 75617325..d61d9cec 100644 --- a/gwinferno/pipeline/parser.py +++ b/gwinferno/pipeline/parser.py @@ -120,11 +120,11 @@ def add_mixture_model(self, param, subd): def load_base_parser(): parser = ArgumentParser() - parser.add_argument("--data-dir", type=str, default="/home/bruce.edelman/projects/GWTC3_allevents/") + parser.add_argument("--data-dir", type=str, default="/projects/farr_lab/shared/GWTC3/all_events") parser.add_argument( "--inj-file", type=str, - default="/home/bruce.edelman/projects/GWTC3_allevents/o1o2o3_mixture_injections.hdf5", + default="/projects/farr_lab/shared/GWTC3/o1o2o3_mixture_injections.hdf5", ) parser.add_argument("--outdir", type=str, default="results") parser.add_argument("--mmin", type=float, default=3.0) diff --git a/gwinferno/preprocess/data_collection.py b/gwinferno/preprocess/data_collection.py index 7c495a88..6aebe0b8 100644 --- a/gwinferno/preprocess/data_collection.py +++ b/gwinferno/preprocess/data_collection.py @@ -246,7 +246,7 @@ def setup_posterior_samples_and_injections(data_dir, inj_file, param_names=None, injdata, new_pmap = convert_component_spin_injections_to_chieff(injdata, param_map, chip=chi_p) param_map = new_pmap pedata = jnp.array(pedata) - injdata = jnp.array(pedata) + injdata = jnp.array(injdata) if save: mag_data = { "injdata": injdata,