1- from typing import Callable , List , Tuple
1+ from typing import List , Tuple
22
33import aesara
44import aesara .tensor as at
55from aesara import config
6+ from aesara .ifelse import ifelse
67from aesara .tensor .shape import shape_tuple
78from aesara .tensor .var import TensorVariable
89
1112
1213
1314def run (
14- kernel_factory ,
15+ kernel ,
1516 initial_state ,
1617 num_steps = 1000 ,
1718 * ,
@@ -20,13 +21,12 @@ def run(
2021 target_acceptance_rate = 0.80
2122):
2223
23- init , update , final = window_adaptation (
24- kernel_factory , is_mass_matrix_full , initial_step_size , target_acceptance_rate
24+ init_adapt , update_adapt , final_adapt = window_adaptation (
25+ num_steps , is_mass_matrix_full , initial_step_size , target_acceptance_rate
2526 )
2627
2728 def one_step (
28- stage , # schedule
29- is_middle_window_end ,
29+ warmup_step ,
3030 q , # chain state
3131 potential_energy ,
3232 potential_energy_grad ,
@@ -35,61 +35,65 @@ def one_step(
3535 log_step_size_avg ,
3636 gradient_avg ,
3737 mu ,
38- inverse_mass_matrix , # inverse mass matrix
3938 mean , # mass matrix adaptation state
4039 m2 ,
4140 sample_size ,
41+ step_size , # parameters
42+ inverse_mass_matrix ,
4243 ):
4344 chain_state = (q , potential_energy , potential_energy_grad )
44-
4545 warmup_state = (
4646 (step , log_step_size , log_step_size_avg , gradient_avg , mu ),
47- inverse_mass_matrix ,
4847 (mean , m2 , sample_size ),
4948 )
49+ parameters = (step_size , inverse_mass_matrix )
50+
51+ # Advance the chain by one step
52+ chain_state , inner_updates = kernel (* chain_state , * parameters )
5053
51- (chain_state , warmup_state ), inner_updates = update (
52- stage , is_middle_window_end , chain_state , warmup_state
54+ # Update the warmup state and parameters
55+ warmup_state , parameters = update_adapt (
56+ warmup_step , warmup_state , parameters , chain_state
5357 )
5458
5559 return (
56- * chain_state ,
60+ chain_state [0 ], # q
61+ chain_state [1 ], # potential_energy
62+ chain_state [2 ], # potential_energy_grad
5763 * warmup_state [0 ],
58- warmup_state [1 ],
59- * warmup_state [ 2 ] ,
64+ * warmup_state [1 ],
65+ * parameters ,
6066 ), inner_updates
6167
62- schedule = build_schedule (num_steps )
63- stage = at .as_tensor ([s [0 ] for s in schedule ])
64- is_middle_window_end = at .as_tensor ([s [1 ] for s in schedule ])
68+ (da_state , mm_state ), parameters = init_adapt (initial_state )
6569
66- da_state , inverse_mass_matrix , wc_state = init ( initial_state )
70+ warmup_steps = at . arange ( 0 , num_steps )
6771 state , updates = aesara .scan (
6872 fn = one_step ,
69- outputs_info = (* initial_state , * da_state , inverse_mass_matrix , * wc_state ),
70- sequences = (stage , is_middle_window_end ),
73+ outputs_info = (* initial_state , * da_state , * mm_state , * parameters ),
74+ sequences = (warmup_steps ,),
75+ name = "window_adaptation" ,
7176 )
7277
7378 last_chain_state = (state [0 ][- 1 ], state [1 ][- 1 ], state [2 ][- 1 ])
74- last_warmup_state = (
75- (state [3 ][- 1 ], state [4 ][- 1 ], state [5 ][- 1 ], state [6 ][- 1 ], state [7 ][- 1 ]),
76- state [8 ][- 1 ],
77- (state [9 ][- 1 ], state [10 ][- 1 ], state [11 ][- 1 ]),
78- )
79-
80- step_size , inverse_mass_matrix = final (last_warmup_state )
79+ step_size = state [- 2 ][- 1 ]
80+ inverse_mass_matrix = state [- 1 ][- 1 ]
8181
8282 return last_chain_state , (step_size , inverse_mass_matrix ), updates
8383
8484
8585def window_adaptation (
86- kernel_factory : Callable [[ TensorVariable ], Callable ] ,
86+ num_steps : int ,
8787 is_mass_matrix_full : bool = False ,
8888 initial_step_size : TensorVariable = at .as_tensor (1.0 , dtype = config .floatX ),
8989 target_acceptance_rate : TensorVariable = 0.80 ,
9090):
9191 mm_init , mm_update , mm_final = covariance_adaptation (is_mass_matrix_full )
9292 da_init , da_update = dual_averaging_adaptation (target_acceptance_rate )
93+ schedule = build_schedule (num_steps )
94+
95+ schedule_stage = at .as_tensor ([s [0 ] for s in schedule ])
96+ schedule_middle_window = at .as_tensor ([s [1 ] for s in schedule ])
9397
9498 def init (initial_chain_state : Tuple ):
9599 if initial_chain_state [0 ].ndim == 0 :
@@ -98,114 +102,88 @@ def init(initial_chain_state: Tuple):
98102 num_dims = shape_tuple (initial_chain_state [0 ])[0 ]
99103 inverse_mass_matrix , mm_state = mm_init (num_dims )
100104
101- step , logstepsize , logstepsize_avg , gradient_avg , mu = da_init (
102- initial_step_size
103- )
105+ da_state = da_init (initial_step_size )
106+ step_size = at .exp (da_state [1 ])
104107
105- return (
106- (step , logstepsize , logstepsize_avg , gradient_avg , mu ),
107- inverse_mass_matrix ,
108- mm_state ,
109- )
108+ warmup_state = (da_state , mm_state )
109+ parameters = (step_size , inverse_mass_matrix )
110+ return warmup_state , parameters
111+
112+ def fast_update (p_accept , warmup_state , parameters ):
113+ da_state , mm_state = warmup_state
114+ _ , inverse_mass_matrix = parameters
115+
116+ new_da_state = da_update (p_accept , * da_state )
117+ step_size = at .exp (new_da_state [1 ])
110118
111- def fast_update (p_accept , da_state , inverse_mass_matrix , mm_state ):
112- da_state = da_update (p_accept , * da_state )
113- return (da_state , inverse_mass_matrix , mm_state )
119+ return (new_da_state , mm_state ), (step_size , inverse_mass_matrix )
114120
115- def slow_update (position , p_accept , da_state , inverse_mass_matrix , mm_state ):
116- da_state = da_update (p_accept , * da_state )
117- mm_state = mm_update (position , mm_state )
118- return (da_state , inverse_mass_matrix , mm_state )
121+ def slow_update (position , p_accept , warmup_state , parameters ):
122+ da_state , mm_state = warmup_state
123+ _ , inverse_mass_matrix = parameters
124+
125+ new_da_state = da_update (p_accept , * da_state )
126+ new_mm_state = mm_update (position , mm_state )
127+ step_size = at .exp (new_da_state [1 ])
128+
129+ return (new_da_state , new_mm_state ), (step_size , inverse_mass_matrix )
119130
120131 def slow_final (warmup_state ):
121132 """We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'."""
122- da_state , inverse_mass_matrix , mm_state = warmup_state
133+ da_state , mm_state = warmup_state
123134
124- new_inverse_mass_matrix = mm_final (mm_state )
125- _ , new_mm_state = mm_init (inverse_mass_matrix .ndim )
135+ inverse_mass_matrix = mm_final (mm_state )
136+
137+ if inverse_mass_matrix .ndim == 0 :
138+ num_dims = 0
139+ else :
140+ num_dims = shape_tuple (inverse_mass_matrix )[0 ]
141+ _ , new_mm_state = mm_init (num_dims )
126142
127143 step_size = at .exp (da_state [1 ])
128- step , logstepsize , logstepsize_avg , gradient_avg , mu = da_init (step_size )
129- return (
130- (step , logstepsize , logstepsize_avg , gradient_avg , mu ),
131- new_inverse_mass_matrix ,
132- new_mm_state ,
133- )
144+ new_da_state = da_init (step_size )
134145
135- def update (
136- stage : int , is_middle_window_end : bool , chain_state : Tuple , warmup_state : Tuple
137- ):
138- da_state , inverse_mass_matrix , mm_state = warmup_state
146+ warmup_state = (new_da_state , new_mm_state )
147+ parameters = (step_size , inverse_mass_matrix )
148+ return warmup_state , parameters
139149
140- step_size = at .exp (da_state [1 ])
141- kernel = kernel_factory (inverse_mass_matrix )
142- (* chain_state , p_accept , _ , _ , _ ), updates = kernel (
143- * chain_state , step_size , inverse_mass_matrix
144- )
150+ def final (
151+ warmup_state : Tuple , parameters : Tuple
152+ ) -> Tuple [TensorVariable , TensorVariable ]:
153+ da_state , _ = warmup_state
154+ _ , inverse_mass_matrix = parameters
155+ step_size = at .exp (da_state [2 ]) # return stepsize_avg at the end
156+ return step_size , inverse_mass_matrix
157+
158+ def update (step : int , warmup_state : Tuple , parameters : Tuple , chain_state : Tuple ):
159+ position , _ , _ , p_accept , * _ = chain_state
145160
146- warmup_state = where_warmup_state (
161+ stage = schedule_stage [step ]
162+ warmup_state , parameters = where_warmup_state (
147163 at .eq (stage , 0 ),
148- fast_update (p_accept , da_state , inverse_mass_matrix , mm_state ),
149- slow_update (
150- chain_state [0 ], p_accept , da_state , inverse_mass_matrix , mm_state
151- ),
164+ fast_update (p_accept , warmup_state , parameters ),
165+ slow_update (position , p_accept , warmup_state , parameters ),
152166 )
153- warmup_state = where_warmup_state (
154- is_middle_window_end , slow_final (warmup_state ), warmup_state
167+
168+ is_middle_window_end = schedule_middle_window [step ]
169+ warmup_state , parameters = where_warmup_state (
170+ is_middle_window_end , slow_final (warmup_state ), (warmup_state , parameters )
155171 )
156172
157- return (chain_state , warmup_state ), updates
173+ is_last_step = at .eq (step , num_steps - 1 )
174+ parameters = ifelse (is_last_step , final (warmup_state , parameters ), parameters )
158175
159- def final (warmup_state : Tuple ) -> Tuple [TensorVariable , TensorVariable ]:
160- da_state , inverse_mass_matrix , mm_state = warmup_state
161- step_size = at .exp (da_state [2 ]) # return stepsize_avg at the end
162- return step_size , inverse_mass_matrix
176+ return warmup_state , parameters
163177
164178 def where_warmup_state (do_pick_left , left_warmup_state , right_warmup_state ):
165- (
166- left_step ,
167- left_logstepsize ,
168- left_logstepsize_avg ,
169- left_gradient_avg ,
170- left_mu ,
171- ) = left_warmup_state [0 ]
172- (
173- right_step ,
174- right_logstepsize ,
175- right_logstepsize_avg ,
176- right_gradient_avg ,
177- right_mu ,
178- ) = right_warmup_state [0 ]
179-
180- step = at .where (do_pick_left , left_step , right_step )
181- logstepsize = at .where (do_pick_left , left_logstepsize , right_logstepsize )
182- logstepsize_avg = at .where (
183- do_pick_left , left_logstepsize_avg , right_logstepsize_avg
184- )
185- gradient_avg = at .where (do_pick_left , left_gradient_avg , right_gradient_avg )
186- mu = at .where (do_pick_left , left_mu , right_mu )
179+ (left_da_state , left_mm_state ), left_params = left_warmup_state
180+ (right_da_state , right_mm_state ), right_params = right_warmup_state
187181
188- left_inverse_mass_matrix = left_warmup_state [1 ]
189- right_inverse_mass_matrix = right_warmup_state [1 ]
190- inverse_mass_matrix = at .where (
191- do_pick_left , left_inverse_mass_matrix , right_inverse_mass_matrix
192- )
193-
194- right_mean , right_m2 , right_sample_size = right_warmup_state [2 ]
195- left_mean , left_m2 , left_sample_size = left_warmup_state [2 ]
196- mean = at .where (do_pick_left , left_mean , right_mean )
197- m2 = at .where (do_pick_left , left_m2 , right_m2 )
198- sample_size = at .where (do_pick_left , left_sample_size , right_sample_size )
182+ da_state = ifelse (do_pick_left , left_da_state , right_da_state )
183+ mm_state = ifelse (do_pick_left , left_mm_state , right_mm_state )
184+ params = ifelse (do_pick_left , left_params , right_params )
199185
200- return (
201- (step , logstepsize , logstepsize_avg , gradient_avg , mu ),
202- inverse_mass_matrix ,
203- (
204- mean ,
205- m2 ,
206- sample_size ,
207- ),
208- )
186+ return (da_state , mm_state ), params
209187
210188 return init , update , final
211189
0 commit comments