@@ -441,7 +441,11 @@ function mcmcsample(
441441 elseif progress == false
442442 progress = :none
443443 end
444- # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
444+ progress in [:overall , :perchain , :none ] || throw (
445+ ArgumentError (
446+ " `progress` for MCMCThreads must be `:overall`, `:perchain`, `:none`, or a boolean" ,
447+ ),
448+ )
445449
446450 # Copy the random number generator, model, and sample for each thread
447451 nchunks = min (nchains, Threads. nthreads ())
@@ -581,12 +585,8 @@ function mcmcsample(
581585 # Stop updating the main progress bar (either if sampling
582586 # is done, or if an error occurs).
583587 put! (progress_channel, false )
584- # Additionally stop the per-chain progress bars (but in
585- # reverse order, because ProgressLogging prints from
586- # the bottom up, and we want chain 1 to show up at the
587- # top)
588- for (progress_name, uuid) in
589- reverse (collect (zip (progress_names, uuids)))
588+ # Additionally stop the per-chain progress bars
589+ for (progress_name, uuid) in zip (progress_names, uuids)
590590 ProgressLogging. @logprogress progress_name " done" _id = uuid
591591 end
592592 elseif progress == :overall
@@ -626,13 +626,18 @@ function mcmcsample(
626626 @warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
627627 end
628628
629- # Determine default progress bar style.
629+ # Determine default progress bar style. Note that for MCMCDistributed(),
630+ # :perchain isn't implemented.
630631 if progress == true
631- progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
632+ progress = :overall
632633 elseif progress == false
633634 progress = :none
634635 end
635- # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
636+ progress in [:overall , :none ] || throw (
637+ ArgumentError (
638+ " `progress` for MCMCDistributed must be `:overall`, `:none`, or a boolean"
639+ ),
640+ )
636641
637642 # Ensure that initial parameters and states are `nothing` or of the correct length
638643 check_initial_params (initial_params, nchains)
@@ -652,25 +657,7 @@ function mcmcsample(
652657 local chains
653658 @ifwithprogresslogger (progress != :none ) name = progressname begin
654659 # Set up progress logging.
655- if progress == :perchain
656- # This is the 'overall' progress bar. We create a channel for each
657- # chain to report back to when it finishes sampling.
658- progress_channel = Distributed. RemoteChannel (
659- () -> Channel {Bool} (Distributed. nworkers ())
660- )
661- # These are the per-chain progress bars. We generate `nchains`
662- # independent UUIDs for each progress bar
663- uuids = [UUIDs. uuid4 () for _ in 1 : nchains]
664- progress_names = [" Chain $i /$nchains " for i in 1 : nchains]
665- # Start the per-chain progress bars (but in reverse order, because
666- # ProgressLogging prints from the bottom up, and we want chain 1 to
667- # show up at the top)
668- for (progress_name, uuid) in reverse (collect (zip (progress_names, uuids)))
669- ProgressLogging. @logprogress name = progress_name nothing _id = uuid
670- end
671- child_progresses = uuids
672- child_progressnames = progress_names
673- elseif progress == :overall
660+ if progress == :overall
674661 # Just a single progress bar for the entire sampling, but instead
675662 # of tracking each chain as it comes in, we track each sample as it
676663 # comes in. This allows us to have more granular progress updates.
@@ -684,12 +671,12 @@ function mcmcsample(
684671 end
685672
686673 Distributed. @sync begin
687- if progress != :none
674+ if progress == :overall
688675 # This task updates the progress bar
689676 Distributed. @async begin
690677 # Determine threshold values for progress logging
691678 # (one update per 0.5% of progress)
692- Ntotal = progress == :overall ? nchains * N : nchains
679+ Ntotal = nchains * N
693680 threshold = Ntotal ÷ 200
694681 next_update = threshold
695682
@@ -754,19 +741,7 @@ function mcmcsample(
754741 child_progressnames,
755742 )
756743 finally
757- if progress == :perchain
758- # Stop updating the main progress bar (either if sampling
759- # is done, or if an error occurs).
760- put! (progress_channel, false )
761- # Additionally stop the per-chain progress bars (but in
762- # reverse order, because ProgressLogging prints from
763- # the bottom up, and we want chain 1 to show up at the
764- # top)
765- for (progress_name, uuid) in
766- reverse (collect (zip (progress_names, uuids)))
767- ProgressLogging. @logprogress progress_name " done" _id = uuid
768- end
769- elseif progress == :overall
744+ if progress == :overall
770745 # Stop updating the main progress bar (either if sampling
771746 # is done, or if an error occurs).
772747 put! (progress_channel, false )
0 commit comments