Skip to content

Commit c4c1641

Browse files
Introduction of compositions and products of samplers and models (#151)
* added CompositionSampler, RepeatedSampler, MultiSampler together with additional methods for meta-type samplers * added LinearAlgebra as dep * big update but now everything finally works * added additional pass-on-methods for meta-samplers and moved the bundle_samples to a more appropriate place * renamed state_from_state to state_from and changed the ordering of the args to be more reasonable * added some missing methods and fixed a typo * added model_for_chain and model_for_process similar to other utility methods for interacting with the tempered state, etc. * added todo * moved bundling back to ordering of defintions * added missing test dep * increase number of steps for one of the tests * specialize step for combination of RepeatedSampler and MultiSampler * Update src/sampler.jl Co-authored-by: Harrison Wilde <[email protected]> * Introduction of `SwapSampler` + make `TemperedSampler` a fancy version of `CompositionSampler` (#152) * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed unnecessary variable in tests * Update src/sampler.jl Co-authored-by: Harrison Wilde <[email protected]> * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * split the transitions and states field in TemperedState * improved internals of CompositionSampler * ongoing work * added swap sampler * added ordering specification and a TemperedComposition * integrated work on TemperedComposition into TemperedSampler and removed the former * reorederd stuff so it actually works * fixed bug in swapping computation * added length implementation for MultiModel * improved construct for TemperedSampler and added some convenience methods * fixed bundle_samples for Chains and TemperedTransition * fixed breaking bug in setparams_and_logprob!! for SwapState * remove usage of adapted HMC in tests * remove doubling of iterations when testing tempering * fixed bugs with MALA and tempering * relax atol a bit for HMC * relax another atol * TemperedComposition is now truly just a wrapper around a CompositionSampler * added method for computing roundtrips * fixed testing + added test for roundtrips * added docs for roundtrips method * added some tests for SwapSampler without tempering * remove ordering from SwapSampler since it should only interact with ProcessOrdering * simplified the sorting according to chains and processes * added some comments * some minor refactoring * some refactoring + TemperedSampler now orders the samplers correctly * remove expected_ordering and make ordering assumptions more explicit * relax type-constraints in state_for_chain so it also works with TemperedState * removed redundant implementations of swap_attempt * rename swap_betas! to swap! * moved swap_attempt as it now requires definition of SwapSampler * removed unnecessary setparams_and_logprob!! that should never be hit with the current codebase * removed expected_order * removed unnecessary variable in tests * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * removed burn-in from step in prep for AbstractMCMC improvements * remove getparams_and_logprob implementation for SwapState as it's unclear what is the right approach * Apply suggestions from code review Co-authored-by: Harrison Wilde <[email protected]> * added CompositionTransition + quite a few bundle_samples with a `bundle_resolve_swaps` kwarg to allow converting into chains more easily * more samples * reduce requirement for ess comparison for AHMC a bit * significant improvements to the simple Gaussian example, now testing using MCSE to get tolerances, etc. and small improvements to the rest of the tests * trying to debug these tests * more debug * fixed typy * reduce significance even further --------- Co-authored-by: Harrison Wilde <[email protected]> --------- Co-authored-by: Harrison Wilde <[email protected]>
1 parent 7cf05a4 commit c4c1641

21 files changed

+1835
-388
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
14+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1315
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1416
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1517
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

src/MCMCTempering.jl

Lines changed: 191 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@ using ProgressLogging: ProgressLogging
99
using ConcreteStructs: @concrete
1010
using Setfield: @set, @set!
1111

12+
using MCMCChains: MCMCChains
13+
1214
using InverseFunctions
1315

1416
using DocStringExtensions
1517

1618
include("logdensityproblems.jl")
19+
include("abstractmcmc.jl")
1720
include("adaptation.jl")
1821
include("swapping.jl")
1922
include("state.jl")
23+
include("swapsampler.jl")
2024
include("sampler.jl")
2125
include("sampling.jl")
2226
include("ladders.jl")
2327
include("stepping.jl")
2428
include("model.jl")
29+
include("utils.jl")
2530

2631
export tempered,
2732
tempered_sample,
@@ -39,16 +44,199 @@ implements_logdensity(x) = LogDensityProblems.capabilities(x) !== nothing
3944
maybe_wrap_model(model) = implements_logdensity(model) ? AbstractMCMC.LogDensityModel(model) : model
4045
maybe_wrap_model(model::AbstractMCMC.LogDensityModel) = model
4146

47+
# Bundling.
48+
# Bundling of non-tempered samples.
49+
function bundle_nontempered_samples(
50+
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
51+
model::AbstractMCMC.AbstractModel,
52+
sampler::TemperedSampler,
53+
state::TemperedState,
54+
::Type{T};
55+
kwargs...
56+
) where {T}
57+
# Create the same model and sampler as we do in the initial step for `TemperedSampler`.
58+
multimodel = MultiModel([
59+
make_tempered_model(sampler, model, sampler.chain_to_beta[i])
60+
for i in 1:numtemps(sampler)
61+
])
62+
multisampler = MultiSampler([getsampler(sampler, i) for i in 1:numtemps(sampler)])
63+
multitransitions = [
64+
MultipleTransitions(sort_by_chain(ProcessOrder(), t.swaptransition, t.transition.transitions))
65+
for t in ts
66+
]
67+
68+
return AbstractMCMC.bundle_samples(
69+
multitransitions,
70+
multimodel,
71+
multisampler,
72+
MultipleStates(sort_by_chain(ProcessOrder(), state.swapstate, state.state.states)),
73+
T
74+
)
75+
end
76+
4277
function AbstractMCMC.bundle_samples(
43-
ts::AbstractVector,
78+
ts::Vector{<:MultipleTransitions},
79+
model::MultiModel,
80+
sampler::MultiSampler,
81+
state::MultipleStates,
82+
# TODO: Generalize for any eltype `T`? Then need to overload for `Real`, etc.?
83+
::Type{Vector{MCMCChains.Chains}};
84+
kwargs...
85+
)
86+
return map(1:length(model), model.models, sampler.samplers, state.states) do i, model, sampler, state
87+
AbstractMCMC.bundle_samples([t.transitions[i] for t in ts], model, sampler, state, MCMCChains.Chains; kwargs...)
88+
end
89+
end
90+
91+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
92+
function AbstractMCMC.bundle_samples(
93+
ts::Vector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
4494
model::AbstractMCMC.AbstractModel,
4595
sampler::TemperedSampler,
4696
state::TemperedState,
47-
chain_type::Type;
97+
::Type{Vector{T}};
98+
bundle_resolve_swaps::Bool=false,
99+
kwargs...
100+
) where {T}
101+
if bundle_resolve_swaps
102+
return bundle_nontempered_samples(ts, model, sampler, state, Vector{T}; kwargs...)
103+
end
104+
105+
# TODO: Do better?
106+
return ts
107+
end
108+
109+
function AbstractMCMC.bundle_samples(
110+
ts::AbstractVector{<:TemperedTransition{<:SwapTransition,<:MultipleTransitions}},
111+
model::AbstractMCMC.AbstractModel,
112+
sampler::TemperedSampler,
113+
state::TemperedState,
114+
::Type{MCMCChains.Chains};
48115
kwargs...
49116
)
117+
# Extract the transitions ordered, which are ordered according to processes, according to the chains.
118+
ts_actual = [t.transition.transitions[first(t.swaptransition.chain_to_process)] for t in ts]
119+
return AbstractMCMC.bundle_samples(
120+
ts_actual,
121+
model,
122+
sampler_for_chain(sampler, state, 1),
123+
state_for_chain(state, 1),
124+
MCMCChains.Chains;
125+
kwargs...
126+
)
127+
end
128+
129+
function AbstractMCMC.bundle_samples(
130+
ts::AbstractVector,
131+
model::AbstractMCMC.AbstractModel,
132+
sampler::CompositionSampler,
133+
state::CompositionState,
134+
::Type{T};
135+
kwargs...
136+
) where {T}
137+
# In the case of `!saveall(sampler)`, the state is not a `CompositionTransition` so we just propagate
138+
# the transitions to the `bundle_samples` for the outer stuff. Otherwise, we flatten the transitions.
139+
ts_actual = saveall(sampler) ? mapreduce(t -> [inner_transition(t), outer_transition(t)], vcat, ts) : ts
140+
# TODO: Should we really always default to outer sampler?
141+
return AbstractMCMC.bundle_samples(
142+
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
143+
kwargs...
144+
)
145+
end
146+
147+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
148+
function AbstractMCMC.bundle_samples(
149+
ts::Vector,
150+
model::AbstractMCMC.AbstractModel,
151+
sampler::CompositionSampler,
152+
state::CompositionState,
153+
::Type{Vector{T}};
154+
kwargs...
155+
) where {T}
156+
if !saveall(sampler)
157+
# In this case, we just use the `outer` for everything since this is the only
158+
# transitions we're keeping around.
159+
return AbstractMCMC.bundle_samples(
160+
ts, model, sampler.sampler_outer, state.state_outer, Vector{T};
161+
kwargs...
162+
)
163+
end
164+
165+
# Otherwise, we don't know what to do.
166+
return ts
167+
end
168+
169+
function AbstractMCMC.bundle_samples(
170+
ts::AbstractVector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
171+
model::AbstractMCMC.AbstractModel,
172+
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
173+
state::CompositionState{<:MultipleStates,<:SwapState},
174+
::Type{T};
175+
bundle_resolve_swaps::Bool=false,
176+
kwargs...
177+
) where {T}
178+
!bundle_resolve_swaps && return ts
179+
180+
# Resolve the swaps.
181+
sampler_without_saveall = @set sampler.sampler_inner.saveall = Val(false)
182+
ts_actual = map(ts) do t
183+
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
184+
end
185+
50186
AbstractMCMC.bundle_samples(
51-
ts, maybe_wrap_model(model), sampler_for_chain(sampler, state, 1), state_for_chain(state, 1), chain_type;
187+
ts_actual, model, sampler.sampler_outer, state.state_outer, T;
188+
kwargs...
189+
)
190+
end
191+
192+
# HACK: https://github.com/TuringLang/AbstractMCMC.jl/issues/118
193+
function AbstractMCMC.bundle_samples(
194+
ts::Vector{<:CompositionTransition{<:MultipleTransitions,<:SwapTransition}},
195+
model::AbstractMCMC.AbstractModel,
196+
sampler::CompositionSampler{<:MultiSampler,<:SwapSampler},
197+
state::CompositionState{<:MultipleStates,<:SwapState},
198+
::Type{Vector{T}};
199+
bundle_resolve_swaps::Bool=false,
200+
kwargs...
201+
) where {T}
202+
!bundle_resolve_swaps && return ts
203+
204+
# Resolve the swaps (using the already implemented resolution in `composition_transition`
205+
# for this particular sampler but without `saveall`).
206+
sampler_without_saveall = @set sampler.saveall = Val(false)
207+
ts_actual = map(ts) do t
208+
composition_transition(sampler_without_saveall, inner_transition(t), outer_transition(t))
209+
end
210+
211+
return AbstractMCMC.bundle_samples(
212+
ts_actual, model, sampler.sampler_outer, state.state_outer, Vector{T};
213+
kwargs...
214+
)
215+
end
216+
217+
function AbstractMCMC.bundle_samples(
218+
ts::AbstractVector,
219+
model::AbstractMCMC.AbstractModel,
220+
sampler::RepeatedSampler,
221+
state,
222+
::Type{MCMCChains.Chains};
223+
kwargs...
224+
)
225+
return AbstractMCMC.bundle_samples(ts, model, sampler.sampler, state, MCMCChains.Chains; kwargs...)
226+
end
227+
228+
# Unflatten in the case of `SequentialTransitions`.
229+
function AbstractMCMC.bundle_samples(
230+
ts::AbstractVector{<:SequentialTransitions},
231+
model::AbstractMCMC.AbstractModel,
232+
sampler::RepeatedSampler,
233+
state::SequentialStates,
234+
::Type{MCMCChains.Chains};
235+
kwargs...
236+
)
237+
ts_actual = [t for tseq in ts for t in tseq.transitions]
238+
return AbstractMCMC.bundle_samples(
239+
ts_actual, model, sampler.sampler, state.states[end], MCMCChains.Chains;
52240
kwargs...
53241
)
54242
end

src/abstractmcmc.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using Setfield
2+
using AbstractMCMC: AbstractMCMC
3+
4+
import LinearAlgebra: ×
5+
6+
"""
7+
getparams([model, ]state)
8+
9+
Get the parameters from the `state`.
10+
11+
Default implementation uses [`getparams_and_logprob`](@ref).
12+
"""
13+
getparams(state) = first(getparams_and_logprob(state))
14+
getparams(model, state) = first(getparams_and_logprob(model, state))
15+
16+
"""
17+
getlogprob([model, ]state)
18+
19+
Get the log probability of the `state`.
20+
21+
Default implementation uses [`getparams_and_logprob`](@ref).
22+
"""
23+
getlogprob(state) = last(getparams_and_logprob(state))
24+
getlogprob(model, state) = last(getparams_and_logprob(model, state))
25+
26+
"""
27+
getparams_and_logprob([model, ]state)
28+
29+
Return a vector of parameters from the `state`.
30+
31+
See also: [`setparams_and_logprob!!`](@ref).
32+
"""
33+
getparams_and_logprob(model, state) = getparams_and_logprob(state)
34+
35+
"""
36+
setparams_and_logprob!!([model, ]state, params)
37+
38+
Set the parameters in the state to `params`, possibly mutating if it makes sense.
39+
40+
See also: [`getparams_and_logprob`](@ref).
41+
"""
42+
setparams_and_logprob!!(model, state, params, logprob) = setparams_and_logprob!!(state, params, logprob)
43+
44+
"""
45+
state_from(model, state_target, state_source[, transition_source, transition_target])
46+
47+
Return a new state similar to `state_target` but updated from `state_source`, which could be
48+
a different type of state.
49+
"""
50+
function state_from(model, state_target, state_source, transition_target, transition_source)
51+
return state_from(model, state_target, state_source)
52+
end
53+
function state_from(model, state_target, state_source)
54+
params, logp = getparams_and_logprob(model, state_source)
55+
return setparams_and_logprob!!(model, state_target, params, logp)
56+
end
57+
58+
"""
59+
SequentialTransitions
60+
61+
A `SequentialTransitions` object is a container for a sequence of transitions.
62+
"""
63+
struct SequentialTransitions{A}
64+
transitions::A
65+
end
66+
67+
# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
68+
# last transition/state.
69+
getparams_and_logprob(transitions::SequentialTransitions) = getparams_and_logprob(transitions.transitions[end])
70+
function getparams_and_logprob(model, transitions::SequentialTransitions)
71+
return getparams_and_logprob(model, transitions.transitions[end])
72+
end
73+
74+
function setparams_and_logprob!!(transitions::SequentialTransitions, params, logprob)
75+
return @set transitions.transitions[end] = setparams_and_logprob!!(transitions.transitions[end], params, logprob)
76+
end
77+
function setparams_and_logprob!!(model, transitions::SequentialTransitions, params, logprob)
78+
return @set transitions.transitions[end] = setparams_and_logprob!!(model, transitions.transitions[end], params, logprob)
79+
end
80+
81+
"""
82+
SequentialStates
83+
84+
A `SequentialStates` object is a container for a sequence of states.
85+
"""
86+
struct SequentialStates{A}
87+
states::A
88+
end
89+
90+
# Since it's a _sequence_ of transitions, the parameters and logprobs are the ones of the
91+
# last transition/state.
92+
getparams_and_logprob(state::SequentialStates) = getparams_and_logprob(state.states[end])
93+
getparams_and_logprob(model, state::SequentialStates) = getparams_and_logprob(model, state.states[end])
94+
95+
function setparams_and_logprob!!(state::SequentialStates, params, logprob)
96+
return @set state.states[end] = setparams_and_logprob!!(state.states[end], params, logprob)
97+
end
98+
function setparams_and_logprob!!(model, state::SequentialStates, params, logprob)
99+
return @set state.states[end] = setparams_and_logprob!!(model, state.states[end], params, logprob)
100+
end
101+
102+
# Includes.
103+
include("samplers/composition.jl")
104+
include("samplers/repeated.jl")
105+
include("samplers/multi.jl")
106+

src/adaptation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ See also: [`AdaptiveState`](@ref), [`update_inverse_temperatures`](@ref), and
1919
"""
2020
struct Geometric end
2121

22-
defaultscale(::Geometric, Δ) = eltype(Δ)(0.9)
22+
defaultscale(::Geometric, Δ) = float(eltype))(0.9)
2323

2424
"""
2525
InverselyAdditive

src/ladders.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ end
3737
Checks and returns a sorted `Δ` containing `{β₀, ..., βₙ}` conforming such that `1 = β₀ > β₁ > ... > βₙ ≥ 0`
3838
"""
3939
function check_inverse_temperatures(Δ)
40-
if length(Δ) <= 1
41-
error("More than one inverse temperatures must be provided.")
42-
end
40+
!isempty(Δ) || error("Inverse temperatures array is empty.")
4341
if !all(zero.(Δ) .≤ Δ .≤ one.(Δ))
4442
error("The temperature ladder provided has values outside of the acceptable range, ensure all values are in [0, 1].")
4543
end

0 commit comments

Comments
 (0)