You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* added state_from_transition, parameters and setparameters!!
* Update src/AbstractMCMC.jl
Co-authored-by: David Widmann <[email protected]>
* renamed state_from_transition to updatestate!!
* adhere to julia convention
* added docs
* fixed docs
* fixed docs
* added example for why updatestate!! is useful
* improved MixtureState example
* further improvements to docs
* renamed parameters and setparameters!! to values and setvalues!!
* fixed typo in docs
* fixed documenting values
* improved and fixed some bugs in docs
* fixed typo in docs
* renamed values and setvalues!! to realize and realize!!
* added model to updatestate!!
* Apply suggestions from code review
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
* Update docs/src/api.md
Co-authored-by: Xianda Sun <[email protected]>
* Apply suggestions from code review
Co-authored-by: Xianda Sun <[email protected]>
* Update docs/src/api.md
Co-authored-by: Xianda Sun <[email protected]>
* version bump
---------
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Xianda Sun <[email protected]>
Copy file name to clipboardExpand all lines: docs/src/api.md
+141Lines changed: 141 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -113,3 +113,144 @@ For chains of this type, AbstractMCMC defines the following two methods.
113
113
AbstractMCMC.chainscat
114
114
AbstractMCMC.chainsstack
115
115
```
116
+
117
+
## Interacting with states of samplers
118
+
119
+
To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
120
+
```@docs
121
+
AbstractMCMC.getparams
122
+
AbstractMCMC.setparams!!
123
+
```
124
+
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.
125
+
126
+
### Example: `MixtureSampler`
127
+
128
+
In a `MixtureSampler` we need two things:
129
+
-`components`: collection of samplers.
130
+
-`weights`: collection of weights representing the probability of choosing the corresponding sampler.
To implement the state, we need to keep track of a couple of things:
140
+
-`index`: the index of the sampler used in this `step`.
141
+
-`states`: the current states of _all_ the components.
142
+
We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
143
+
The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.
144
+
145
+
146
+
```julia
147
+
struct MixtureState{S}
148
+
index::Int
149
+
states::S
150
+
end
151
+
```
152
+
The `step` for a `MixtureSampler` is defined by the following generative process
153
+
```math
154
+
\begin{aligned}
155
+
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
156
+
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
157
+
\end{aligned}
158
+
```
159
+
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
160
+
[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.
161
+
162
+
If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:
163
+
164
+
```julia
165
+
# Update the corresponding state, i.e. `state.states[i]`, using
166
+
# the state and transition from the previous iteration.
167
+
state_current = AbstractMCMC.setparams!!(
168
+
state.states[i],
169
+
AbstractMCMC.getparams(state.states[i_prev]),
170
+
)
171
+
172
+
# Take a `step` for this sampler using the updated state.
173
+
transition, state_current = AbstractMCMC.step(
174
+
rng, model, sampler_current, sampler_state;
175
+
kwargs...
176
+
)
177
+
```
178
+
179
+
The full [`AbstractMCMC.step`](@ref) implementation would then be something like:
180
+
181
+
```julia
182
+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
183
+
# Sample the component to use in this `step`.
184
+
i =rand(Categorical(sampler.weights))
185
+
sampler_current = sampler.components[i]
186
+
187
+
# Update the corresponding state, i.e. `state.states[i]`, using
188
+
# the state and transition from the previous iteration.
189
+
i_prev = state.index
190
+
state_current = AbstractMCMC.setparams!!(
191
+
state.states[i],
192
+
AbstractMCMC.getparams(state.states[i_prev]),
193
+
)
194
+
195
+
# Take a `step` for this sampler using the updated state.
196
+
transition, state_current = AbstractMCMC.step(
197
+
rng, model, sampler_current, state_current;
198
+
kwargs...
199
+
)
200
+
201
+
# Create the new states.
202
+
# NOTE: Code below will result in `states_new` being a `Vector`.
203
+
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
204
+
# it would be better to use something like `@set states[i] = state_current`
205
+
# where `@set` is from Setfield.jl.
206
+
states_new =map(1:length(state.states)) do j
207
+
if j == i
208
+
# Replace the i-th state with the new one.
209
+
state_current
210
+
else
211
+
# Otherwise we just carry over the previous ones.
212
+
state.states[j]
213
+
end
214
+
end
215
+
216
+
# Create the new `MixtureState`.
217
+
state_new =MixtureState(i, states_new)
218
+
219
+
return transition, state_new
220
+
end
221
+
```
222
+
223
+
And for the initial [`AbstractMCMC.step`](@ref) we have:
224
+
225
+
```julia
226
+
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
227
+
# Initialize every state.
228
+
transitions_and_states =map(sampler.components) do spl
229
+
AbstractMCMC.step(rng, model, spl; kwargs...)
230
+
end
231
+
232
+
# Sample the component to use this `step`.
233
+
i =rand(Categorical(sampler.weights))
234
+
# Extract the corresponding transition.
235
+
transition =first(transitions_and_states[i])
236
+
# Extract states.
237
+
states =map(last, transitions_and_states)
238
+
# Create new `MixtureState`.
239
+
state =MixtureState(i, states)
240
+
241
+
return transition, state
242
+
end
243
+
```
244
+
245
+
Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`.
246
+
247
+
248
+
To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do
0 commit comments