Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a631b9c
2d b-splines
Jul 20, 2023
0b787dc
Merge branch '2d-b-spline' of https://github.com/FarrOutLab/GWInferno…
Jul 20, 2023
0c67174
bug fixes
jaxengodfrey Jul 11, 2023
44490e6
add xarray data util function
jaxengodfrey Jul 11, 2023
b15abba
fix pre-commits
bfarr Jul 17, 2023
006047e
fix pre-commit warning about ambiguous flags
bfarr Jul 17, 2023
eadac6b
fix method name
bfarr Jul 24, 2023
46b6d1e
add vectorized option in comment
bfarr Jul 24, 2023
f191925
made an edit to call function of Base2DBSplineModel
Aug 14, 2023
a829bd9
2d b-spline development in RectBivariateBasisSpline class
Aug 14, 2023
1ddd296
2d-b-spline development in RectBivariateBasisSpline
CarrilloG Aug 14, 2023
4684355
changed apply_twod_difference_prior method name
CarrilloG Aug 14, 2023
1360f7f
fixed bug in saving posterior_samples_and_injection_chi_effff.h5
CarrilloG Aug 15, 2023
7e0bdd7
Merge branch 'main' into 2d-b-spline
CarrilloG Aug 21, 2023
e5efd21
Separated the RectBivariateBasisSpline class into a linear and log Z …
CarrilloG Aug 31, 2023
f1f223d
added chieff_range and q_range to BSplineJointMassRatioChiEffective
CarrilloG Sep 11, 2023
a79283f
raise n_eff cutoff from Nobs to 4*Nobs
CarrilloG Oct 3, 2023
127eff2
bug fix 2d spline: norm grids and order
CarrilloG Oct 3, 2023
0418004
updated gitignore
CarrilloG Oct 3, 2023
bbba67d
updated gitignore
CarrilloG Oct 6, 2023
4b517c6
Fixed directory locations for PE samples and injections.
CarrilloG Oct 20, 2023
0114560
Added n_splines attribute to the class BSplineRatio.
CarrilloG Oct 20, 2023
1257932
Fixed typos
CarrilloG Oct 24, 2023
f6617c2
Deleted extra line spacings and extra line of code
CarrilloG Oct 24, 2023
c06b9ae
typo fixes
CarrilloG Oct 24, 2023
985a051
deleted line of code to get it working again
CarrilloG Oct 25, 2023
4de5df8
fixed bug
CarrilloG Oct 27, 2023
aba83c3
testing funcs. Not intended to be a real commit
CarrilloG Oct 27, 2023
470746f
merging from main
CarrilloG Oct 27, 2023
c4944dd
Fixed bugs to get stable sampling
CarrilloG Oct 31, 2023
0f0bc70
Uncommented a piece of code that is used in cosmology class
CarrilloG Nov 3, 2023
2807b6d
Changed neff cut from 4*Nobs back to Nobs
CarrilloG Nov 3, 2023
adcb82e
Added another BSpline example script. This model includes a chi_eff B…
CarrilloG Nov 17, 2023
69a59b0
Removed unused functions and added a comment.
CarrilloG Nov 17, 2023
1c4d71b
Removed unused separable object that was used during developing 2d mo…
CarrilloG Nov 17, 2023
c6e1d5c
Removed nan to num jnp in logic for min_neff_cut and added an additio…
CarrilloG Nov 27, 2023
3320c75
Changed the nEffs cut back to 4*Nobs as done in Cover Your Basis paper.
CarrilloG Dec 19, 2023
52cb30a
Changed min_neff_cut to Nobs. Order 10 Nobs should be enough
CarrilloG Jan 18, 2024
65f1863
Removed an old note and random spacing indents
CarrilloG Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,9 @@ docs/_build
docs/api/**
docs/_build/**
docs/_build/_sources
docs/_build/_static
docs/_build/_static
run-experimentation
**.h5
**.ipynb
/posterior_samples_and_injections_spin_magnitude
**.txt
157 changes: 90 additions & 67 deletions examples/basis_spline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ def load_parser():
parser.add_argument("--q-knots", type=int, default=30)
parser.add_argument("--tilt-knots", type=int, default=25)
parser.add_argument("--z-knots", type=int, default=20)
parser.add_argument("--skip-prior", action="store_true", default=True)
return parser.parse_args()


def setup_mass_BSpline_model(injdata, pedata, pmap, nknots, qknots, mmin=3.0, mmax=100.0):
print(f"Basis Spline model in m1 w/ {nknots} knots logspaced from {mmin} to {mmax}...")
print(f"Basis Spline model in q w/ {qknots} knots linspaced from {mmin/mmax} to 1...")
print(f"Basis Spline model in m1 w/ {nknots} number of bases. Knots are logspaced from {mmin} to {mmax}...")
print(f"Basis Spline model in q w/ {qknots} number of bases. Knots are linspaced from {mmin/mmax} to 1...")

model = BSplinePrimaryBSplineRatio(
nknots,
Expand Down Expand Up @@ -152,47 +153,53 @@ def model(
Tobs,
sample_prior=False,
):
mass_knots = mass_model.primary_model.nknots
q_knots = mass_model.ratio_model.nknots
mass_knots = mass_model.primary_model.n_splines
q_knots = mass_model.ratio_model.n_splines
mag_model = spin_models["mag"]
tilt_model = spin_models["tilt"]
mag_knots = mag_model.primary_model.nknots
tilt_knots = tilt_model.primary_model.nknots
mag_knots = mag_model.primary_model.n_splines
tilt_knots = tilt_model.primary_model.n_splines
z_knots = z_model.nknots

mass_cs = numpyro.sample("mass_cs", dist.Normal(0, 6), sample_shape=(mass_knots,))
mass_tau = numpyro.sample("mass_tau", dist.Uniform(1, 1000))
numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_tau, degree=2))
mass_tau_squared = numpyro.sample("mass_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.01), low = 0, high = 1))
mass_lambda = numpyro.deterministic("mass_lambda", 1/mass_tau_squared)
numpyro.factor("mass_log_smoothing_prior", apply_difference_prior(mass_cs, mass_lambda, degree=2))

q_cs = numpyro.sample("q_cs", dist.Normal(0, 4), sample_shape=(q_knots,))
q_tau = numpyro.sample("q_tau", dist.Uniform(1, 25))
numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_tau, degree=2))
q_tau_squared = numpyro.sample("q_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1))
q_lambda = numpyro.deterministic("q_lambda", 1/q_tau_squared)
numpyro.factor("q_log_smoothing_prior", apply_difference_prior(q_cs, q_lambda, degree=2))

mag_cs = numpyro.sample("mag_cs", dist.Normal(0, 2), sample_shape=(mag_knots,))
mag_tau = numpyro.sample("mag_tau", dist.Uniform(1, 10))
numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_tau, degree=2))
mag_tau_squared = numpyro.sample("mag_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1))
mag_lambda = numpyro.deterministic("mag_lambda", 1/mag_tau_squared)
numpyro.factor("mag_log_smoothing_prior", apply_difference_prior(mag_cs, mag_lambda, degree=2))

tilt_cs = numpyro.sample("tilt_cs", dist.Normal(0, 2), sample_shape=(tilt_knots,))
tilt_tau = numpyro.sample("tilt_tau", dist.Uniform(1, 10))
numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_tau, degree=2))
tilt_tau_squared = numpyro.sample("tilt_tau_squared", dist.TruncatedDistribution(dist.Normal(scale = 0.1), low = 0, high = 1))
tilt_lambda = numpyro.deterministic("tilt_lambda", 1/tilt_tau_squared)
numpyro.factor("tilt_log_smoothing_prior", apply_difference_prior(tilt_cs, tilt_lambda, degree=2))

lamb = numpyro.sample("lamb", dist.Normal(0, 3))
z_cs = numpyro.sample("z_cs", dist.Normal(), sample_shape=(z_knots,))
z_tau = numpyro.sample("z_tau", dist.Uniform(1, 5))
numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_tau, degree=2))
z_tau_squared = numpyro.sample("z_tau_squared", dist.Uniform(1, 10))
z_lambda = numpyro.deterministic("z_lambda", 1/z_tau_squared)
numpyro.factor("z_log_smoothing_prior", apply_difference_prior(z_cs, z_lambda, degree=2))

if not sample_prior:

def get_weights(z, prior):
p_m1q = mass_model(len(z.shape), mass_cs, q_cs)
p_a1a2 = mag_model(len(z.shape), mag_cs)
p_ct1ct2 = tilt_model(len(z.shape), tilt_cs)
def get_weights(z, prior, pe_samples = True):
p_m1q = mass_model(mass_cs, q_cs, pe_samples)
p_a1a2 = mag_model(mag_cs, pe_samples)
p_ct1ct2 = tilt_model(tilt_cs, pe_samples)
p_z = z_model(z, lamb, z_cs)
wts = p_m1q * p_a1a2 * p_ct1ct2 * p_z / prior

return jnp.where(jnp.isnan(wts) | jnp.isinf(wts), 0, wts)

peweights = get_weights(pedict["redshift"], pedict["prior"])
injweights = get_weights(injdict["redshift"], injdict["prior"])
injweights = get_weights(injdict["redshift"], injdict["prior"], pe_samples=False)
hierarchical_likelihood(
peweights,
injweights,
Expand All @@ -202,7 +209,7 @@ def get_weights(z, prior):
surv_hypervolume_fct=z_model.normalization,
vtfct_kwargs=dict(lamb=lamb, cs=z_cs),
marginalize_selection=False,
min_neff_cut=True,
min_neff_cut=False,
posterior_predictive_check=True,
pedata=pedict,
injdata=injdict,
Expand Down Expand Up @@ -281,17 +288,17 @@ def main():
"logBFs",
"log_l",
"mag_cs",
"mag_tau",
"mag_lambda",
"mass_cs",
"mass_tau",
"mass_lambda",
"q_cs",
"q_tau",
"q_lambda",
"rate",
"surveyed_hypervolume",
"tilt_cs",
"tilt_tau",
"tilt_lambda",
"z_cs",
"z_tau",
"z_lambda",
]
fig = az.plot_trace(az.from_numpyro(mcmc), var_names=plot_params)
plt.savefig(f"{label}_trace_plot.png")
Expand All @@ -311,6 +318,8 @@ def main():
mmin=args.mmin,
m1mmin=args.mmin,
mmax=args.mmax,
basis_m=LogXLogYBSpline,
basis_q=LogYBSpline,
)
print("calculating mass posterior ppds...")
pm1s, pqs, ms, qs = calculate_m1q_bspline_ppds(
Expand All @@ -322,74 +331,88 @@ def main():
mmin=args.mmin,
m1mmin=args.mmin,
mmax=args.mmax,
basis_m=LogXLogYBSpline,
basis_q=LogYBSpline,
)

print("calculating mag prior ppds...")
prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1)
if not args.skip_prior:
print("calculating mag prior ppds...")
prior_pmags, mags = calculate_iid_spin_bspline_ppds(prior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline)
print("calculating mag posterior ppds...")
pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1)
pmags, mags = calculate_iid_spin_bspline_ppds(posterior["mag_cs"], BSplineIIDSpinMagnitudes, args.mag_knots, xmin=0, xmax=1, basis=LogYBSpline)

print("calculating tilt prior ppds...")
prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1)
if not args.skip_prior:
print("calculating tilt prior ppds...")
prior_ptilts, tilts = calculate_iid_spin_bspline_ppds(prior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline)
print("calculating tilt posterior ppds...")
ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1)
ptilts, tilts = calculate_iid_spin_bspline_ppds(posterior["tilt_cs"], BSplineIIDSpinTilts, args.tilt_knots, xmin=-1, xmax=1, basis=LogYBSpline)

print("calculating rate prior ppds...")
prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z)
if not args.skip_prior:
print("calculating rate prior ppds...")
prior_Rofz, zs = calculate_powerbspline_rate_of_z_ppds(prior["lamb"], prior["z_cs"], jnp.ones_like(prior["lamb"]), z)
print("calculating rate posterior ppds...")
Rofz, zs = calculate_powerbspline_rate_of_z_ppds(posterior["lamb"], posterior["z_cs"], posterior["rate"], z)

ppd_dict = {
"dRdm1": pm1s,
"dRdq": pqs,
"m1s": ms,
"qs": qs,
"dRda": pmags,
"mags": mags,
"dRdct": ptilts,
"tilts": tilts,
"Rofz": Rofz,
"zs": zs,
}
dd.io.save(f"{label}_ppds.h5", ppd_dict)
prior_ppd_dict = {
"pm1": prior_pm1s,
"pq": prior_pqs,
"pa": prior_pmags,
"pct": prior_ptilts,
"m1s": ms,
"qs": qs,
"mags": mags,
"tilts": tilts,
"Rofz": prior_Rofz,
"zs": zs,
}
dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict)
del ppd_dict, prior_ppd_dict
# Lines (357-383) are commented out due to deepdish errors
# if not args.skip_prior:
# prior_ppd_dict = {
# "pm1": prior_pm1s,
# "pq": prior_pqs,
# "pa": prior_pmags,
# "pct": prior_ptilts,
# "m1s": ms,
# "qs": qs,
# "mags": mags,
# "tilts": tilts,
# "Rofz": prior_Rofz,
# "zs": zs,
# }

# dd.io.save(f"{label}_prior_ppds.h5", prior_ppd_dict)
# del prior_ppd_dict

# ppd_dict = {
# "dRdm1": pm1s,
# "dRdq": pqs,
# "m1s": ms,
# "qs": qs,
# "dRda": pmags,
# "mags": mags,
# "dRdct": ptilts,
# "tilts": tilts,
# "Rofz": Rofz,
# "zs": zs,
# }
# dd.io.save(f"{label}_ppds.h5", ppd_dict)
# del ppd_dict

print("plotting mass distribution...")
priors = None if args.skip_prior else {"m1": prior_pm1s, "q": prior_pqs}
fig = plot_mass_dist(
pm1s,
pqs,
ms,
qs,
mmin=5.0,
mmax=args.mmax,
priors={"m1": prior_pm1s, "q": prior_pqs},
priors=priors,
)
plt.savefig(f"{label}_mass_distribution.png")
del fig

print("plotting spin distributions...")
fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors={"mags": prior_pmags, "tilts": prior_ptilts})
priors = None if args.skip_prior else {"mags": prior_pmags, "tilts": prior_ptilts}
fig = plot_iid_spin_dist(pmags, ptilts, mags, tilts, priors=priors)
plt.savefig(f"{label}_iid_component_spin_distribution.png")
del fig

print("plotting R(z)...")
fig = plot_rofz(Rofz, zs, prior=prior_Rofz)
prior = None if args.skip_prior else prior_Rofz
fig = plot_rofz(Rofz, zs, prior=prior)
plt.savefig(f"{label}_rate_vs_z.png")
del fig
fig = plot_rofz(Rofz, zs, logx=True, prior=prior_Rofz)
prior = None if args.skip_prior else prior_Rofz
fig = plot_rofz(Rofz, zs, logx=True, prior=prior)
plt.savefig(f"{label}_rate_vs_z_logscale.png")
del fig

Expand All @@ -405,4 +428,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading