Skip to content

Commit 467b076

Browse files
torfjeldedevmotionsunxd3github-actions[bot]
authored
Add getparameters and setparameters!! (#86)
* added state_from_transition, parameters and setparameters!! * Update src/AbstractMCMC.jl Co-authored-by: David Widmann <[email protected]> * renamed state_from_transition to updatestate!! * adhere to julia convention * added docs * fixed docs * fixed docs * added example for why updatestate!! is useful * improved MixtureState example * further improvements to docs * renamed parameters and setparameters!! to values and setvalues!! * fixed typo in docs * fixed documenting values * improved and fixed some bugs in docs * fixed typo in docs * renamed values and setvalues!! to realize and realize!! * added model to updatestate!! * Apply suggestions from code review Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update docs/src/api.md Co-authored-by: Xianda Sun <[email protected]> * Apply suggestions from code review Co-authored-by: Xianda Sun <[email protected]> * Update docs/src/api.md Co-authored-by: Xianda Sun <[email protected]> * version bump --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Xianda Sun <[email protected]>
1 parent fc8cfa6 commit 467b076

File tree

3 files changed

+163
-1
lines changed

3 files changed

+163
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.4.0"
6+
version = "5.5.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

docs/src/api.md

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,144 @@ For chains of this type, AbstractMCMC defines the following two methods.
113113
AbstractMCMC.chainscat
114114
AbstractMCMC.chainsstack
115115
```
116+
117+
## Interacting with states of samplers
118+
119+
To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
120+
```@docs
121+
AbstractMCMC.getparams
122+
AbstractMCMC.setparams!!
123+
```
124+
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.
125+
126+
### Example: `MixtureSampler`
127+
128+
In a `MixtureSampler` we need two things:
129+
- `components`: collection of samplers.
130+
- `weights`: collection of weights representing the probability of choosing the corresponding sampler.
131+
132+
```julia
133+
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
134+
components::C
135+
weights::W
136+
end
137+
```
138+
139+
To implement the state, we need to keep track of a couple of things:
140+
- `index`: the index of the sampler used in this `step`.
141+
- `states`: the current states of _all_ the components.
142+
We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
143+
The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.
144+
145+
146+
```julia
147+
struct MixtureState{S}
148+
index::Int
149+
states::S
150+
end
151+
```
152+
The `step` for a `MixtureSampler` is defined by the following generative process
153+
```math
154+
\begin{aligned}
155+
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
156+
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
157+
\end{aligned}
158+
```
159+
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
160+
[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.
161+
162+
If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:
163+
164+
```julia
165+
# Update the corresponding state, i.e. `state.states[i]`, using
166+
# the state and transition from the previous iteration.
167+
state_current = AbstractMCMC.setparams!!(
168+
state.states[i],
169+
AbstractMCMC.getparams(state.states[i_prev]),
170+
)
171+
172+
# Take a `step` for this sampler using the updated state.
173+
transition, state_current = AbstractMCMC.step(
174+
rng, model, sampler_current, sampler_state;
175+
kwargs...
176+
)
177+
```
178+
179+
The full [`AbstractMCMC.step`](@ref) implementation would then be something like:
180+
181+
```julia
182+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
183+
# Sample the component to use in this `step`.
184+
i = rand(Categorical(sampler.weights))
185+
sampler_current = sampler.components[i]
186+
187+
# Update the corresponding state, i.e. `state.states[i]`, using
188+
# the state and transition from the previous iteration.
189+
i_prev = state.index
190+
state_current = AbstractMCMC.setparams!!(
191+
state.states[i],
192+
AbstractMCMC.getparams(state.states[i_prev]),
193+
)
194+
195+
# Take a `step` for this sampler using the updated state.
196+
transition, state_current = AbstractMCMC.step(
197+
rng, model, sampler_current, state_current;
198+
kwargs...
199+
)
200+
201+
# Create the new states.
202+
# NOTE: Code below will result in `states_new` being a `Vector`.
203+
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
204+
# it would be better to use something like `@set states[i] = state_current`
205+
# where `@set` is from Setfield.jl.
206+
states_new = map(1:length(state.states)) do j
207+
if j == i
208+
# Replace the i-th state with the new one.
209+
state_current
210+
else
211+
# Otherwise we just carry over the previous ones.
212+
state.states[j]
213+
end
214+
end
215+
216+
# Create the new `MixtureState`.
217+
state_new = MixtureState(i, states_new)
218+
219+
return transition, state_new
220+
end
221+
```
222+
223+
And for the initial [`AbstractMCMC.step`](@ref) we have:
224+
225+
```julia
226+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
227+
# Initialize every state.
228+
transitions_and_states = map(sampler.components) do spl
229+
AbstractMCMC.step(rng, model, spl; kwargs...)
230+
end
231+
232+
# Sample the component to use this `step`.
233+
i = rand(Categorical(sampler.weights))
234+
# Extract the corresponding transition.
235+
transition = first(transitions_and_states[i])
236+
# Extract states.
237+
states = map(last, transitions_and_states)
238+
# Create new `MixtureState`.
239+
state = MixtureState(i, states)
240+
241+
return transition, state
242+
end
243+
```
244+
245+
Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`.
246+
247+
248+
To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do
249+
250+
```julia
251+
sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9])
252+
transition, state = AbstractMCMC.step(rng, model, sampler)
253+
while ...
254+
transition, state = AbstractMCMC.step(rng, model, sampler, state)
255+
end
256+
```

src/AbstractMCMC.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,27 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
8080
"""
8181
struct MCMCSerial <: AbstractMCMCEnsemble end
8282

83+
"""
84+
getparams(state[; kwargs...])
85+
86+
Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`.
87+
"""
88+
function getparams end
89+
90+
"""
91+
setparams!!(state, params)
92+
93+
Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`.
94+
95+
This function should follow the `BangBang` interface: mutate `state` in-place if possible and
96+
return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters.
97+
98+
Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another
99+
word, the sampler should implement a consistent transformation between its internal representation
100+
and the vector representation of the parameter values.
101+
"""
102+
function setparams!! end
103+
83104
include("samplingstats.jl")
84105
include("logging.jl")
85106
include("interface.jl")

0 commit comments

Comments
 (0)