Skip to content

Conversation

@SamuelBrand1
Copy link
Contributor

@SamuelBrand1 SamuelBrand1 commented Oct 13, 2025

This PR closes #558

This pull request extends the Hamiltonian Monte Carlo (HMC) implementation to support arbitrary metric (mass matrix) choices for the hmc sampler provided by Gen.

Contribution

hmc now supports dense and diagonal metrics through the new metric kwarg which dispatches on new methods of sample_momenta and assess_momenta.

  • metric = nothing: Lowers down to current default (effectively a diagonal metric of ones).
  • metric::AbstractVector: Diagonal metric implemented as simple rescaling of the standard normal draws of momenta. Has safety assertion for strict positive.
  • metric::Diagonal: Lower to vector by capturing the diagonal elements.
  • metric::AbstractMatrix: Dense metric. This uses the Gen mvnormal functions for both sampling and logpdf calculations.

Documentation

I've extend the docstring for hmc, however, I have not added to the documentation of Gen itself. I had a pass at adding a new section to mcmc https://github.com/probcomp/Gen.jl/blob/master/docs/src/tutorials/mcmc_map.md . However, building on the example (linear regression with outliers) I couldn't find a use case where hmc with a non-uniform metric worked where default hmc didn't.

I did note that there is currently no default hmc example, and I could add one but that felt out of scope for this PR.

EDIT: I did note in contribution guide that adding docs should be on this PR. So open to doing this if you can give a steer where?

Unit tests

I added to the current smoke tests for the hmc sampler with:

  • similar smoke tests for the metric choices
  • A simple validation that returned gradient were correct for the simple normals used in the unit test example.
  • A few checks that changing the metric in the sample sampling target changed the return samples.

Extended the HMC sampler to support arbitrary metrics (mass matrices), including vectors, Diagonal, and dense matrices, for improved sampling efficiency. Updated momenta sampling and log-probability assessment to handle these metrics. Added smoke unit tests
- Classic grad = -x for logpdf of x ~ N(0,1) check
- Check that identical metrics in different forms give similar sampling
- Check that different metrics give different sampling
- Check the bad metric catches
@SamuelBrand1
Copy link
Contributor Author

SamuelBrand1 commented Oct 14, 2025

As noted in #560 this is only failing CI because of Random.Xoshiro call which was not in julia 1.6 stdlib (or was in a different way I can't remember!).

Given julia 1.6 is no longer LTS... whether I change the code to be 1.6 compliant is a call for the Gen devs. My instinct is no, and would prefer Gen to only support LTS 1.10 and more recent.

If the answer is yes, I'll change Xoshiro for MarsenneTwister which was available in 1.6 https://docs.julialang.org/en/v1.6/stdlib/Random/ EDIT: no longer planned.

@ztangent
Copy link
Member

Thanks for this PR!

Is the rng in the test cases actually being used anywhere? Since #520 has not been merged yet, I think Julia uses the global RNG by default, and you'll have to seed the global RNG via Random.seed!(1) if you want to ensure determinism between runs. So I think there's no need for Random.Xoshiro, and we can maintain support for 1.6 (which I would prefer not to drop support for until the next major release).

This maintains backwards compat with Julia 1.6
@SamuelBrand1
Copy link
Contributor Author

Thanks for this PR!

Is the rng in the test cases actually being used anywhere? Since #520 has not been merged yet, I think Julia uses the global RNG by default, and you'll have to seed the global RNG via Random.seed!(1) if you want to ensure determinism between runs. So I think there's no need for Random.Xoshiro, and we can maintain support for 1.6 (which I would prefer not to drop support for until the next major release).

I'm mainly trying to get into the habit of using rng given the new features in Julia 1.12! But you make good points. I've changed the test to use the global RNG.

@ztangent
Copy link
Member

Looks good to me! Merging.

@ztangent ztangent merged commit aa45748 into probcomp:master Oct 22, 2025
5 checks passed
@SamuelBrand1 SamuelBrand1 deleted the add-metric-to-hmc branch October 22, 2025 15:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: Ability to sample momenta for HMC with mass matrix

2 participants