Skip to content

Commit 180a928

Browse files
Introduction of SwapSampler + make TemperedSampler a fancy version of CompositionSampler (#152)
* split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed unnecessary variable in tests * Update src/sampler.jl Co-authored-by: Harrison Wilde <[email protected]> * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * removed unnecessary variable in tests * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * added CompositionTransition + quite a few bundle_samples with a `bundle_resolve_swaps` kwarg to allow converting into chains more easily * more samples * reduce requirement for ess comparison for AHMC a bit * significant improvements to the simple Gaussian example, now testing using MCSE to get tolerances, etc. and small improvements to the rest of the tests * trying to debug these tests * more debug * fixed typy * reduce significance even further --------- Co-authored-by: Harrison Wilde <[email protected]>
1 parent ef97a94 commit 180a928

18 files changed

+1087
-450
lines changed

src/MCMCTempering.jl

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ include("abstractmcmc.jl")
2020
include("adaptation.jl")
2121
include("swapping.jl")
2222
include("state.jl")
23+
include("swapsampler.jl")
2324
include("sampler.jl")
2425
include("sampling.jl")
2526
include("ladders.jl")
2627
include("stepping.jl")
2728
include("model.jl")
29+
include("utils.jl")
2830

2931
export tempered,
3032
tempered_sample,
@@ -43,21 +45,82 @@ maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensity
4345
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model
4446

4547
# Bundling.
46-
# TODO: Improve this, somehow.
47-
# TODO: Move this to an extension.
48+
# Bundling of non-tempered samples.
49+
function bundle_nontempered_samples(
50+
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
51+
model::AbstractMCMC.AbstractModel,
52+
sampler::TemperedSampler,
53+
state::TemperedState,
54+
::Type{T};
55+
kwargs...
56+
) where {T}
57+
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
58+
multimodel = MultiModel([
59+
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
60+
for i in 1:numtemps(sampler)
61+
])
62+
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
63+
multitransitions = [
64+
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
65+
for t in ts
66+
]
67+
68+
return AbstractMCMC.bundle_samples(
69+
multitransitions,
70+
multimodel,
71+
multisampler,
72+
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
73+
T
74+
)
75+
end
76+
77+
function AbstractMCMC.bundle_samples(
78+
ts::Vector{<:MultipleTransitions},
79+
model::MultiModel,
80+
sampler::MultiSampler,
81+
state::MultipleStates,
82+
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
83+
::Type{Vector{MCMCChains.Chains}};
84+
kwargs...
85+
)
86+
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
87+
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
88+
end
89+
end
90+
91+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
92+
function AbstractMCMC.bundle_samples(
93+
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
94+
model::AbstractMCMC.AbstractModel,
95+
sampler::TemperedSampler,
96+
state::TemperedState,
97+
::Type{Vector{T}};
98+
bundle_resolve_swaps::Bool=false,
99+
kwargs...
100+
) where {T}
101+
if bundle_resolve_swaps
102+
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
103+
end
104+
105+
# TODO: Do better?
106+
return ts
107+
end
108+
48109
function AbstractMCMC.bundle_samples(
49-
ts::AbstractVector{<:TemperedTransition},
110+
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
50111
model::AbstractMCMC.AbstractModel,
51112
sampler::TemperedSampler,
52113
state::TemperedState,
53114
::Type{MCMCChains.Chains};
54115
kwargs...
55116
)
117+
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
118+
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
56119
return AbstractMCMC.bundle_samples(
57-
map(Base.Fix2(getproperty, :transition), filter(!Base.Fix2(getproperty, :is_swap), ts)), # Remove the swaps.
120+
ts_actual,
58121
model,
59-
sampler_for_chain(sampler, state),
60-
state_for_chain(state),
122+
sampler_for_chain(sampler, state, 1),
123+
state_for_chain(state, 1),
61124
MCMCChains.Chains;
62125
kwargs...
63126
)
@@ -68,27 +131,85 @@ function AbstractMCMC.bundle_samples(
68131
model::AbstractMCMC.AbstractModel,
69132
sampler::CompositionSampler,
70133
state::CompositionState,
71-
::Type{MCMCChains.Chains};
134+
::Type{T};
72135
kwargs...
73-
)
136+
) where {T}
137+
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
138+
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
139+
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
140+
# TODO: Should we really always default to outer sampler?
74141
return AbstractMCMC.bundle_samples(
75-
ts, model, sampler.sampler_outer, state.state_outer, MCMCChains.Chains;
142+
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
76143
kwargs...
77144
)
78145
end
79146

