Skip to content

Commit aa45748

Browse files
authored
Merge pull request #559 from SamuelBrand1/add-metric-to-hmc
Issue 558: add support for custom metrics in HMC
2 parents 8f2f4cc + 3609d3e commit aa45748

File tree

2 files changed

+197
-9
lines changed

2 files changed

+197
-9
lines changed

src/inference/hmc.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
1+
# Momenta sampling with different metrics
2+
13
function sample_momenta(n::Int)
24
Float64[random(normal, 0, 1) for _=1:n]
35
end
46

7+
function sample_momenta(n::Int, metric::AbstractVector)
8+
@assert all(>(0), metric) "All diagonal metric values must be positive"
9+
return sqrt.(metric) .* sample_momenta(n)
10+
end
11+
12+
function sample_momenta(n::Int, metric::LinearAlgebra.Diagonal)
13+
sample_momenta(n::Int, LinearAlgebra.diag(metric))
14+
end
15+
16+
function sample_momenta(n::Int, metric::AbstractMatrix)
17+
mvnormal(zeros(n), metric)
18+
end
19+
20+
function sample_momenta(n::Int, metric::Nothing)
21+
sample_momenta(n)
22+
end
23+
24+
# Assessing momenta log probabilities with different metrics
25+
526
function assess_momenta(momenta)
627
logprob = 0.
728
for val in momenta
@@ -10,21 +31,56 @@ function assess_momenta(momenta)
1031
logprob
1132
end
1233

