Skip to content

Commit ea21742

Browse files
committed
Decouple the adaptation from MCMC kernels
1 parent ba55e6a commit ea21742

File tree

2 files changed

+98
-123
lines changed

2 files changed

+98
-123
lines changed

aehmc/window_adaptation.py

Lines changed: 93 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Callable, List, Tuple
1+
from typing import List, Tuple
22

33
import aesara
44
import aesara.tensor as at
55
from aesara import config
6+
from aesara.ifelse import ifelse
67
from aesara.tensor.shape import shape_tuple
78
from aesara.tensor.var import TensorVariable
89

@@ -11,7 +12,7 @@
1112

1213

1314
def run(
14-
kernel_factory,
15+
kernel,
1516
initial_state,
1617
num_steps=1000,
1718
*,
@@ -20,13 +21,12 @@ def run(
2021
target_acceptance_rate=0.80
2122
):
2223

23-
init, update, final = window_adaptation(
24-
kernel_factory, is_mass_matrix_full, initial_step_size, target_acceptance_rate
24+
init_adapt, update_adapt, final_adapt = window_adaptation(
25+
num_steps, is_mass_matrix_full, initial_step_size, target_acceptance_rate
2526
)
2627

2728
def one_step(
28-
stage, # schedule
29-
is_middle_window_end,
29+
warmup_step,
3030
q, # chain state
3131
potential_energy,
3232
potential_energy_grad,
@@ -35,61 +35,65 @@ def one_step(
3535
log_step_size_avg,
3636
gradient_avg,
3737
mu,
38-
inverse_mass_matrix, # inverse mass matrix
3938
mean, # mass matrix adaptation state
4039
m2,
4140
sample_size,
41+
step_size, # parameters
42+
inverse_mass_matrix,
4243
):
4344
chain_state = (q, potential_energy, potential_energy_grad)
44-
4545
warmup_state = (
4646
(step, log_step_size, log_step_size_avg, gradient_avg, mu),
47-
inverse_mass_matrix,
4847
(mean, m2, sample_size),
4948
)
49+
parameters = (step_size, inverse_mass_matrix)
50+
51+
# Advance the chain by one step
52+
chain_state, inner_updates = kernel(*chain_state, *parameters)
5053

51-
(chain_state, warmup_state), inner_updates = update(
52-
stage, is_middle_window_end, chain_state, warmup_state
54+
# Update the warmup state and parameters
55+
warmup_state, parameters = update_adapt(
56+
warmup_step, warmup_state, parameters, chain_state
5357
)
5458

5559
return (
56-
*chain_state,
60+
chain_state[0], # q
61+
chain_state[1], # potential_energy
62+
chain_state[2], # potential_energy_grad
5763
*warmup_state[0],
58-
warmup_state[1],
59-
*warmup_state[2],
64+
*warmup_state[1],
65+
*parameters,
6066
), inner_updates
6167

62-
schedule = build_schedule(num_steps)
63-
stage = at.as_tensor([s[0] for s in schedule])
64-
is_middle_window_end = at.as_tensor([s[1] for s in schedule])
68+
(da_state, mm_state), parameters = init_adapt(initial_state)
6569

66-
da_state, inverse_mass_matrix, wc_state = init(initial_state)
70+
warmup_steps = at.arange(0, num_steps)
6771
state, updates = aesara.scan(
6872
fn=one_step,
69-
outputs_info=(*initial_state, *da_state, inverse_mass_matrix, *wc_state),
70-
sequences=(stage, is_middle_window_end),
73+
outputs_info=(*initial_state, *da_state, *mm_state, *parameters),
74+
sequences=(warmup_steps,),
75+
name="window_adaptation",
7176
)
7277

7378
last_chain_state = (state[0][-1], state[1][-1], state[2][-1])
74-
last_warmup_state = (
75-
(state[3][-1], state[4][-1], state[5][-1], state[6][-1], state[7][-1]),
76-
state[8][-1],
77-
(state[9][-1], state[10][-1], state[11][-1]),
78-
)
79-
80-
step_size, inverse_mass_matrix = final(last_warmup_state)
79+
step_size = state[-2][-1]
80+
inverse_mass_matrix = state[-1][-1]
8181

8282
return last_chain_state, (step_size, inverse_mass_matrix), updates
8383

8484

8585
def window_adaptation(
86-
kernel_factory: Callable[[TensorVariable], Callable],
86+
num_steps: int,
8787
is_mass_matrix_full: bool = False,
8888
initial_step_size: TensorVariable = at.as_tensor(1.0, dtype=config.floatX),
8989
target_acceptance_rate: TensorVariable = 0.80,
9090
):
9191
mm_init, mm_update, mm_final = covariance_adaptation(is_mass_matrix_full)
9292
da_init, da_update = dual_averaging_adaptation(target_acceptance_rate)
93+
schedule = build_schedule(num_steps)
94+
95+
schedule_stage = at.as_tensor([s[0] for s in schedule])
96+
schedule_middle_window = at.as_tensor([s[1] for s in schedule])
9397

9498
def init(initial_chain_state: Tuple):
9599
if initial_chain_state[0].ndim == 0:
@@ -98,114 +102,88 @@ def init(initial_chain_state: Tuple):
98102
num_dims = shape_tuple(initial_chain_state[0])[0]
99103
inverse_mass_matrix, mm_state = mm_init(num_dims)
100104

101-
step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init(
102-
initial_step_size
103-
)
105+
da_state = da_init(initial_step_size)
106+
step_size = at.exp(da_state[1])
104107

105-
return (
106-
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
107-
inverse_mass_matrix,
108-
mm_state,
109-
)
108+
warmup_state = (da_state, mm_state)
109+
parameters = (step_size, inverse_mass_matrix)
110+
return warmup_state, parameters
111+
112+
def fast_update(p_accept, warmup_state, parameters):
113+
da_state, mm_state = warmup_state
114+
_, inverse_mass_matrix = parameters
115+
116+
new_da_state = da_update(p_accept, *da_state)
117+
step_size = at.exp(new_da_state[1])
110118

111-
def fast_update(p_accept, da_state, inverse_mass_matrix, mm_state):
112-
da_state = da_update(p_accept, *da_state)
113-
return (da_state, inverse_mass_matrix, mm_state)
119+
return (new_da_state, mm_state), (step_size, inverse_mass_matrix)
114120

115-
def slow_update(position, p_accept, da_state, inverse_mass_matrix, mm_state):
116-
da_state = da_update(p_accept, *da_state)
117-
mm_state = mm_update(position, mm_state)
118-
return (da_state, inverse_mass_matrix, mm_state)
121+
def slow_update(position, p_accept, warmup_state, parameters):
122+
da_state, mm_state = warmup_state
123+
_, inverse_mass_matrix = parameters
124+
125+
new_da_state = da_update(p_accept, *da_state)
126+
new_mm_state = mm_update(position, mm_state)
127+
step_size = at.exp(new_da_state[1])
128+
129+
return (new_da_state, new_mm_state), (step_size, inverse_mass_matrix)
119130

120131
def slow_final(warmup_state):
121132
"""We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'."""
122-
da_state, inverse_mass_matrix, mm_state = warmup_state
133+
da_state, mm_state = warmup_state
123134

124-
new_inverse_mass_matrix = mm_final(mm_state)
125-
_, new_mm_state = mm_init(inverse_mass_matrix.ndim)
135+
inverse_mass_matrix = mm_final(mm_state)
136+
137+
if inverse_mass_matrix.ndim == 0:
138+
num_dims = 0
139+
else:
140+
num_dims = shape_tuple(inverse_mass_matrix)[0]
141+
_, new_mm_state = mm_init(num_dims)
126142

127143
step_size = at.exp(da_state[1])
128-
step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init(step_size)
129-
return (
130-
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
131-
new_inverse_mass_matrix,
132-
new_mm_state,
133-
)
144+
new_da_state = da_init(step_size)
134145

135-
def update(
136-
stage: int, is_middle_window_end: bool, chain_state: Tuple, warmup_state: Tuple
137-
):
138-
da_state, inverse_mass_matrix, mm_state = warmup_state
146+
warmup_state = (new_da_state, new_mm_state)
147+
parameters = (step_size, inverse_mass_matrix)
148+
return warmup_state, parameters
139149

140-
step_size = at.exp(da_state[1])
141-
kernel = kernel_factory(inverse_mass_matrix)
142-
(*chain_state, p_accept, _, _, _), updates = kernel(
143-
*chain_state, step_size, inverse_mass_matrix
144-
)
150+
def final(
151+
warmup_state: Tuple, parameters: Tuple
152+
) -> Tuple[TensorVariable, TensorVariable]:
153+
da_state, _ = warmup_state
154+
_, inverse_mass_matrix = parameters
155+
step_size = at.exp(da_state[2]) # return stepsize_avg at the end
156+
return step_size, inverse_mass_matrix
157+
158+
def update(step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Tuple):
159+
position, _, _, p_accept, *_ = chain_state
145160

146-
warmup_state = where_warmup_state(
161+
stage = schedule_stage[step]
162+
warmup_state, parameters = where_warmup_state(
147163
at.eq(stage, 0),
148-
fast_update(p_accept, da_state, inverse_mass_matrix, mm_state),
149-
slow_update(
150-
chain_state[0], p_accept, da_state, inverse_mass_matrix, mm_state
151-
),
164+
fast_update(p_accept, warmup_state, parameters),
165+
slow_update(position, p_accept, warmup_state, parameters),
152166
)
153-
warmup_state = where_warmup_state(
154-
is_middle_window_end, slow_final(warmup_state), warmup_state
167+
168+
is_middle_window_end = schedule_middle_window[step]
169+
warmup_state, parameters = where_warmup_state(
170+
is_middle_window_end, slow_final(warmup_state), (warmup_state, parameters)
155171
)
156172

157-
return (chain_state, warmup_state), updates
173+
is_last_step = at.eq(step, num_steps - 1)
174+
parameters = ifelse(is_last_step, final(warmup_state, parameters), parameters)
158175

159-
def final(warmup_state: Tuple) -> Tuple[TensorVariable, TensorVariable]:
160-
da_state, inverse_mass_matrix, mm_state = warmup_state
161-
step_size = at.exp(da_state[2]) # return stepsize_avg at the end
162-
return step_size, inverse_mass_matrix
176+
return warmup_state, parameters
163177

164178
def where_warmup_state(do_pick_left, left_warmup_state, right_warmup_state):
165-
(
166-
left_step,
167-
left_logstepsize,
168-
left_logstepsize_avg,
169-
left_gradient_avg,
170-
left_mu,
171-
) = left_warmup_state[0]
172-
(
173-
right_step,
174-
right_logstepsize,
175-
right_logstepsize_avg,
176-
right_gradient_avg,
177-
right_mu,
178-
) = right_warmup_state[0]
179-
180-
step = at.where(do_pick_left, left_step, right_step)
181-
logstepsize = at.where(do_pick_left, left_logstepsize, right_logstepsize)
182-
logstepsize_avg = at.where(
183-
do_pick_left, left_logstepsize_avg, right_logstepsize_avg
184-
)
185-
gradient_avg = at.where(do_pick_left, left_gradient_avg, right_gradient_avg)
186-
mu = at.where(do_pick_left, left_mu, right_mu)
179+
(left_da_state, left_mm_state), left_params = left_warmup_state
180+
(right_da_state, right_mm_state), right_params = right_warmup_state
187181

188-
left_inverse_mass_matrix = left_warmup_state[1]
189-
right_inverse_mass_matrix = right_warmup_state[1]
190-
inverse_mass_matrix = at.where(
191-
do_pick_left, left_inverse_mass_matrix, right_inverse_mass_matrix
192-
)
193-
194-
right_mean, right_m2, right_sample_size = right_warmup_state[2]
195-
left_mean, left_m2, left_sample_size = left_warmup_state[2]
196-
mean = at.where(do_pick_left, left_mean, right_mean)
197-
m2 = at.where(do_pick_left, left_m2, right_m2)
198-
sample_size = at.where(do_pick_left, left_sample_size, right_sample_size)
182+
da_state = ifelse(do_pick_left, left_da_state, right_da_state)
183+
mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state)
184+
params = ifelse(do_pick_left, left_params, right_params)
199185

200-
return (
201-
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
202-
inverse_mass_matrix,
203-
(
204-
mean,
205-
m2,
206-
sample_size,
207-
),
208-
)
186+
return (da_state, mm_state), params
209187

210188
return init, update, final
211189

tests/test_hmc.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@ def logprob_fn(y: TensorVariable):
2020
logprob = joint_logprob({Y_rv: y})
2121
return logprob
2222

23-
def kernel_factory(inverse_mass_matrix: TensorVariable):
24-
return nuts.new_kernel(srng, logprob_fn)
25-
2623
y_vv = Y_rv.clone()
24+
kernel = nuts.new_kernel(srng, logprob_fn)
2725
initial_state = nuts.new_state(y_vv, logprob_fn)
2826

2927
state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
30-
kernel_factory, initial_state, num_steps=1000
28+
kernel, initial_state, num_steps=1000
3129
)
3230

3331
# Compile the warmup and execute to get a value for the step size and the
@@ -42,6 +40,7 @@ def kernel_factory(inverse_mass_matrix: TensorVariable):
4240

4341
assert final_state[0] != 3.0 # the chain has moved
4442
assert np.ndim(step_size) == 0 # scalar step size
43+
assert step_size != 1.0 # step size changed
4544
assert step_size > 0.1 and step_size < 2 # stable range for the step size
4645
assert np.ndim(inverse_mass_matrix) == 0 # scalar mass matrix
4746
assert inverse_mass_matrix == pytest.approx(4, rel=1.0)
@@ -61,14 +60,12 @@ def logprob_fn(y: TensorVariable):
6160
logprob = joint_logprob({Y_rv: y})
6261
return logprob
6362

64-
def kernel_factory(inverse_mass_matrix: TensorVariable):
65-
return nuts.new_kernel(srng, logprob_fn)
66-
6763
y_vv = Y_rv.clone()
64+
kernel = nuts.new_kernel(srng, logprob_fn)
6865
initial_state = nuts.new_state(y_vv, logprob_fn)
6966

7067
state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
71-
kernel_factory, initial_state, num_steps=1000
68+
kernel, initial_state, num_steps=1000
7269
)
7370

7471
# Compile the warmup and execute to get a value for the step size and the

0 commit comments

Comments
 (0)