80-
# Unflatten in the case of `SequentialTransitions`
147+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
81148
function AbstractMCMC.bundle_samples(
82-
ts::AbstractVector{<:SequentialTransitions},
149+
ts::Vector,
83150
model::AbstractMCMC.AbstractModel,
84151
sampler::CompositionSampler,
85-
state::SequentialStates,
86-
::Type{MCMCChains.Chains};
152+
state::CompositionState,
153+
::Type{Vector{T}};
87154
kwargs...
88-
)
89-
ts_actual = [t for tseq in ts for t in tseq.transitions]
155+
) where {T}
156+
if !saveall(sampler)
157+
# In this case, we just use the `outer` for everything since this is the only
158+
# transitions we're keeping around.
159+
return AbstractMCMC.bundle_samples(
160+
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
161+
kwargs...
162+
)
163+
end
164+
165+
# Otherwise, we don't know what to do.
166+
return ts
167+
end
168+
169+
function AbstractMCMC.bundle_samples(
170+
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
171+
model::AbstractMCMC.AbstractModel,
172+
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
173+
state::CompositionState{<:MultipleStates,<:SwapState},
174+
::Type{T};
175+
bundle_resolve_swaps::Bool=false,
176+
kwargs...
177+
) where {T}
178+
!bundle_resolve_swaps && return ts
179+
180+
# Resolve the swaps.
181+
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
182+
ts_actual = map(ts) do t
183+
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
184+
end
185+
186+
AbstractMCMC.bundle_samples(
187+
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
188+
kwargs...
189+
)
190+
end
191+
192+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
193+
function AbstractMCMC.bundle_samples(
194+
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
195+
model::AbstractMCMC.AbstractModel,
196+
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
197+
state::CompositionState{<:MultipleStates,<:SwapState},
198+
::Type{Vector{T}};
199+
bundle_resolve_swaps::Bool=false,
200+
kwargs...
201+
) where {T}
202+
!bundle_resolve_swaps && return ts
203+
204+
# Resolve the swaps (using the already implemented resolution in `composition_transition`
205+
# for this particular sampler but without `saveall`).
206+
sampler_without_saveall = @set sampler.saveall = Val(false)
207+
ts_actual = map(ts) do t
208+
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
209+
end
210+
90211
return AbstractMCMC.bundle_samples(
91-
ts_actual, model, sampler.sampler_outer, state.states[end], MCMCChains.Chains;
212+
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
92213
kwargs...
93214
)
94215
end

src/adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
1818
"""
1919
struct Geometric end
2020

21-
defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
21+
defaultscale(::Geometric, Δ) = float(eltype))(0.9)
2222

2323
"""
2424
InverselyAdditive

