1+ using UUIDs: uuid4
2+
13# Default implementations of `sample`.
24const PROGRESS = Ref (true )
35
@@ -144,11 +146,22 @@ function mcmcsample(
144146 @ifwithprogresslogger progress name = progressname begin
145147 # Determine threshold values for progress logging
146148 # (one update per 0.5% of progress)
147- if (progress == true || progress === nothing )
149+ if ! (progress == false )
148150 threshold = Ntotal ÷ 200
149151 next_update = threshold
150152 end
151153
154+ # Ugly hacky code to reset the start timer if called from a multi-chain
155+ # sampling process
156+ if progress isa ProgressLogging. Progress
157+ try
158+ bartrees = Logging. current_logger (). loggers[1 ]. logger. bartrees
159+ bar = TerminalLoggers. findbar (bartrees, progress. id). data
160+ bar. tfirst = time ()
161+ catch
162+ end
163+ end
164+
152165 # Obtain the initial sample and state.
153166 sample, state = if num_warmup > 0
154167 if initial_state === nothing
@@ -170,7 +183,8 @@ function mcmcsample(
170183 if progress == true
171184 ProgressLogging. @logprogress itotal / Ntotal
172185 else
173- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
186+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
187+ progress. id
174188 end
175189 next_update = itotal + threshold
176190 end
@@ -189,7 +203,8 @@ function mcmcsample(
189203 if progress == true
190204 ProgressLogging. @logprogress itotal / Ntotal
191205 else
192- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
206+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
207+ progress. id
193208 end
194209 next_update = itotal + threshold
195210 end
@@ -218,7 +233,8 @@ function mcmcsample(
218233 if progress == true
219234 ProgressLogging. @logprogress itotal / Ntotal
220235 else
221- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
236+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
237+ progress. id
222238 end
223239 next_update = itotal + threshold
224240 end
@@ -243,7 +259,8 @@ function mcmcsample(
243259 if progress == true
244260 ProgressLogging. @logprogress itotal / Ntotal
245261 else
246- ProgressLogging. @logprogress itotal / Ntotal _id = " hello"
262+ ProgressLogging. @logprogress name = progressname itotal / Ntotal _id =
263+ progress. id
247264 end
248265 next_update = itotal + threshold
249266 end
@@ -432,6 +449,18 @@ function mcmcsample(
432449 if progress
433450 channel = Channel {Bool} (length (interval))
434451 end
452+ # Generate nchains independent UUIDs for each progress bar
453+ uuids = [uuid4 () for _ in 1 : nchains]
454+ # Start the progress bars (but in reverse order, because
455+ # ProgressLogging prints from the bottom up, and we want chain 1 to
456+ # show up at the top)
457+ # TODO : This has an unintended effect that the 'time' field in the
458+ # progress bar shows the total time since the beginning of sampling,
459+ # even if the specific chain doesn't start sampling until later on.
460+ for (i, uuid) in enumerate (reverse (uuids))
461+ ProgressLogging. @logprogress name = " Chain $(nchains- i+ 1 ) /$nchains " nothing _id =
462+ uuid
463+ end
435464
436465 Distributed. @sync begin
437466 if progress
@@ -472,17 +501,21 @@ function mcmcsample(
472501 Random. seed! (_rng, seeds[chainidx])
473502
474503 # Sample a chain and save it to the vector.
504+ child_progressname = " Chain $chainidx /$nchains "
475505 child_progress = if progress == false
476506 false
477507 else
478- nothing
508+ ProgressLogging. Progress (
509+ uuids[chainidx]; name= child_progressname
510+ )
479511 end
480- @ifwithprogresslogger progress chains[chainidx] = StatsBase. sample (
512+ chains[chainidx] = StatsBase. sample (
481513 _rng,
482514 _model,
483515 _sampler,
484516 N;
485517 progress= child_progress,
518+ progressname= child_progressname,
486519 initial_params= if initial_params === nothing
487520 nothing
488521 else
@@ -496,6 +529,8 @@ function mcmcsample(
496529 kwargs... ,
497530 )
498531
532+ ProgressLogging. @logprogress name = child_progressname " done" _id = uuids[chainidx]
533+
499534 # Update the progress bar.
500535 progress && put! (channel, true )
501536 end
0 commit comments