-
Notifications
You must be signed in to change notification settings - Fork 162
Issue 558: add support for custom metrics in HMC #559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Extended the HMC sampler to support arbitrary metrics (mass matrices), including vectors, Diagonal, and dense matrices, for improved sampling efficiency. Updated momenta sampling and log-probability assessment to handle these metrics. Added smoke unit tests
- Classic grad = -x for logpdf of x ~ N(0,1) check - Check that identical metrics in different forms give similar sampling - Check that different metrics give different sampling - Check the bad metric catches
|
As noted in #560 this is only failing CI because of Given julia 1.6 is no longer LTS... whether I change the code to be 1.6 compliant is a call for the Gen devs. My instinct is no, and would prefer Gen to only support LTS 1.10 and more recent.
|
|
Thanks for this PR! Is the |
This maintains backwards compat with Julia 1.6
I'm mainly trying to get into the habit of using |
|
Looks good to me! Merging. |
This PR closes #558
This pull request extends the Hamiltonian Monte Carlo (HMC) implementation to support arbitrary metric (mass matrix) choices for the
hmcsampler provided by Gen.Contribution
hmcnow supports dense and diagonal metrics through the newmetrickwarg which dispatches on new methods ofsample_momentaandassess_momenta.metric = nothing: Lowers down to current default (effectively a diagonal metric of ones).metric::AbstractVector: Diagonal metric implemented as simple rescaling of the standard normal draws of momenta. Has safety assertion for strict positive.metric::Diagonal: Lower to vector by capturing the diagonal elements.metric::AbstractMatrix: Dense metric. This uses theGenmvnormalfunctions for both sampling and logpdf calculations.Documentation
I've extend the docstring for
hmc, however, I have not added to the documentation of Gen itself. I had a pass at adding a new section to mcmc https://github.com/probcomp/Gen.jl/blob/master/docs/src/tutorials/mcmc_map.md . However, building on the example (linear regression with outliers) I couldn't find a use case where hmc with a non-uniform metric worked where default hmc didn't.I did note that there is currently no default
hmcexample, and I could add one but that felt out of scope for this PR.EDIT: I did note in contribution guide that adding docs should be on this PR. So open to doing this if you can give a steer where?
Unit tests
I added to the current smoke tests for the
hmcsampler with: