Skip to content

Commit fc74dd4

Browse files
authored
Add chain_number keyword argument when performing multi-chain sampling (#174)
* Add `chain_number` keyword argument when performing multi-chain sampling * Document `chain_number` kwarg * minor bump instead * use channel to make test more robust
1 parent 92af231 commit fc74dd4

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.7.2"
6+
version = "5.8.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ Common keyword arguments for regular and parallel sampling are:
7171
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details.
7272
- `chain_type` (default: `Any`): determines the type of the returned chain
7373
- `callback` (default: `nothing`): if `callback !== nothing`, then
74-
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
74+
`callback(rng, model, sampler, sample, iteration; kwargs...)` is called after every sampling step,
7575
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
76+
- Keyword arguments `kwargs...` are passed down from the call to `sample(...)`. If you are performing multiple-chain sampling, then `kwargs` _additionally_ contains a `chain_number` keyword argument, which runs from 1 to the number of chains. This is not present when performing single-chain sampling.
7677
- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step,
7778
i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to
7879
[`AbstractMCMC.step`](@ref).

src/sample.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ function mcmcsample(
549549
else
550550
initial_state[chainidx]
551551
end,
552+
chain_number=chainidx,
552553
kwargs...,
553554
)
554555
end
@@ -669,7 +670,7 @@ function mcmcsample(
669670
Distributed.@async begin
670671
try
671672
function sample_chain(
672-
seed, initial_params, initial_state, child_progress
673+
seed, initial_params, initial_state, child_progress, chainidx
673674
)
674675
# Seed a new random number generator with the pre-made seed.
675676
Random.seed!(rng, seed)
@@ -683,6 +684,7 @@ function mcmcsample(
683684
progress=child_progress,
684685
initial_params=initial_params,
685686
initial_state=initial_state,
687+
chain_number=chainidx,
686688
kwargs...,
687689
)
688690

@@ -696,6 +698,7 @@ function mcmcsample(
696698
_initial_params,
697699
_initial_state,
698700
child_progresses,
701+
1:nchains;
699702
)
700703
finally
701704
if progress == :overall
@@ -755,6 +758,7 @@ function mcmcsample(
755758
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
756759
initial_params=initial_params,
757760
initial_state=initial_state,
761+
chain_number=i,
758762
kwargs...,
759763
)
760764
end

test/sample.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,22 @@
679679
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
680680
end
681681

682+
@testset "chain_number keyword argument" begin
683+
@testset for m in [MCMCSerial(), MCMCThreads(), MCMCDistributed()]
684+
niters = 10
685+
channel = Channel{Int}() do chn
686+
# check that the `chain_number` keyword argument is passed to the callback
687+
function callback(args...; kwargs...)
688+
@test haskey(kwargs, :chain_number)
689+
return put!(chn, kwargs[:chain_number])
690+
end
691+
chain = sample(MyModel(), MySampler(), m, niters, 4; callback=callback)
692+
end
693+
chain_numbers = collect(channel)
694+
@test sort(chain_numbers) == repeat(1:4; inner=niters)
695+
end
696+
end
697+
682698
@testset "Sample vector of `NamedTuple`s" begin
683699
chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple})
684700
# Check output type

0 commit comments

Comments
 (0)