-
Couldn't load subscription status.
- Fork 230
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Gibbs sampler #2647
Changes from 3 commits
c0158ea
a972b5a
bdb7f73
c3cc773
714c1e8
97c571d
94b723d
891ac14
2058ae5
b0812a3
d910312
4b1dc2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,245 @@ | ||||||||
| using DynamicPPL: VarName | ||||||||
| using Random: Random | ||||||||
| import AbstractMCMC | ||||||||
|
|
||||||||
| # These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl | ||||||||
|
|
||||||||
| """ | ||||||||
| GibbsConditional(sym::Symbol, conditional) | ||||||||
|
|
||||||||
| A Gibbs sampler component that samples a variable according to a user-provided | ||||||||
| analytical conditional distribution. | ||||||||
|
|
||||||||
| The `conditional` function should take a `NamedTuple` of conditioned variables and return | ||||||||
| a `Distribution` from which to sample the variable `sym`. | ||||||||
|
|
||||||||
| # Examples | ||||||||
|
|
||||||||
| ```julia | ||||||||
| # Define a model | ||||||||
| @model function inverse_gdemo(x) | ||||||||
| λ ~ Gamma(2, 3) | ||||||||
| m ~ Normal(0, sqrt(1 / λ)) | ||||||||
| for i in 1:length(x) | ||||||||
| x[i] ~ Normal(m, sqrt(1 / λ)) | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| # Define analytical conditionals | ||||||||
| function cond_λ(c::NamedTuple) | ||||||||
| a = 2.0 | ||||||||
| b = 3.0 | ||||||||
| m = c.m | ||||||||
| x = c.x | ||||||||
| n = length(x) | ||||||||
| a_new = a + (n + 1) / 2 | ||||||||
| b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the |
||||||||
| return Gamma(a_new, 1 / b_new) | ||||||||
| end | ||||||||
|
|
||||||||
| function cond_m(c::NamedTuple) | ||||||||
| λ = c.λ | ||||||||
| x = c.x | ||||||||
| n = length(x) | ||||||||
| m_mean = sum(x) / (n + 1) | ||||||||
| m_var = 1 / (λ * (n + 1)) | ||||||||
| return Normal(m_mean, sqrt(m_var)) | ||||||||
| end | ||||||||
|
|
||||||||
| # Sample using GibbsConditional | ||||||||
| model = inverse_gdemo([1.0, 2.0, 3.0]) | ||||||||
| chain = sample(model, Gibbs( | ||||||||
| :λ => GibbsConditional(:λ, cond_λ), | ||||||||
| :m => GibbsConditional(:m, cond_m) | ||||||||
| ), 1000) | ||||||||
| ``` | ||||||||
| """ | ||||||||
| struct GibbsConditional{S,C} <: InferenceAlgorithm | ||||||||
|
||||||||
| conditional::C | ||||||||
|
|
||||||||
| function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||||||||
| return new{sym,C}(conditional) | ||||||||
| end | ||||||||
| end | ||||||||
|
|
||||||||
| # Mark GibbsConditional as a valid Gibbs component | ||||||||
| isgibbscomponent(::GibbsConditional) = true | ||||||||
|
|
||||||||
| """ | ||||||||
| DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||||||||
|
|
||||||||
| Initialize the GibbsConditional sampler. | ||||||||
| """ | ||||||||
| function DynamicPPL.initialstep( | ||||||||
| rng::Random.AbstractRNG, | ||||||||
| model::DynamicPPL.Model, | ||||||||
| sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||||||||
| vi::DynamicPPL.AbstractVarInfo; | ||||||||
| kwargs..., | ||||||||
| ) | ||||||||
| # GibbsConditional doesn't need any special initialization | ||||||||
| # Just return the initial state | ||||||||
| return nothing, vi | ||||||||
| end | ||||||||
|
|
||||||||
| """ | ||||||||
| AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) | ||||||||
|
|
||||||||
| Perform a step of GibbsConditional sampling. | ||||||||
| """ | ||||||||
| function AbstractMCMC.step( | ||||||||
| rng::Random.AbstractRNG, | ||||||||
| model::DynamicPPL.Model, | ||||||||
| sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||||||||
| state::DynamicPPL.AbstractVarInfo; | ||||||||
| kwargs..., | ||||||||
| ) where {S} | ||||||||
| alg = sampler.alg | ||||||||
|
|
||||||||
| # For GibbsConditional within Gibbs, we need to get all variable values | ||||||||
| # Check if we're in a Gibbs context | ||||||||
| global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | ||||||||
|
||||||||
| global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | |
| global_vi = if isdefined(model, :context) && model.context isa GibbsContext |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core idea here of finding the possible GibbsContext and getting the global varinfo from it is good. However, GibbsContext is a bit weird, in that it's always inserted at the bottom of the context stack. By the context stack I mean the fact that contexts often have child contexts, and thus model.context may in fact be many nested contexts. See e.g. how the GibbsContext is set here, by calling setleafcontext rather than setcontext:
Line 258 in d75e6f2
| gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) |
So rather than check whether model.context isa GibbsContext, I think you'll need to traverse the whole context stack, and check if any of them are a GibbsContext, until you hit a leaf context and the stack ends.
Moreover, I think you'll need to check not just for GibbsContext, but also for ConditionContext and FixedContext, which condition/fix the values of some variables. So all in all, if you go through the whole stack, starting with model.context and going through its child contexts, and collect any variables set in ConditionContext, FixedContext, and GibbsContext, that should give you all of the variable values you need. See here for more details on condition and fix: https://github.com/TuringLang/DynamicPPL.jl/blob/1ed8cc8d9f013f46806c88a83e93f7a4c5b891dd/src/contexts.jl#L258
As mentioned on Slack a week or two ago, all this context stack business is likely changing Soon (TM), since @penelopeysm is overhauling condition and fix over here, TuringLang/DynamicPPL.jl#1010, and as a result we may be able to overhaul GibbsContext as well. You could wait for that to be finished first, at least if it looks like getting this to work would be a lot of work.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operating principle of the new(ish) Gibbs sampler is that every component sampler only ever sees a VarInfo with the variables that that component sampler is supposed to sample. Thus, you should be able to assume that updated includes values for all the variables in state, and for nothing else. Hence the below checks and loops I think shouldn't be necessary. The solution be might be as simple as new_state = unflatten(state, updated), though there may be details there that I'm not thinking of right now. (What if state is linked? But maybe we can guarantee that it's never linked, because the sampler can control it.) Happy to discuss details more if unflatten by itself doesn't seem to cut it.
Outdated
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local keyword is unnecessary here since updated_vi is already in a local scope. This adds visual clutter without functional benefit.
| local updated_vi = state | |
| updated_vi = state |
Outdated
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message could be more helpful by suggesting what variables are available or providing debugging information about the VarInfo contents.
| error("Could not find variable $S in VarInfo") | |
| error("Could not find variable $S in VarInfo. Available variables: $(join([string(DynamicPPL.getsym(k)) for k in keys(state)], \", \")).") |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you shouldn't need this, because the log joint is going to be recomputed anyway by the Gibbs sampler once it's looped over all component samplers. Saves one model evaluation.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would hope you wouldn't need to overload this or gibbs_initialstep_recursive. Also, the below implementation seems to be just a repeat of step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be
Gamma(2, inv(3))?