src/ladders.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ end
3737
Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
3838
"""
3939
function check_inverse_temperatures(Δ)
40-
if length(Δ) <= 1
41-
error("More than one inverse temperatures must be provided.")
42-
end
40+
!isempty(Δ) || error("Inverse temperatures array is empty.")
4341
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
4442
error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
4543
end

src/sampler.jl

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
"""
2+
TemperedState
3+
4+
A state for a tempered sampler.
5+
6+
# Fields
7+
$(FIELDS)
8+
"""
9+
@concrete struct TemperedState
10+
"state for swap-sampler"
11+
swapstate
12+
"state for the main sampler"
13+
state
14+
"inverse temperature for each of the chains"
15+
chain_to_beta
16+
end
17+
118
"""
219
TemperedSampler <: AbstractMCMC.AbstractSampler
320
@@ -7,44 +24,39 @@ A `TemperedSampler` struct wraps a sampler upon which to apply the Parallel Temp
724
825
$(FIELDS)
926
"""
10-
@concrete struct TemperedSampler <: AbstractMCMC.AbstractSampler
27+
Base.@kwdef struct TemperedSampler{SplT,A,SwapT,Adapt} <: AbstractMCMC.AbstractSampler
1128
"sampler(s) used to target the tempered distributions"
12-
sampler
29+
sampler::SplT
1330
"collection of inverse temperatures β; β[i] correponds i-th tempered model"
14-
inverse_temperatures
15-
"number of steps of `sampler` to take before proposing swaps"
16-
swap_every
17-
"the swap strategy that will be used when proposing swaps"
18-
swap_strategy
19-
# TODO: This should be replaced with `P` just being some `NoAdapt` type.
31+
chain_to_beta::A
32+
"strategy to use for swapping"
33+
swapstrategy::SwapT=ReversibleSwap()
34+
# TODO: Remove `adapt` and just consider `adaptation_states=nothing` as no adaptation.
2035
"boolean flag specifying whether or not to adapt"
21-
adapt
36+
adapt=false
2237
"adaptation parameters"
23-
adaptation_states
38+
adaptation_states::Adapt=nothing
2439
end
2540

26-
swapstrategy(sampler::TemperedSampler) = sampler.swap_strategy
41+
TemperedSampler(sampler, chain_to_beta; kwargs...) = TemperedSampler(; sampler, chain_to_beta, kwargs...)
42+
43+
swapsampler(sampler::TemperedSampler) = SwapSampler(sampler.swapstrategy)
2744

45+
# TODO: Do we need this now?
2846
getsampler(samplers, I...) = getindex(samplers, I...)
2947
getsampler(sampler::AbstractMCMC.AbstractSampler, I...) = sampler
3048
getsampler(sampler::TemperedSampler, I...) = getsampler(sampler.sampler, I...)
3149

32-
"""
33-
numsteps(sampler::TemperedSampler)
34-
35-
Return number of inverse temperatures used by `sampler`.
36-
"""
37-
numtemps(sampler::TemperedSampler) = length(sampler.inverse_temperatures)
50+
chain_to_process(state::TemperedState, I...) = chain_to_process(state.swapstate, I...)
51+
process_to_chain(state::TemperedState, I...) = process_to_chain(state.swapstate, I...)
3852

3953
"""
40-
sampler_for_chain(sampler::TemperedSampler, state::TemperedState[, I...])
54+
sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
4155
4256
Return the sampler corresponding to the chain indexed by `I...`.
43-
If `I...` is not specified, the sampler corresponding to `β=1.0` will be returned.
4457
"""
45-
sampler_for_chain(sampler::TemperedSampler, state::TemperedState) = sampler_for_chain(sampler, state, 1)
4658
function sampler_for_chain(sampler::TemperedSampler, state::TemperedState, I...)
47-
return getsampler(sampler.sampler, chain_to_process(state, I...))
59+
return sampler_for_process(sampler, state, chain_to_process(state, I...))
4860
end
4961

5062
"""
@@ -53,9 +65,51 @@ end
5365
Return the sampler corresponding to the process indexed by `I...`.
5466
"""
5567
function sampler_for_process(sampler::TemperedSampler, state::TemperedState, I...)
56-
return getsampler(sampler.sampler, I...)
68+
return _sampler_for_process_temper(sampler.sampler, state, I...)
5769
end
5870

71+
# If `sampler` is a `MultiSampler`, we assume it's ordered according to chains.
72+
_sampler_for_process_temper(sampler::MultiSampler, state, I...) = sampler.samplers[process_to_chain(state, I...)]
73+
# Otherwise, we just use the same sampler for everything.
74+
_sampler_for_process_temper(sampler, state, I...) = sampler
75+
76+
# Defer extracting the corresponding state to the `swapstate`.
77+
state_for_process(state::TemperedState, I...) = state_for_process(state.swapstate, I...)
78+
79+
# Here we make the model(s) using the temperatures.
80+
function model_for_process(sampler::TemperedSampler, model, state::TemperedState, I...)
81+
return make_tempered_model(sampler, model, beta_for_process(state, I...))
82+
end
83+
84+
"""
85+
beta_for_chain(state[, I...])
86+
87+
Return the β corresponding to the chain indexed by `I...`.
88+
If `I...` is not specified, the β corresponding to `β=1.0` will be returned.
89+
"""
90+
beta_for_chain(state::TemperedState) = beta_for_chain(state, 1)
91+
beta_for_chain(state::TemperedState, I...) = beta_for_chain(state.chain_to_beta, I...)
92+
# NOTE: Array impl. is useful for testing.
93+
beta_for_chain(chain_to_beta::AbstractArray, I...) = chain_to_beta[I...]
94+
95+
"""
96+
beta_for_process(state, I...)
97+
98+
Return the β corresponding to the process indexed by `I...`.
99+
"""
100+
beta_for_process(state::TemperedState, I...) = beta_for_process(state.chain_to_beta, state.swapstate.process_to_chain, I...)
101+
# NOTE: Array impl. is useful for testing.
102+
function beta_for_process(chain_to_beta::AbstractArray, proc2chain::AbstractArray, I...)
103+
return beta_for_chain(chain_to_beta, process_to_chain(proc2chain, I...))
104+
end
105+
106+
"""
107+
numsteps(sampler::TemperedSampler)
108+
109+
Return number of inverse temperatures used by `sampler`.
110+
"""
111+
numtemps(sampler::TemperedSampler) = length(sampler.chain_to_beta)
112+
59113
"""
60114
tempered(sampler, inverse_temperatures; kwargs...)
61115
OR
@@ -99,7 +153,7 @@ function tempered(
99153
inverse_temperatures::Vector{<:Real};
100154
swap_strategy::AbstractSwapStrategy=ReversibleSwap(),
101155
# TODO: Change `swap_every` to something like `number_of_iterations_per_swap`.
102-
swap_every::Integer=1,
156+
steps_per_swap::Integer=1,
103157
adapt::Bool=false,
104158
adapt_target::Real=0.234,
105159
adapt_stepsize::Real=1,
@@ -109,14 +163,13 @@ function tempered(
109163
kwargs...
110164
)
111165
!(adapt && typeof(swap_strategy) <: Union{RandomSwap, SingleRandomSwap}) || error("Adaptation of the inverse temperature ladder is not currently supported under the chosen swap strategy.")
112-
swap_every 1 || error("`swap_every` must take a positive integer value.")
166+
steps_per_swap > 0 || error("`steps_per_swap` must take a positive integer value.")
113167
inverse_temperatures = check_inverse_temperatures(inverse_temperatures)
114168
adaptation_states = init_adaptation(
115169
adapt_schedule, inverse_temperatures, adapt_target, adapt_scale, adapt_eta, adapt_stepsize
116170
)
117171
# NOTE: We just make a repeated sampler for `sampler_inner`.
118172
# TODO: Generalize. Allow passing in a `MultiSampler`, etc.
119-
sampler_inner = sampler^swap_every
120-
# FIXME: Remove the hard-coded `2` for swap-every, and change `should_swap` acoordingly.
121-
return TemperedSampler(sampler_inner, inverse_temperatures, 2, swap_strategy, adapt, adaptation_states)
173+
sampler_inner = sampler^steps_per_swap
174+
return TemperedSampler(sampler_inner, inverse_temperatures, swap_strategy, adapt, adaptation_states)
122175
end

0 commit comments

Comments
 (0)