34+
function assess_momenta(momenta, metric::AbstractVector)
35+
logprob = 0.
36+
for (val, m) in zip(momenta, metric)
37+
logprob += logpdf(normal, val, 0, sqrt(m))
38+
end
39+
logprob
40+
end
41+
42+
function assess_momenta(momenta, metric::LinearAlgebra.Diagonal)
43+
assess_momenta(momenta, LinearAlgebra.diag(metric))
44+
end
45+
46+
function assess_momenta(momenta, metric::AbstractMatrix)
47+
logpdf(mvnormal, momenta, zeros(length(momenta)), metric)
48+
end
49+
50+
function assess_momenta(momenta, metric::Nothing)
51+
assess_momenta(momenta)
52+
end
53+
1354
"""
1455
(new_trace, accepted) = hmc(
1556
trace, selection::Selection; L=10, eps=0.1,
16-
check=false, observations=EmptyChoiceMap())
57+
check=false, observations=EmptyChoiceMap(), metric = nothing)
58+
59+
Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the
60+
selected addresses, returning the new trace (which is equal to the previous trace
61+
if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.
62+
63+
Hamilton's equations are numerically integrated using leapfrog integration with
64+
step size `eps` for `L` steps and initial momenta sampled from a Gaussian distribution with
65+
covariance given by `metric` (mass matrix).
66+
67+
Sampling with HMC is improved by using a metric/mass matrix that approximates the
68+
**inverse** covariance of the target distribution, and is equivalent to a linear transformation
69+
of the parameter space (see Neal, 2011). The following options are supported for `metric`:
1770
18-
Apply a Hamiltonian Monte Carlo (HMC) update that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.
71+
- `nothing` (default): identity matrix
72+
- `Vector`: diagonal matrix with the given vector as the diagonal
73+
- `Diagonal`: diagonal matrix lowers to the vector of the diagonal entries
74+
- `Matrix`: dense matrix
1975
20-
Hamilton's equations are numerically integrated using leapfrog integration with step size `eps` for `L` steps. See equations (5.18)-(5.20) of Neal (2011).
76+
See equations (5.18)-(5.20) of Neal (2011).
2177
2278
# References
2379
Neal, Radford M. (2011), "MCMC Using Hamiltonian Dynamics", Handbook of Markov Chain Monte Carlo, pp. 113-162. URL: http://www.mcmchandbook.net/HandbookChapter5.pdf
2480
"""
2581
function hmc(
2682
trace::Trace, selection::Selection; L=10, eps=0.1,
27-
check=false, observations=EmptyChoiceMap())
83+
check=false, observations=EmptyChoiceMap(), metric = nothing)
2884
prev_model_score = get_score(trace)
2985
args = get_args(trace)
3086
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
@@ -35,8 +91,8 @@ function hmc(
3591
(_, values_trie, gradient_trie) = choice_gradients(new_trace, selection, retval_grad)
3692
values = to_array(values_trie, Float64)
3793
gradient = to_array(gradient_trie, Float64)
38-
momenta = sample_momenta(length(values))
39-
prev_momenta_score = assess_momenta(momenta)
94+
momenta = sample_momenta(length(values), metric)
95+
prev_momenta_score = assess_momenta(momenta, metric)
4096
for step=1:L
4197

4298
# half step on momenta
@@ -60,7 +116,7 @@ function hmc(
60116
new_model_score = get_score(new_trace)
61117

62118
# assess new momenta score (negative kinetic energy)
63-
new_momenta_score = assess_momenta(-momenta)
119+
new_momenta_score = assess_momenta(-momenta, metric)
64120

65121
# accept or reject
66122
alpha = new_model_score - prev_model_score + new_momenta_score - prev_momenta_score

test/inference/hmc.jl

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
@testset "hmc" begin
2-
1+
@testset "hmc tests" begin
2+
import LinearAlgebra, Random
3+
34
# smoke test a function without retval gradient
45
@gen function foo()
56
x = @trace(normal(0, 1), :x)
@@ -17,4 +18,135 @@
1718

1819
(trace, _) = generate(foo, ())
1920
(new_trace, accepted) = hmc(trace, select(:x))
21+
22+
# For Normal(0,1), grad should be -x
23+
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x), 0)
24+
values = to_array(values_trie, Float64)
25+
grad = to_array(gradient_trie, Float64)
26+
@test values -grad
27+
28+
# smoke test with vector metric
29+
@gen function bar()
30+
x = @trace(normal(0, 1), :x)
31+
y = @trace(normal(0, 1), :y)
32+
return x + y
33+
end
34+
35+
(trace, _) = generate(bar, ())
36+
metric_vec = [1.0, 2.0]
37+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)
38+
39+
# smoke test with Diagonal metric
40+
(trace, _) = generate(bar, ())
41+
metric_diag = LinearAlgebra.Diagonal([1.0, 2.0])
42+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag)
43+
44+
# smoke test with Dense matrix metric
45+
(trace, _) = generate(bar, ())
46+
metric_dense = [1.0 0.1; 0.1 2.0]
47+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense)
48+
49+
# smoke test with vector metric and retval gradient
50+
@gen (grad) function bar_grad()
51+
x = @trace(normal(0, 1), :x)
52+
y = @trace(normal(0, 1), :y)
53+
return x + y
54+
end
55+
56+
(trace, _) = generate(bar_grad, ())
57+
metric_vec = [0.5, 1.5]
58+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_vec)
59+
60+
# For each Normal(0,1), grad should be -x
61+
(_, values_trie, gradient_trie) = choice_gradients(trace, select(:x, :y), 0)
62+
values = to_array(values_trie, Float64)
63+
grad = to_array(gradient_trie, Float64)
64+
@test values -grad
65+
66+
# smoke test with Diagonal metric and retval gradient
67+
(trace, _) = generate(bar_grad, ())
68+
metric_diag = LinearAlgebra.Diagonal([0.5, 1.5])
69+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_diag)
70+
71+
# smoke test with Dense matrix metric and retval gradient
72+
(trace, _) = generate(bar_grad, ())
73+
metric_dense = [0.5 0.2; 0.2 1.5]
74+
(new_trace, accepted) = hmc(trace, select(:x, :y); metric=metric_dense)
2075
end
76+
77+
@testset "hmc metric behavior" begin
78+
import LinearAlgebra, Random
79+
80+
# test that different metrics produce different behavior
81+
@gen function test_metric_effect()
82+
x = @trace(normal(0, 1), :x)
83+
y = @trace(normal(0, 1), :y)
84+
return x + y
85+
end
86+
87+
(trace1, _) = generate(test_metric_effect, ())
88+
89+
90+
# Set RNG to a known state for comparison
91+
Random.seed!(1)
92+
93+
# Run HMC with identity metric (default)
94+
(trace_identity, _) = hmc(trace1, select(:x, :y); L=5)
95+
96+
# Reset RNG to same state for comparison
97+
Random.seed!(1)
98+
99+
# Run HMC with scaled metric (should behave differently)
100+
metric_scaled = [10.0, 0.1] # Very different scales
101+
(trace_scaled, _) = hmc(trace1, select(:x, :y); L=5, metric=metric_scaled)
102+
103+
# With same RNG sequence but different metrics, should get different results
104+
@test get_choices(trace_identity) != get_choices(trace_scaled)
105+
106+
# With same metric but different representations, should get similar results
107+
# Test many times to check statistical similarity
108+
acceptances_diag = Float64[]
109+
acceptances_dense = Float64[]
110+
111+
for i in 1:50
112+
# Reset to predictable state for each iteration
113+
Random.seed!(i)
114+
(_, accepted_diag) = hmc(trace1, select(:x, :y);
115+
metric=LinearAlgebra.Diagonal([2.0, 3.0]))
116+
117+
# Reset to same state for comparison
118+
Random.seed!(i)
119+
(_, accepted_dense) = hmc(trace1, select(:x, :y);
120+
metric=[2.0 0.0; 0.0 3.0])
121+
122+
# Collect acceptance results
123+
push!(acceptances_diag, float(accepted_diag))
124+
push!(acceptances_dense, float(accepted_dense))
125+
end
126+
127+
# # Should have similar acceptance rates (within 20%)
128+
rate_diag = Distributions.mean(acceptances_diag)
129+
rate_dense = Distributions.mean(acceptances_dense)
130+
@test abs(rate_diag - rate_dense) < 0.2
131+
132+
133+
end
134+
135+
@testset "Bad metric catches" begin
136+
@gen function bar()
137+
x = @trace(normal(0, 1), :x)
138+
y = @trace(normal(0, 1), :y)
139+
return x + y
140+
end
141+
142+
bad_metrics =([-1.0 -20.0; 0.0 1.0], # Bad dense,
143+
LinearAlgebra.Diagonal([-1.0, -20.0]), # Bad diag
144+
[-5.0, 20.0], # Bad vector diag
145+
)
146+
147+
for bad_metric in bad_metrics
148+
(trace, _) = generate(bar, ())
149+
@test_throws Exception hmc(trace, select(:x, :y); metric=bad_metric)
150+
end
151+
152+
end

0 commit comments

Comments
 (0)