diff --git a/HISTORY.md b/HISTORY.md index be6331ca74..3854806d7f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,65 @@ +# 0.41.0 + +## DynamicPPL 0.38 + +Turing.jl v0.41 brings with it all the underlying changes in DynamicPPL 0.38. +Please see [the DynamicPPL changelog](https://github.com/TuringLang/DynamicPPL.jl/blob/main/HISTORY.md) for full details: in this section we only describe the changes that will directly affect end-users of Turing.jl. + +### Performance + +A number of functions such as `returned` and `predict` will have substantially better performance in this release. + +### `ProductNamedTupleDistribution` + +`Distributions.ProductNamedTupleDistribution` can now be used on the right-hand side of `~` in Turing models. + +### Initial parameters + +**Initial parameters for MCMC sampling must now be specified in a different form.** +You still need to use the `initial_params` keyword argument to `sample`, but the allowed values are different. +For almost all samplers in Turing.jl (except `Emcee`) this should now be a `DynamicPPL.AbstractInitStrategy`. + +There are three kinds of initialisation strategies provided out of the box with Turing.jl (they are exported so you can use these directly with `using Turing`): + + - `InitFromPrior()`: Sample from the prior distribution. This is the default for most samplers in Turing.jl (if you don't specify `initial_params`). + + - `InitFromUniform(a, b)`: Sample uniformly from `[a, b]` in linked space. This is the default for Hamiltonian samplers. If `a` and `b` are not specified it defaults to `[-2, 2]`, which preserves the behaviour in previous versions (and mimics that of Stan). + - `InitFromParams(p)`: Explicitly provide a set of initial parameters. **Note: `p` must be either a `NamedTuple` or an `AbstractDict{<:VarName}`; it can no longer be a `Vector`.** Parameters must be provided in unlinked space, even if the sampler later performs linking. + + + For this release of Turing.jl, you can also provide a `NamedTuple` or `AbstractDict{<:VarName}` and this will be automatically wrapped in `InitFromParams` for you. This is an intermediate measure for backwards compatibility, and will eventually be removed. + +This change is made because Vectors are semantically ambiguous. +It is not clear which element of the vector corresponds to which variable in the model, nor is it clear whether the parameters are in linked or unlinked space. +Previously, both of these would depend on the internal structure of the VarInfo, which is an implementation detail. +In contrast, the behaviour of `AbstractDict`s and `NamedTuple`s is invariant to the ordering of variables and it is also easier for readers to understand which variable is being set to which value. + +If you were previously using `varinfo[:]` to extract a vector of initial parameters, you can now use `Dict(k => varinfo[k] for k in keys(varinfo)` to extract a Dict of initial parameters. + +For more details about initialisation you can also refer to [the main TuringLang docs](https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters), and/or the [DynamicPPL API docs](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.InitFromPrior). + +### `resume_from` and `loadstate` + +The `resume_from` keyword argument to `sample` is now removed. +Instead of `sample(...; resume_from=chain)` you can use `sample(...; initial_state=loadstate(chain))` which is entirely equivalent. +`loadstate` is exported from Turing now instead of in DynamicPPL. + +Note that `loadstate` only works for `MCMCChains.Chains`. +For FlexiChains users please consult the FlexiChains docs directly where this functionality is described in detail. + +### `pointwise_logdensities` + +`pointwise_logdensities(model, chn)`, `pointwise_loglikelihoods(...)`, and `pointwise_prior_logdensities(...)` now return an `MCMCChains.Chains` object if `chn` is itself an `MCMCChains.Chains` object. +The old behaviour of returning an `OrderedDict` is still available: you just need to pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chn, OrderedDict)`. + +## Initial step in MCMC sampling + +HMC and NUTS samplers no longer take an extra single step before starting the chain. +This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided). + +Note that if the initial sample is included, the corresponding sampler statistics will be `missing`. +Due to a technical limitation of MCMCChains.jl, this causes all indexing into MCMCChains to return `Union{Float64, Missing}` or similar. +If you want the old behaviour, you can discard the first sample (e.g. using `discard_initial=1`). + # 0.40.5 Bump Optimization.jl compatibility to include v5. diff --git a/Project.toml b/Project.toml index a2e5f206f4..d38d8e1159 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.5" +version = "0.41.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -45,7 +45,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" -TuringOptimExt = "Optim" +TuringOptimExt = ["Optim", "AbstractPPL"] [compat] ADTypes = "1.9" @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.37.2" +DynamicPPL = "0.38" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" diff --git a/docs/make.jl b/docs/make.jl index ab6855e87f..0579e950f8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -4,14 +4,15 @@ using Turing using DocumenterInterLinks links = InterLinks( - "DynamicPPL" => "https://turinglang.org/DynamicPPL.jl/stable/objects.inv", - "AbstractPPL" => "https://turinglang.org/AbstractPPL.jl/stable/objects.inv", - "LinearAlgebra" => "https://docs.julialang.org/en/v1/objects.inv", - "AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/objects.inv", - "ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/objects.inv", - "AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/objects.inv", - "DistributionsAD" => "https://turinglang.org/DistributionsAD.jl/stable/objects.inv", - "OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/objects.inv", + "DynamicPPL" => "https://turinglang.org/DynamicPPL.jl/stable/", + "AbstractPPL" => "https://turinglang.org/AbstractPPL.jl/stable/", + "LinearAlgebra" => "https://docs.julialang.org/en/v1/", + "AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/", + "ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/", + "AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/", + "DistributionsAD" => "https://turinglang.org/DistributionsAD.jl/stable/", + "OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/", + "Distributions" => "https://juliastats.org/Distributions.jl/stable/", ) # Doctest setup @@ -27,6 +28,7 @@ makedocs(; "Inference" => "api/Inference.md", "Optimisation" => "api/Optimisation.md", "Variational " => "api/Variational.md", + "RandomMeasures " => "api/RandomMeasures.md", ], ], checkdocs=:exports, diff --git a/docs/src/api.md b/docs/src/api.md index 0b8351eb3b..885d587ea6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,7 +31,7 @@ DynamicPPL.@model function my_model() end sample(my_model(), Turing.Inference.Prior(), 100) ``` -even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` module and [`@model`](@ref) in the `DynamicPPL` package. +even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` module and [`@model`](@extref `DynamicPPL.@model`) in the `DynamicPPL` package. ### Modelling @@ -46,12 +46,13 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu ### Inference -| Exported symbol | Documentation | Description | -|:----------------- |:------------------------------------------------------------------------------------------------ |:---------------------------------- | -| `sample` | [`StatsBase.sample`](https://turinglang.org/AbstractMCMC.jl/stable/api/#Sampling-a-single-chain) | Sample from a model | -| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads | -| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes | -| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism | +| Exported symbol | Documentation | Description | +|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- | +| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model | +| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads | +| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes | +| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism | +| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` | ### Samplers @@ -75,6 +76,34 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu | `RepeatSampler` | [`Turing.Inference.RepeatSampler`](@ref) | A sampler that runs multiple times on the same variable | | `externalsampler` | [`Turing.Inference.externalsampler`](@ref) | Wrap an external sampler for use in Turing | +### DynamicPPL utilities + +Please see the [generated quantities](https://turinglang.org/docs/tutorials/usage-generated-quantities/) and [probability interface](https://turinglang.org/docs/tutorials/usage-probability-interface/) guides for more information. + +| Exported symbol | Documentation | Description | +|:-------------------------- |:---------------------------------------------------------------------------------------------------------------------------- |:------------------------------------------------------- | +| `returned` | [`DynamicPPL.returned`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.returned-Tuple%7BModel,%20NamedTuple%7D) | Calculate additional quantities defined in a model | +| `predict` | [`StatsAPI.predict`](https://turinglang.org/DynamicPPL.jl/stable/api/#Predicting) | Generate samples from posterior predictive distribution | +| `pointwise_loglikelihoods` | [`DynamicPPL.pointwise_loglikelihoods`](@extref) | Compute log likelihoods for each sample in a chain | +| `logprior` | [`DynamicPPL.logprior`](@extref) | Compute log prior probability | +| `logjoint` | [`DynamicPPL.logjoint`](@extref) | Compute log joint probability | +| `condition` | [`AbstractPPL.condition`](@extref) | Condition a model on data | +| `decondition` | [`AbstractPPL.decondition`](@extref) | Remove conditioning on data | +| `conditioned` | [`DynamicPPL.conditioned`](@extref) | Return the conditioned values of a model | +| `fix` | [`DynamicPPL.fix`](@extref) | Fix the value of a variable | +| `unfix` | [`DynamicPPL.unfix`](@extref) | Unfix the value of a variable | +| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary | + +### Initialisation strategies + +Turing.jl provides several strategies to initialise parameters for models. + +| Exported symbol | Documentation | Description | +|:----------------- |:--------------------------------------- |:--------------------------------------------------------------- | +| `InitFromPrior` | [`DynamicPPL.InitFromPrior`](@extref) | Obtain initial parameters from the prior distribution | +| `InitFromUniform` | [`DynamicPPL.InitFromUniform`](@extref) | Obtain initial parameters by sampling uniformly in linked space | +| `InitFromParams` | [`DynamicPPL.InitFromParams`](@extref) | Manually specify (possibly a subset of) initial parameters | + ### Variational inference See the [docs of AdvancedVI.jl](https://turinglang.org/AdvancedVI.jl/stable/) for detailed usage and the [variational inference tutorial](https://turinglang.org/docs/tutorials/09-variational-inference/) for a basic walkthrough. @@ -124,29 +153,6 @@ LogPoisson | `arraydist` | [`DistributionsAD.arraydist`](@extref) | Create a product distribution from an array of distributions | | `NamedDist` | [`DynamicPPL.NamedDist`](@extref) | A distribution that carries the name of the variable | -### Predictions - -| Exported symbol | Documentation | Description | -|:--------------- |:--------------------------------------------------------------------------------- |:------------------------------------------------------- | -| `predict` | [`StatsAPI.predict`](https://turinglang.org/DynamicPPL.jl/stable/api/#Predicting) | Generate samples from posterior predictive distribution | - -### Querying model probabilities and quantities - -Please see the [generated quantities](https://turinglang.org/docs/tutorials/usage-generated-quantities/) and [probability interface](https://turinglang.org/docs/tutorials/usage-probability-interface/) guides for more information. - -| Exported symbol | Documentation | Description | -|:-------------------------- |:---------------------------------------------------------------------------------------------------------------------------- |:-------------------------------------------------- | -| `returned` | [`DynamicPPL.returned`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.returned-Tuple%7BModel,%20NamedTuple%7D) | Calculate additional quantities defined in a model | -| `pointwise_loglikelihoods` | [`DynamicPPL.pointwise_loglikelihoods`](@extref) | Compute log likelihoods for each sample in a chain | -| `logprior` | [`DynamicPPL.logprior`](@extref) | Compute log prior probability | -| `logjoint` | [`DynamicPPL.logjoint`](@extref) | Compute log joint probability | -| `condition` | [`AbstractPPL.condition`](@extref) | Condition a model on data | -| `decondition` | [`AbstractPPL.decondition`](@extref) | Remove conditioning on data | -| `conditioned` | [`DynamicPPL.conditioned`](@extref) | Return the conditioned values of a model | -| `fix` | [`DynamicPPL.fix`](@extref) | Fix the value of a variable | -| `unfix` | [`DynamicPPL.unfix`](@extref) | Unfix the value of a variable | -| `OrderedDict` | [`OrderedCollections.OrderedDict`](@extref) | An ordered dictionary | - ### Point estimates See the [mode estimation tutorial](https://turinglang.org/docs/tutorials/docs-17-mode-estimation/) for more information. diff --git a/docs/src/api/RandomMeasures.md b/docs/src/api/RandomMeasures.md new file mode 100644 index 0000000000..f37b6118c4 --- /dev/null +++ b/docs/src/api/RandomMeasures.md @@ -0,0 +1,6 @@ +# API: `Turing.RandomMeasures` + +```@autodocs +Modules = [Turing.RandomMeasures] +Order = [:type, :function] +``` diff --git a/ext/TuringDynamicHMCExt.jl b/ext/TuringDynamicHMCExt.jl index 2c4bd08980..9e4c8b6ef8 100644 --- a/ext/TuringDynamicHMCExt.jl +++ b/ext/TuringDynamicHMCExt.jl @@ -44,26 +44,22 @@ struct DynamicNUTSState{L,V<:DynamicPPL.AbstractVarInfo,C,M,S} stepsize::S end -function DynamicPPL.initialsampler(::DynamicPPL.Sampler{<:DynamicNUTS}) - return DynamicPPL.SampleFromUniform() -end - -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:DynamicNUTS}, + spl::DynamicNUTS, vi::DynamicPPL.AbstractVarInfo; kwargs..., ) # Ensure that initial sample is in unconstrained space. - if !DynamicPPL.islinked(vi) + if !DynamicPPL.is_transformed(vi) vi = DynamicPPL.link!!(vi, model) vi = last(DynamicPPL.evaluate!!(model, vi)) end # Define log-density function. ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) # Perform initial step. @@ -84,14 +80,14 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:DynamicNUTS}, + spl::DynamicNUTS, state::DynamicNUTSState; kwargs..., ) # Compute next sample. vi = state.vi ℓ = state.logdensity - steps = DynamicHMC.mcmc_steps(rng, spl.alg.sampler, state.metric, ℓ, state.stepsize) + steps = DynamicHMC.mcmc_steps(rng, spl.sampler, state.metric, ℓ, state.stepsize) Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache) # Create next sample and state. diff --git a/ext/TuringOptimExt.jl b/ext/TuringOptimExt.jl index 0f755988ef..21aecafbe4 100644 --- a/ext/TuringOptimExt.jl +++ b/ext/TuringOptimExt.jl @@ -1,6 +1,7 @@ module TuringOptimExt using Turing: Turing +using AbstractPPL: AbstractPPL import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation using Optim: Optim @@ -186,7 +187,7 @@ function _optimize( f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype ) vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) vns_vals_iter = mapreduce(collect, vcat, iters) varnames = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) diff --git a/src/Turing.jl b/src/Turing.jl index 0cdbe24586..58a58eb2af 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -73,7 +73,10 @@ using DynamicPPL: conditioned, to_submodel, LogDensityFunction, - @addlogprob! + @addlogprob!, + InitFromPrior, + InitFromUniform, + InitFromParams using StatsBase: predict using OrderedCollections: OrderedDict @@ -148,11 +151,17 @@ export fix, unfix, OrderedDict, # OrderedCollections + # Initialisation strategies for models + InitFromPrior, + InitFromUniform, + InitFromParams, # Point estimates - Turing.Optimisation # The MAP and MLE exports are only needed for the Optim.jl interface. maximum_a_posteriori, maximum_likelihood, MAP, - MLE + MLE, + # Chain save/resume + loadstate end diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 53bf6dbc08..7d25ecd7ee 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -13,7 +13,6 @@ using DynamicPPL: # or implement it for other VarInfo types and export it from DPPL. all_varnames_grouped_by_symbol, syms, - islinked, setindex!!, push!!, setlogp!!, @@ -23,12 +22,7 @@ using DynamicPPL: getsym, getdist, Model, - Sampler, - SampleFromPrior, - SampleFromUniform, - DefaultContext, - set_flag!, - unset_flag! + DefaultContext using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra @@ -55,12 +49,9 @@ import Random import MCMCChains import StatsBase: predict -export InferenceAlgorithm, - Hamiltonian, +export Hamiltonian, StaticHamiltonian, AdaptiveHamiltonian, - SampleFromUniform, - SampleFromPrior, MH, ESS, Emcee, @@ -78,13 +69,16 @@ export InferenceAlgorithm, RepeatSampler, Prior, predict, - externalsampler + externalsampler, + init_strategy, + loadstate -############################################### -# Abstract interface for inference algorithms # -############################################### +######################################### +# Generic AbstractMCMC methods dispatch # +######################################### -include("algorithm.jl") +const DEFAULT_CHAIN_TYPE = MCMCChains.Chains +include("abstractmcmc.jl") #################### # Sampler wrappers # @@ -262,13 +256,13 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) dicts = map(ts) do t # In general getparams returns a dict of VarName => values. We need to also # split it up into constituent elements using - # `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl + # `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl # won't understand it. vals = getparams(model, t) nms_and_vs = if isempty(vals) Tuple{VarName,Any}[] else - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) mapreduce(collect, vcat, iters) end nms = map(first, nms_and_vs) @@ -315,11 +309,10 @@ end getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. -# This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{Transition,AbstractVarInfo}}, - model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, + ts::Vector{<:Transition}, + model::DynamicPPL.Model, + spl::AbstractSampler, state, chain_type::Type{MCMCChains.Chains}; save_state=false, @@ -378,11 +371,10 @@ function AbstractMCMC.bundle_samples( return sort_chain ? sort(chain) : chain end -# This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector{<:Union{Transition,AbstractVarInfo}}, - model::AbstractModel, - spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior,RepeatSampler}, + ts::Vector{<:Transition}, + model::DynamicPPL.Model, + spl::AbstractSampler, state, chain_type::Type{Vector{NamedTuple}}; kwargs..., @@ -423,7 +415,7 @@ function group_varnames_by_symbol(vns) return d end -function save(c::MCMCChains.Chains, spl::Sampler, model, vi, samples) +function save(c::MCMCChains.Chains, spl::AbstractSampler, model, vi, samples) nt = NamedTuple{(:sampler, :model, :vi, :samples)}((spl, model, deepcopy(vi), samples)) return setinfo(c, merge(nt, c.info)) end @@ -442,18 +434,12 @@ include("sghmc.jl") include("emcee.jl") include("prior.jl") -################################################# -# Generic AbstractMCMC methods dispatch # -################################################# - -include("abstractmcmc.jl") - ################ # Typing tools # ################ function DynamicPPL.get_matching_type( - spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV} + spl::Union{PG,SMC}, vi, ::Type{TV} ) where {T,N,TV<:Array{T,N}} return Array{T,N} end diff --git a/src/mcmc/abstractmcmc.jl b/src/mcmc/abstractmcmc.jl index edd5638854..0f20925762 100644 --- a/src/mcmc/abstractmcmc.jl +++ b/src/mcmc/abstractmcmc.jl @@ -1,39 +1,101 @@ # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. function _check_model(model::DynamicPPL.Model) - # TODO(DPPL0.38/penelopeysm): use InitContext - spl_model = DynamicPPL.contextualize(model, DynamicPPL.SamplingContext(model.context)) - return DynamicPPL.check_model(spl_model, VarInfo(); error_on_failure=true) + new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + return DynamicPPL.check_model(new_model, VarInfo(); error_on_failure=true) end -function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm) +function _check_model(model::DynamicPPL.Model, ::AbstractSampler) return _check_model(model) end +""" + Turing.Inference.init_strategy(spl::AbstractSampler) + +Get the default initialization strategy for a given sampler `spl`, i.e. how initial +parameters for sampling are chosen if not specified by the user. By default, this is +`InitFromPrior()`, which samples initial parameters from the prior distribution. +""" +init_strategy(::AbstractSampler) = DynamicPPL.InitFromPrior() + +""" + _convert_initial_params(initial_params) + +Convert `initial_params` to a `DynamicPPl.AbstractInitStrategy` if it is not already one, or +throw a useful error message. +""" +_convert_initial_params(initial_params::DynamicPPL.AbstractInitStrategy) = initial_params +function _convert_initial_params(nt::NamedTuple) + @info "Using a NamedTuple for `initial_params` will be deprecated in a future release. Please use `InitFromParams(namedtuple)` instead." + return DynamicPPL.InitFromParams(nt) +end +function _convert_initial_params(d::AbstractDict{<:VarName}) + @info "Using a Dict for `initial_params` will be deprecated in a future release. Please use `InitFromParams(dict)` instead." + return DynamicPPL.InitFromParams(d) +end +function _convert_initial_params(::AbstractVector{<:Real}) + errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or ideally a `DynamicPPL.AbstractInitStrategy`. Using a vector of parameters for `initial_params` is no longer supported. Please see https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters for details on how to update your code." + throw(ArgumentError(errmsg)) +end +function _convert_initial_params(@nospecialize(_::Any)) + errmsg = "`initial_params` must be a `NamedTuple`, an `AbstractDict{<:VarName}`, or a `DynamicPPL.AbstractInitStrategy`." + throw(ArgumentError(errmsg)) +end + +""" + default_varinfo(rng, model, sampler) + +Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns a NTVarInfo (i.e. 'typed varinfo'). +""" +function default_varinfo( + rng::Random.AbstractRNG, model::DynamicPPL.Model, ::AbstractSampler +) + # Note that in `AbstractMCMC.step`, the values in the varinfo returned here are + # immediately overwritten by a subsequent call to `init!!`. The reason why we + # _do_ create a varinfo with parameters here (as opposed to simply returning + # an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty + # typed VarInfo would fail. This can happen if two VarNames have different types + # but share the same symbol (e.g. `x.a` and `x.b`). + # TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments + # and return an empty VarInfo instead. + return DynamicPPL.typed_varinfo(VarInfo(rng, model)) +end + ######################################### # Default definitions for the interface # ######################################### function AbstractMCMC.sample( - model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs... + model::DynamicPPL.Model, spl::AbstractSampler, N::Integer; kwargs... ) - return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...) + return AbstractMCMC.sample(Random.default_rng(), model, spl, N; kwargs...) end function AbstractMCMC.sample( rng::AbstractRNG, - model::AbstractModel, - alg::InferenceAlgorithm, + model::DynamicPPL.Model, + spl::AbstractSampler, N::Integer; + initial_params=init_strategy(spl), check_model::Bool=true, + chain_type=DEFAULT_CHAIN_TYPE, kwargs..., ) - check_model && _check_model(model, alg) - return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...) + check_model && _check_model(model, spl) + return AbstractMCMC.mcmcsample( + rng, + model, + spl, + N; + initial_params=_convert_initial_params(initial_params), + chain_type, + kwargs..., + ) end function AbstractMCMC.sample( - model::AbstractModel, - alg::InferenceAlgorithm, + model::DynamicPPL.Model, + alg::AbstractSampler, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; @@ -45,15 +107,74 @@ function AbstractMCMC.sample( end function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - alg::InferenceAlgorithm, + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractSampler, ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; + chain_type=DEFAULT_CHAIN_TYPE, check_model::Bool=true, + initial_params=fill(init_strategy(spl), n_chains), kwargs..., ) - check_model && _check_model(model, alg) - return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...) + check_model && _check_model(model, spl) + if !(initial_params isa AbstractVector) || length(initial_params) != n_chains + errmsg = "`initial_params` must be an AbstractVector of length `n_chains`; one element per chain" + throw(ArgumentError(errmsg)) + end + return AbstractMCMC.mcmcsample( + rng, + model, + spl, + ensemble, + N, + n_chains; + chain_type, + initial_params=map(_convert_initial_params, initial_params), + kwargs..., + ) +end + +""" + loadstate(chain::MCMCChains.Chains) + +Load the final state of the sampler from a `MCMCChains.Chains` object. + +To save the final state of the sampler, you must use `sample(...; save_state=true)`. If this +argument was not used during sampling, calling `loadstate` will throw an error. +""" +function loadstate(chain::MCMCChains.Chains) + if !haskey(chain.info, :samplerstate) + throw( + ArgumentError( + "the chain object does not contain the final state of the sampler; to save the final state you must sample with `save_state=true`", + ), + ) + end + return chain.info[:samplerstate] +end + +# TODO(penelopeysm): Remove initialstep and generalise MCMC sampling procedures +function initialstep end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + spl::AbstractSampler; + initial_params, + kwargs..., +) + # Generate the default varinfo. Note that any parameters inside this varinfo + # will be immediately overwritten by the next call to `init!!`. + vi = default_varinfo(rng, model, spl) + + # Fill it with initial parameters. Note that, if `InitFromParams` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = DynamicPPL.init!!(rng, model, vi, initial_params) + + # Call the actual function that does the first step. + return initialstep(rng, model, spl, vi; initial_params, kwargs...) end diff --git a/src/mcmc/algorithm.jl b/src/mcmc/algorithm.jl deleted file mode 100644 index d45ae0d4a7..0000000000 --- a/src/mcmc/algorithm.jl +++ /dev/null @@ -1,14 +0,0 @@ -""" - InferenceAlgorithm - -Abstract type representing an inference algorithm in Turing. Note that this is -not the same as an `AbstractSampler`: the latter is what defines the necessary -methods for actually sampling. - -To create an `AbstractSampler`, the `InferenceAlgorithm` needs to be wrapped in -`DynamicPPL.Sampler`. If `sample()` is called with an `InferenceAlgorithm`, -this wrapping occurs automatically. -""" -abstract type InferenceAlgorithm end - -DynamicPPL.default_chain_type(sampler::Sampler{<:InferenceAlgorithm}) = MCMCChains.Chains diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index 98ed20b40e..226536aca2 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -13,7 +13,7 @@ Foreman-Mackey, D., Hogg, D. W., Lang, D., & Goodman, J. (2013). emcee: The MCMC Hammer. Publications of the Astronomical Society of the Pacific, 125 (925), 306. https://doi.org/10.1086/670067 """ -struct Emcee{E<:AMH.Ensemble} <: InferenceAlgorithm +struct Emcee{E<:AMH.Ensemble} <: AbstractSampler ensemble::E end @@ -31,37 +31,37 @@ struct EmceeState{V<:AbstractVarInfo,S} states::S end -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:Emcee}; - resume_from=nothing, - initial_params=nothing, - kwargs..., +# Utility function to tetrieve the number of walkers +_get_n_walkers(e::Emcee) = e.ensemble.n_walkers + +# Because Emcee expects n_walkers initialisations, we need to override this +function Turing.Inference.init_strategy(spl::Emcee) + return fill(DynamicPPL.InitFromPrior(), _get_n_walkers(spl)) +end +# We also have to explicitly allow this or else it will error... +function Turing.Inference._convert_initial_params( + x::AbstractVector{<:DynamicPPL.AbstractInitStrategy} ) - if resume_from !== nothing - state = loadstate(resume_from) - return AbstractMCMC.step(rng, model, spl, state; kwargs...) - end + return x +end +function AbstractMCMC.step( + rng::Random.AbstractRNG, model::Model, spl::Emcee; initial_params, kwargs... +) # Sample from the prior - n = spl.alg.ensemble.n_walkers - vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n] + n = _get_n_walkers(spl) + vis = [VarInfo(rng, model) for _ in 1:n] # Update the parameters if provided. - if initial_params !== nothing - length(initial_params) == n || - throw(ArgumentError("initial parameters have to be specified for each walker")) - vis = map(vis, initial_params) do vi, init - # TODO(DPPL0.38/penelopeysm) This whole thing can be replaced with init!! - vi = DynamicPPL.initialize_parameters!!(vi, init, model) - - # Update log joint probability. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, SampleFromPrior(), model.context) - ) - last(DynamicPPL.evaluate!!(spl_model, vi)) - end + if !( + initial_params isa AbstractVector{<:DynamicPPL.AbstractInitStrategy} && + length(initial_params) == n + ) + err_msg = "initial_params for `Emcee` must be a vector of `DynamicPPL.AbstractInitStrategy`, with length equal to the number of walkers ($n)" + throw(ArgumentError(err_msg)) + end + vis = map(vis, initial_params) do vi, strategy + last(DynamicPPL.init!!(rng, model, vi, strategy)) end # Compute initial transition and states. @@ -80,7 +80,7 @@ function AbstractMCMC.step( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:Emcee}, state::EmceeState; kwargs... + rng::AbstractRNG, model::Model, spl::Emcee, state::EmceeState; kwargs... ) # Generate a log joint function. vi = state.vi @@ -92,7 +92,7 @@ function AbstractMCMC.step( ) # Compute the next states. - t, states = AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states) + t, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states) # Compute the next transition and state. transition = map(states) do _state @@ -107,7 +107,7 @@ end function AbstractMCMC.bundle_samples( samples::Vector{<:Vector}, model::AbstractModel, - spl::Sampler{<:Emcee}, + spl::Emcee, state::EmceeState, chain_type::Type{MCMCChains.Chains}; save_state=false, diff --git a/src/mcmc/ess.jl b/src/mcmc/ess.jl index 3afd91607c..18dbfa4171 100644 --- a/src/mcmc/ess.jl +++ b/src/mcmc/ess.jl @@ -20,11 +20,11 @@ Mean │ 1 │ m │ 0.824853 │ ``` """ -struct ESS <: InferenceAlgorithm end +struct ESS <: AbstractSampler end # always accept in the first step -function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, ::ESS, vi::AbstractVarInfo; kwargs... ) for vn in keys(vi) dist = getdist(vi, vn) @@ -35,7 +35,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, ::ESS, vi::AbstractVarInfo; kwargs... ) # obtain previous sample f = vi[:] @@ -82,23 +82,8 @@ EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true # Only define out-of-place sampling function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) - varinfo = p.varinfo - # TODO: Surely there's a better way of doing this now that we have `SamplingContext`? - # TODO(DPPL0.38/penelopeysm): This can be replaced with `init!!(p.model, - # p.varinfo, PriorInit())` after TuringLang/DynamicPPL.jl#984. The reason - # why we had to use the 'del' flag before this was because - # SampleFromPrior() wouldn't overwrite existing variables. - # The main problem I'm rather unsure about is ESS-within-Gibbs. The - # current implementation I think makes sure to only resample the variables - # that 'belong' to the current ESS sampler. InitContext on the other hand - # would resample all variables in the model (??) Need to think about this - # carefully. - vns = keys(varinfo) - for vn in vns - set_flag!(varinfo, vn, "del") - end - p.model(rng, varinfo) - return varinfo[:] + _, vi = DynamicPPL.init!!(rng, p.model, p.varinfo, DynamicPPL.InitFromPrior()) + return vi[:] end # Mean of prior distribution @@ -118,3 +103,18 @@ struct ESSLikelihood{L<:DynamicPPL.LogDensityFunction} end (ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f) + +# Needed for method ambiguity resolution, even though this method is never going to be +# called in practice. This just shuts Aqua up. +# TODO(penelopeysm): Remove this when the default `step(rng, ::DynamicPPL.Model, +# ::AbstractSampler) method in `src/mcmc/abstractmcmc.jl` is removed. +function AbstractMCMC.step( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::EllipticalSliceSampling.ESS; + kwargs..., +) + return error( + "This method is not implemented! If you want to use the ESS sampler in Turing.jl, please use `Turing.ESS()` instead. If you want the default behaviour in EllipticalSliceSampling.jl, wrap your model in a different subtype of `AbstractMCMC.AbstractModel`, and then implement the necessary EllipticalSliceSampling.jl methods on it.", + ) +end diff --git a/src/mcmc/external_sampler.jl b/src/mcmc/external_sampler.jl index af31e0243f..f8673f6eef 100644 --- a/src/mcmc/external_sampler.jl +++ b/src/mcmc/external_sampler.jl @@ -1,7 +1,8 @@ """ ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} -Represents a sampler that is not an implementation of `InferenceAlgorithm`. +Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng, +::DynamicPPL.Model, spl)`. The `Unconstrained` type-parameter is to indicate whether the sampler requires unconstrained space. @@ -10,25 +11,49 @@ $(TYPEDFIELDS) # Turing.jl's interface for external samplers -When implementing a new `MySampler <: AbstractSampler`, -`MySampler` must first and foremost conform to the `AbstractMCMC` interface to work with Turing.jl's `externalsampler` function. -In particular, it must implement: +If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl +models, there are two options: -- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl) -- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the parameters from the transition returned by your sampler (i.e., the first return value of `step`). - There is a default implementation for this method, which is to return `external_transition.θ`. +1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the + most powerful option and is what Turing.jl's in-house samplers do. Implementing this + means that you can directly call `sample(model, MySampler(), N)`. + +2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This + struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step` + implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use + this with Turing.jl, you will need to wrap your sampler: `sample(model, + externalsampler(MySampler()), N)`. + +This section describes the latter. + +`MySampler` must implement the following methods: + +- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is + documented in AbstractMCMC.jl) +- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the + parameters from the transition returned by your sampler (i.e., the first return value of + `step`). There is a default implementation for this method, which is to return + `external_transition.θ`. !!! note - In a future breaking release of Turing, this is likely to change to `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. `Turing.Inference.getparams` is technically an internal method, so the aim here is to unify the interface for samplers at a higher level. + In a future breaking release of Turing, this is likely to change to + `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method. + `Turing.Inference.getparams` is technically an internal method, so the aim here is to + unify the interface for samplers at a higher level. -There are a few more optional functions which you can implement to improve the integration with Turing.jl: +There are a few more optional functions which you can implement to improve the integration +with Turing.jl: -- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`. +- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as + a component in Turing's Gibbs sampler, you should make this evaluate to `true`. -- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space. +- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires + unconstrained space, you should return `true`. This tells Turing to perform linking on the + VarInfo before evaluation, and ensures that the parameter values passed to your sampler + will always be in unconstrained (Euclidean) space. """ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: - InferenceAlgorithm + AbstractSampler "the sampler to wrap" sampler::S "the automatic differentiation (AD) backend to use" @@ -115,36 +140,39 @@ getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.pa function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler_wrapper::Sampler{<:ExternalSampler}; + sampler_wrapper::ExternalSampler; initial_state=nothing, - initial_params=nothing, + initial_params, # passed through from sample kwargs..., ) - alg = sampler_wrapper.alg - sampler = alg.sampler + sampler = sampler_wrapper.sampler # Initialise varinfo with initial params and link the varinfo if needed. varinfo = DynamicPPL.VarInfo(model) - if requires_unconstrained_space(alg) - if initial_params !== nothing - # If we have initial parameters, we need to set the varinfo before linking. - varinfo = DynamicPPL.link(DynamicPPL.unflatten(varinfo, initial_params), model) - # Extract initial parameters in unconstrained space. - initial_params = varinfo[:] - else - varinfo = DynamicPPL.link(varinfo, model) - end + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params) + + if requires_unconstrained_space(sampler_wrapper) + varinfo = DynamicPPL.link(varinfo, model) end + # We need to extract the vectorised initial_params, because the later call to + # AbstractMCMC.step only sees a `LogDensityModel` which expects `initial_params` + # to be a vector. + initial_params_vector = varinfo[:] + # Construct LogDensityFunction f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, varinfo; adtype=alg.adtype + model, DynamicPPL.getlogjoint_internal, varinfo; adtype=sampler_wrapper.adtype ) # Then just call `AbstractMCMC.step` with the right arguments. if initial_state === nothing transition_inner, state_inner = AbstractMCMC.step( - rng, AbstractMCMC.LogDensityModel(f), sampler; initial_params, kwargs... + rng, + AbstractMCMC.LogDensityModel(f), + sampler; + initial_params=initial_params_vector, + kwargs..., ) else transition_inner, state_inner = AbstractMCMC.step( @@ -152,7 +180,7 @@ function AbstractMCMC.step( AbstractMCMC.LogDensityModel(f), sampler, initial_state; - initial_params, + initial_params=initial_params_vector, kwargs..., ) end @@ -170,11 +198,11 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler_wrapper::Sampler{<:ExternalSampler}, + sampler_wrapper::ExternalSampler, state::TuringState; kwargs..., ) - sampler = sampler_wrapper.alg.sampler + sampler = sampler_wrapper.sampler f = state.ldf # Then just call `AdvancedMCMC.step` with the right arguments. diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 17bc881535..7d15829a3c 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -1,12 +1,11 @@ """ - isgibbscomponent(alg::Union{InferenceAlgorithm, AbstractMCMC.AbstractSampler}) + isgibbscomponent(spl::AbstractSampler) -Return a boolean indicating whether `alg` is a valid component for a Gibbs sampler. +Return a boolean indicating whether `spl` is a valid component for a Gibbs sampler. Defaults to `false` if no method has been defined for a particular algorithm type. """ -isgibbscomponent(::InferenceAlgorithm) = false -isgibbscomponent(spl::Sampler) = isgibbscomponent(spl.alg) +isgibbscomponent(::AbstractSampler) = false isgibbscomponent(::ESS) = true isgibbscomponent(::HMC) = true @@ -47,7 +46,7 @@ A context used in the implementation of the Turing.jl Gibbs sampler. There will be one `GibbsContext` for each iteration of a component sampler. `target_varnames` is a a tuple of `VarName`s that the current component sampler -is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume` +is sampling. For those `VarName`s, `GibbsContext` will just pass `tilde_assume!!` calls to its child context. For other variables, their values will be fixed to the values they have in `global_varinfo`. @@ -140,7 +139,9 @@ function is_target_varname(context::GibbsContext, vns::AbstractArray{<:VarName}) end # Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) +function DynamicPPL.tilde_assume!!( + context::GibbsContext, right::Distribution, vn::VarName, vi::DynamicPPL.AbstractVarInfo +) child_context = DynamicPPL.childcontext(context) # Note that `child_context` may contain `PrefixContext`s -- in which case @@ -175,47 +176,7 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) return if is_target_varname(context, vn) # Fall back to the default behavior. - DynamicPPL.tilde_assume(child_context, right, vn, vi) - elseif has_conditioned_gibbs(context, vn) - # This branch means that a different sampler is supposed to handle this - # variable. From the perspective of this sampler, this variable is - # conditioned on, so we can just treat it as an observation. - # The only catch is that the value that we need is to be obtained from - # the global VarInfo (since the local VarInfo has no knowledge of it). - # Note that tilde_observe!! will trigger resampling in particle methods - # for variables that are handled by other Gibbs component samplers. - val = get_conditioned_gibbs(context, vn) - DynamicPPL.tilde_observe!!(child_context, right, val, vn, vi) - else - # If the varname has not been conditioned on, nor is it a target variable, its - # presumably a new variable that should be sampled from its prior. We need to add - # this new variable to the global `varinfo` of the context, but not to the local one - # being used by the current sampler. - value, new_global_vi = DynamicPPL.tilde_assume( - child_context, - DynamicPPL.SampleFromPrior(), - right, - vn, - get_global_varinfo(context), - ) - set_global_varinfo!(context, new_global_vi) - value, vi - end -end - -# As above but with an RNG. -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi -) - # See comment in the above, rng-less version of this method for an explanation. - child_context = DynamicPPL.childcontext(context) - vn, child_context = DynamicPPL.prefix_and_strip_contexts(child_context, vn) - - return if is_target_varname(context, vn) - # This branch means that that `sampler` is supposed to handle - # this variable. We can thus use its default behaviour, with - # the 'local' sampler-specific VarInfo. - DynamicPPL.tilde_assume(rng, child_context, sampler, right, vn, vi) + DynamicPPL.tilde_assume!!(child_context, right, vn, vi) elseif has_conditioned_gibbs(context, vn) # This branch means that a different sampler is supposed to handle this # variable. From the perspective of this sampler, this variable is @@ -231,10 +192,10 @@ function DynamicPPL.tilde_assume( # presumably a new variable that should be sampled from its prior. We need to add # this new variable to the global `varinfo` of the context, but not to the local one # being used by the current sampler. - value, new_global_vi = DynamicPPL.tilde_assume( - rng, - child_context, - DynamicPPL.SampleFromPrior(), + value, new_global_vi = DynamicPPL.tilde_assume!!( + # child_context might be a PrefixContext so we have to be careful to not + # overwrite it. + DynamicPPL.setleafcontext(child_context, DynamicPPL.InitContext()), right, vn, get_global_varinfo(context), @@ -275,9 +236,6 @@ function make_conditional( return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner end -wrap_in_sampler(x::AbstractMCMC.AbstractSampler) = x -wrap_in_sampler(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) - to_varname(x::VarName) = x to_varname(x::Symbol) = VarName{x}() to_varname_list(x::Union{VarName,Symbol}) = [to_varname(x)] @@ -307,10 +265,8 @@ Gibbs((@varname(x), :y) => NUTS(), :z => MH()) # Fields $(TYPEDFIELDS) """ -struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: - InferenceAlgorithm - # TODO(mhauru) Revisit whether A should have a fixed element type once - # InferenceAlgorithm/Sampler types have been cleaned up. +struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: AbstractSampler + # TODO(mhauru) Revisit whether A should have a fixed element type. "varnames representing variables for each sampler" varnames::V "samplers for each entry in `varnames`" @@ -328,7 +284,7 @@ struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <: end end - samplers = tuple(map(wrap_in_sampler, samplers)...) + samplers = tuple(samplers...) varnames = tuple(map(to_varname_list, varnames)...) return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers) end @@ -352,32 +308,21 @@ This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated h support calling both step and step_warmup as the initial step. DynamicPPL initialstep is incompatible with step_warmup. """ -function initial_varinfo(rng, model, spl, initial_params) - vi = DynamicPPL.default_varinfo(rng, model, spl) - - # Update the parameters if provided. - if initial_params !== nothing - vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(DynamicPPL.evaluate!!(model, vi)) - end +function initial_varinfo(rng, model, spl, initial_params::DynamicPPL.AbstractInitStrategy) + vi = Turing.Inference.default_varinfo(rng, model, spl) + _, vi = DynamicPPL.init!!(rng, model, vi, initial_params) return vi end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params=nothing, + spl::Gibbs; + initial_params=Turing.Inference.init_strategy(spl), kwargs..., ) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) vi, states = gibbs_initialstep_recursive( @@ -396,13 +341,12 @@ end function AbstractMCMC.step_warmup( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}; - initial_params=nothing, + spl::Gibbs; + initial_params=Turing.Inference.init_strategy(spl), kwargs..., ) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers vi = initial_varinfo(rng, model, spl, initial_params) vi, states = gibbs_initialstep_recursive( @@ -434,7 +378,7 @@ function gibbs_initialstep_recursive( samplers, vi, states=(); - initial_params=nothing, + initial_params, kwargs..., ) # End recursion @@ -445,13 +389,6 @@ function gibbs_initialstep_recursive( varnames, varname_vecs_tail... = varname_vecs sampler, samplers_tail... = samplers - # Get the initial values for this component sampler. - initial_params_local = if initial_params === nothing - nothing - else - DynamicPPL.subset(vi, varnames)[:] - end - # Construct the conditioned model. conditioned_model, context = make_conditional(model, varnames, vi) @@ -462,7 +399,7 @@ function gibbs_initialstep_recursive( sampler; # FIXME: This will cause issues if the sampler expects initial params in unconstrained space. # This is not the case for any samplers in Turing.jl, but will be for external samplers, etc. - initial_params=initial_params_local, + initial_params=initial_params, kwargs..., ) new_vi_local = get_varinfo(new_state) @@ -489,14 +426,13 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, + spl::Gibbs, state::GibbsState; kwargs..., ) vi = get_varinfo(state) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) @@ -509,14 +445,13 @@ end function AbstractMCMC.step_warmup( rng::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:Gibbs}, + spl::Gibbs, state::GibbsState; kwargs..., ) vi = get_varinfo(state) - alg = spl.alg - varnames = alg.varnames - samplers = alg.samplers + varnames = spl.varnames + samplers = spl.samplers states = state.states @assert length(samplers) == length(state.states) @@ -527,7 +462,7 @@ function AbstractMCMC.step_warmup( end """ - setparams_varinfo!!(model, sampler::Sampler, state, params::AbstractVarInfo) + setparams_varinfo!!(model, sampler::AbstractSampler, state, params::AbstractVarInfo) A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameters, takes an `AbstractVarInfo` object. Also takes the `sampler` as an argument. By default, falls back to @@ -536,12 +471,14 @@ A lot like AbstractMCMC.setparams!!, but instead of taking a vector of parameter `model` is typically a `DynamicPPL.Model`, but can also be e.g. an `AbstractMCMC.LogDensityModel`. """ -function setparams_varinfo!!(model, ::Sampler, state, params::AbstractVarInfo) +function setparams_varinfo!!( + model::DynamicPPL.Model, ::AbstractSampler, state, params::AbstractVarInfo +) return AbstractMCMC.setparams!!(model, state, params[:]) end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler{<:MH}, state::MHState, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::MH, state::MHState, params::AbstractVarInfo ) # Re-evaluate to update the logprob. new_vi = last(DynamicPPL.evaluate!!(model, params)) @@ -549,10 +486,7 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:ESS}, - state::AbstractVarInfo, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::ESS, state::AbstractVarInfo, params::AbstractVarInfo ) # The state is already a VarInfo, so we can just return `params`, but first we need to # update its logprob. @@ -561,24 +495,21 @@ end function setparams_varinfo!!( model::DynamicPPL.Model, - sampler::Sampler{<:ExternalSampler}, + sampler::ExternalSampler, state::TuringState, params::AbstractVarInfo, ) logdensity = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.alg.adtype + model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.adtype ) - new_inner_state = setparams_varinfo!!( - AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params + new_inner_state = AbstractMCMC.setparams!!( + AbstractMCMC.LogDensityModel(logdensity), state.state, params[:] ) return TuringState(new_inner_state, params, logdensity) end function setparams_varinfo!!( - model::DynamicPPL.Model, - sampler::Sampler{<:Hamiltonian}, - state::HMCState, - params::AbstractVarInfo, + model::DynamicPPL.Model, sampler::Hamiltonian, state::HMCState, params::AbstractVarInfo ) θ_new = params[:] hamiltonian = get_hamiltonian(model, sampler, params, state, length(θ_new)) @@ -592,7 +523,7 @@ function setparams_varinfo!!( end function setparams_varinfo!!( - model::DynamicPPL.Model, sampler::Sampler{<:PG}, state::PGState, params::AbstractVarInfo + model::DynamicPPL.Model, sampler::PG, state::PGState, params::AbstractVarInfo ) return PGState(params, state.rng) end @@ -606,22 +537,22 @@ variables, and one might need it to be linked while the other doesn't. """ function match_linking!!(varinfo_local, prev_state_local, model) prev_varinfo_local = get_varinfo(prev_state_local) - was_linked = DynamicPPL.istrans(prev_varinfo_local) - is_linked = DynamicPPL.istrans(varinfo_local) + was_linked = DynamicPPL.is_transformed(prev_varinfo_local) + is_linked = DynamicPPL.is_transformed(varinfo_local) if was_linked && !is_linked varinfo_local = DynamicPPL.link!!(varinfo_local, model) elseif !was_linked && is_linked varinfo_local = DynamicPPL.invlink!!(varinfo_local, model) end # TODO(mhauru) The above might run into trouble if some variables are linked and others - # are not. `istrans(varinfo)` returns an `all` over the individual variables. This could + # are not. `is_transformed(varinfo)` returns an `all` over the individual variables. This could # especially be a problem with dynamic models, where new variables may get introduced, # but also in cases where component samplers have partial overlap in their target # variables. The below is how I would like to implement this, but DynamicPPL at this # time does not support linking individual variables selected by `VarName`. It soon # should though, so come back to this. # Issue ref: https://github.com/TuringLang/Turing.jl/issues/2401 - # prev_links_dict = Dict(vn => DynamicPPL.istrans(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local)) + # prev_links_dict = Dict(vn => DynamicPPL.is_transformed(prev_varinfo_local, vn) for vn in keys(prev_varinfo_local)) # any_linked = any(values(prev_links_dict)) # for vn in keys(varinfo_local) # was_linked = if haskey(prev_varinfo_local, vn) @@ -631,7 +562,7 @@ function match_linking!!(varinfo_local, prev_state_local, model) # # of the variables of the old state were linked. # any_linked # end - # is_linked = DynamicPPL.istrans(varinfo_local, vn) + # is_linked = DynamicPPL.is_transformed(varinfo_local, vn) # if was_linked && !is_linked # varinfo_local = DynamicPPL.invlink!!(varinfo_local, vn) # elseif !was_linked && is_linked diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d776a68a86..101847b75c 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -1,4 +1,4 @@ -abstract type Hamiltonian <: InferenceAlgorithm end +abstract type Hamiltonian <: AbstractSampler end abstract type StaticHamiltonian <: Hamiltonian end abstract type AdaptiveHamiltonian <: Hamiltonian end @@ -80,24 +80,26 @@ function HMC( return HMC(ϵ, n_leapfrog, metricT; adtype=adtype) end -DynamicPPL.initialsampler(::Sampler{<:Hamiltonian}) = SampleFromUniform() +Turing.Inference.init_strategy(::Hamiltonian) = DynamicPPL.InitFromUniform() # Handle setting `nadapts` and `discard_initial` function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::Sampler{<:AdaptiveHamiltonian}, + sampler::AdaptiveHamiltonian, N::Integer; - chain_type=DynamicPPL.default_chain_type(sampler), - resume_from=nothing, - initial_state=DynamicPPL.loadstate(resume_from), + check_model=true, + chain_type=DEFAULT_CHAIN_TYPE, + initial_params=Turing.Inference.init_strategy(sampler), + initial_state=nothing, progress=PROGRESS[], - nadapts=sampler.alg.n_adapts, + nadapts=sampler.n_adapts, discard_adapt=true, discard_initial=-1, kwargs..., ) - if resume_from === nothing + check_model && _check_model(model, sampler) + if initial_state === nothing # If `nadapts` is `-1`, then the user called a convenience # constructor like `NUTS()` or `NUTS(0.65)`, # and we should set a default for them. @@ -124,6 +126,7 @@ function AbstractMCMC.sample( progress=progress, nadapts=_nadapts, discard_initial=_discard_initial, + initial_params=initial_params, kwargs..., ) else @@ -138,6 +141,7 @@ function AbstractMCMC.sample( nadapts=0, discard_adapt=false, discard_initial=0, + initial_params=initial_params, kwargs..., ) end @@ -147,7 +151,8 @@ function find_initial_params( rng::Random.AbstractRNG, model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo, - hamiltonian::AHMC.Hamiltonian; + hamiltonian::AHMC.Hamiltonian, + init_strategy::DynamicPPL.AbstractInitStrategy; max_attempts::Int=1000, ) varinfo = deepcopy(varinfo) # Don't mutate @@ -158,15 +163,10 @@ function find_initial_params( isfinite(z) && return varinfo, z attempts == 10 && - @warn "failed to find valid initial parameters in $(attempts) tries; consider providing explicit initial parameters using the `initial_params` keyword" + @warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword" # Resample and try again. - # NOTE: varinfo has to be linked to make sure this samples in unconstrained space - varinfo = last( - DynamicPPL.evaluate_and_sample!!( - rng, model, varinfo, DynamicPPL.SampleFromUniform() - ), - ) + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy) end # if we failed to find valid initial parameters, error @@ -175,12 +175,14 @@ function find_initial_params( ) end -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:Hamiltonian}, + model::DynamicPPL.Model, + spl::Hamiltonian, vi_original::AbstractVarInfo; - initial_params=nothing, + # the initial_params kwarg is always passed on from sample(), cf. DynamicPPL + # src/sampler.jl, so we don't need to provide a default value here + initial_params::DynamicPPL.AbstractInitStrategy, nadapts=0, verbose::Bool=true, kwargs..., @@ -192,65 +194,47 @@ function DynamicPPL.initialstep( theta = vi[:] # Create a Hamiltonian. - metricT = getmetricT(spl.alg) + metricT = getmetricT(spl) metric = metricT(length(theta)) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) hamiltonian = AHMC.Hamiltonian(metric, lp_func, lp_grad_func) - # If no initial parameters are provided, resample until the log probability - # and its gradient are finite. Otherwise, just use the existing parameters. - vi, z = if initial_params === nothing - find_initial_params(rng, model, vi, hamiltonian) - else - vi, AHMC.phasepoint(rng, theta, hamiltonian) - end + # Note that there is already one round of 'initialisation' before we reach this step, + # inside DynamicPPL's `AbstractMCMC.step` implementation. That leads to a possible issue + # that this `find_initial_params` function might override the parameters set by the + # user. + # Luckily for us, `find_initial_params` always checks if the logp and its gradient are + # finite. If it is already finite with the params inside the current `vi`, it doesn't + # attempt to find new ones. This means that the parameters passed to `sample()` will be + # respected instead of being overridden here. + vi, z = find_initial_params(rng, model, vi, hamiltonian, initial_params) theta = vi[:] # Find good eps if not provided one - if iszero(spl.alg.ϵ) + if iszero(spl.ϵ) ϵ = AHMC.find_good_stepsize(rng, hamiltonian, theta) verbose && @info "Found initial step size" ϵ else - ϵ = spl.alg.ϵ + ϵ = spl.ϵ end + # Generate a kernel and adaptor. + kernel = make_ahmc_kernel(spl, ϵ) + adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ) - # Generate a kernel. - kernel = make_ahmc_kernel(spl.alg, ϵ) - - # Create initial transition and state. - # Already perform one step since otherwise we don't get any statistics. - t = AHMC.transition(rng, hamiltonian, kernel, z) - - # Adaptation - adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ) - if spl.alg isa AdaptiveHamiltonian - hamiltonian, kernel, _ = AHMC.adapt!( - hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate - ) - end - - # Update VarInfo parameters based on acceptance - new_params = if t.stat.is_accept - t.z.θ - else - theta - end - vi = DynamicPPL.unflatten(vi, new_params) - - transition = Transition(model, vi, t) - state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor) + transition = Transition(model, vi, NamedTuple()) + state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor) return transition, state end function AbstractMCMC.step( rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:Hamiltonian}, + model::DynamicPPL.Model, + spl::Hamiltonian, state::HMCState; nadapts=0, kwargs..., @@ -265,7 +249,7 @@ function AbstractMCMC.step( # Adaptation i = state.i + 1 - if spl.alg isa AdaptiveHamiltonian + if spl isa AdaptiveHamiltonian hamiltonian, kernel, _ = AHMC.adapt!( hamiltonian, state.kernel, @@ -295,7 +279,7 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) ldf = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) lp_grad_func = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ldf) @@ -460,17 +444,17 @@ end ##### HMC core functions ##### -getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ -getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor) +getstepsize(sampler::Hamiltonian, state) = sampler.ϵ +getstepsize(sampler::AdaptiveHamiltonian, state) = AHMC.getϵ(state.adaptor) function getstepsize( - sampler::Sampler{<:AdaptiveHamiltonian}, + sampler::AdaptiveHamiltonian, state::HMCState{TV,TKernel,THam,PhType,AHMC.Adaptation.NoAdaptation}, ) where {TV,TKernel,THam,PhType} return state.kernel.τ.integrator.ϵ end -gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim) -function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) +gen_metric(dim::Int, spl::Hamiltonian, state) = AHMC.UnitEuclideanMetric(dim) +function gen_metric(dim::Int, spl::AdaptiveHamiltonian, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) end @@ -492,15 +476,6 @@ function make_ahmc_kernel(alg::NUTS, ϵ) ) end -#### -#### Compiler interface, i.e. tilde operators. -#### -function DynamicPPL.assume( - rng, ::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi -) - return DynamicPPL.assume(dist, vn, vi) -end - #### #### Default HMC stepsize and mass matrix adaptor #### diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 319e424fcb..88f915d1f9 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -24,35 +24,47 @@ end sample(gdemo([1.5, 2]), IS(), 1000) ``` """ -struct IS <: InferenceAlgorithm end +struct IS <: AbstractSampler end -DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler - -function DynamicPPL.initialstep( - rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs... +function Turing.Inference.initialstep( + rng::AbstractRNG, model::Model, spl::IS, vi::AbstractVarInfo; kwargs... ) return Transition(model, vi, nothing), nothing end function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs... + rng::Random.AbstractRNG, model::Model, spl::IS, ::Nothing; kwargs... ) - vi = VarInfo(rng, model, spl) + model = DynamicPPL.setleafcontext(model, ISContext(rng)) + _, vi = DynamicPPL.evaluate!!(model, DynamicPPL.VarInfo()) + vi = DynamicPPL.typed_varinfo(vi) return Transition(model, vi, nothing), nothing end # Calculate evidence. -function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state) +function getlogevidence(samples::Vector{<:Transition}, ::IS, state) return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples)) end -function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName, vi) +struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R +end +DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf() + +function DynamicPPL.tilde_assume!!( + ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) if haskey(vi, vn) r = vi[vn] else - r = rand(rng, dist) + r = rand(ctx.rng, dist) vi = push!!(vi, vn, r, dist) end vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist) return r, vi end +function DynamicPPL.tilde_observe!!( + ::ISContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo +) + return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 863db559ce..833303b864 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -104,7 +104,7 @@ mean(chain) ``` """ -struct MH{P} <: InferenceAlgorithm +struct MH{P} <: AbstractSampler proposals::P function MH(proposals...) @@ -178,8 +178,6 @@ get_varinfo(s::MHState) = s.varinfo # Utility functions # ##################### -# TODO(DPPL0.38/penelopeysm): This function should no longer be needed -# once InitContext is merged. """ set_namedtuple!(vi::VarInfo, nt::NamedTuple) @@ -207,15 +205,24 @@ end # NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems # interface in that it gets evaluated with a NamedTuple. Hence we need this # method just to deal with MH. -# TODO(DPPL0.38/penelopeysm): Check the extent to which this method is actually -# needed. If it's still needed, replace this with `init!!(f.model, f.varinfo, -# ParamsInit(x))`. Much less hacky than `set_namedtuple!` (hopefully...). -# In general, we should much prefer to either (1) conform to the -# LogDensityProblems interface or (2) use VarNames anyway. function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple) vi = deepcopy(f.varinfo) + # Note that the NamedTuple `x` does NOT conform to the structure required for + # `InitFromParams`. In particular, for models that look like this: + # + # @model function f() + # v = Vector{Vector{Float64}} + # v[1] ~ MvNormal(zeros(2), I) + # end + # + # `InitFromParams` will expect Dict(@varname(v[1]) => [x1, x2]), but `x` will have the + # format `(v = [x1, x2])`. Hence we still need this `set_namedtuple!` function. + # + # In general `init!!(f.model, vi, InitFromParams(x))` will work iff the model only + # contains 'basic' varnames. set_namedtuple!(vi, x) - vi_new = last(DynamicPPL.evaluate!!(f.model, vi)) + # Update log probability. + _, vi_new = DynamicPPL.evaluate!!(f.model, vi) lj = f.getlogdensity(vi_new) return lj end @@ -240,16 +247,16 @@ function reconstruct(dist::AbstractVector{<:MultivariateDistribution}, val::Abst end """ - dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo) + dist_val_tuple(spl::MH, vi::VarInfo) Return two `NamedTuples`. The first `NamedTuple` has symbols as keys and distributions as values. The second `NamedTuple` has model symbols as keys and their stored values as values. """ -function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) +function dist_val_tuple(spl::MH, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo) vns = all_varnames_grouped_by_symbol(vi) - dt = _dist_tuple(spl.alg.proposals, vi, vns) + dt = _dist_tuple(spl.proposals, vi, vns) vt = _val_tuple(vi, vns) return dt, vt end @@ -317,9 +324,7 @@ function maybe_link!!(varinfo, sampler, proposal, model) end # Make a proposal if we don't have a covariance proposal matrix (the default). -function propose!!( - rng::AbstractRNG, prev_state::MHState, model::Model, spl::Sampler{<:MH}, proposal -) +function propose!!(rng::AbstractRNG, prev_state::MHState, model::Model, spl::MH, proposal) vi = prev_state.varinfo # Retrieve distribution and value NamedTuples. dt, vt = dist_val_tuple(spl, vi) @@ -329,13 +334,11 @@ function propose!!( prev_trans = AMH.Transition(vt, prev_state.logjoint_internal, false) # Make a new transition. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, spl, model.context) - ) + model = DynamicPPL.setleafcontext(model, MHContext(rng)) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) @@ -353,7 +356,7 @@ function propose!!( rng::AbstractRNG, prev_state::MHState, model::Model, - spl::Sampler{<:MH}, + spl::MH, proposal::AdvancedMH.RandomWalkProposal, ) vi = prev_state.varinfo @@ -362,17 +365,15 @@ function propose!!( vals = vi[:] # Create a sampler and the previous transition. - mh_sampler = AMH.MetropolisHastings(spl.alg.proposals) + mh_sampler = AMH.MetropolisHastings(spl.proposals) prev_trans = AMH.Transition(vals, prev_state.logjoint_internal, false) # Make a new transition. - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, spl, model.context) - ) + model = DynamicPPL.setleafcontext(model, MHContext(rng)) densitymodel = AMH.DensityModel( Base.Fix1( LogDensityProblems.logdensity, - DynamicPPL.LogDensityFunction(spl_model, DynamicPPL.getlogjoint_internal, vi), + DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi), ), ) trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) @@ -385,38 +386,46 @@ function propose!!( return MHState(vi, trans.lp) end -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:MH}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, spl::MH, vi::AbstractVarInfo; kwargs... ) # If we're doing random walk with a covariance matrix, # just link everything before sampling. - vi = maybe_link!!(vi, spl, spl.alg.proposals, model) + vi = maybe_link!!(vi, spl, spl.proposals, model) return Transition(model, vi, nothing), MHState(vi, DynamicPPL.getlogjoint_internal(vi)) end function AbstractMCMC.step( - rng::AbstractRNG, model::Model, spl::Sampler{<:MH}, state::MHState; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, spl::MH, state::MHState; kwargs... ) # Cases: # 1. A covariance proposal matrix # 2. A bunch of NamedTuples that specify the proposal space - new_state = propose!!(rng, state, model, spl, spl.alg.proposals) + new_state = propose!!(rng, state, model, spl, spl.proposals) return Transition(model, new_state.varinfo, nothing), new_state end -#### -#### Compiler interface, i.e. tilde operators. -#### -function DynamicPPL.assume( - rng::Random.AbstractRNG, spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi +struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R +end +DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf() + +function DynamicPPL.tilde_assume!!( + context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Allow MH to sample new variables from the prior if it's not already present in the + # VarInfo. + dispatch_ctx = if haskey(vi, vn) + DynamicPPL.DefaultContext() + else + DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior()) + end + return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi) +end +function DynamicPPL.tilde_observe!!( + ::MHContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo ) - # Just defer to `SampleFromPrior`. - retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) - return retval + return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index e80ec527bb..7aadef09ef 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -4,62 +4,28 @@ ### AdvancedPS models and interface -""" - set_all_del!(vi::AbstractVarInfo) - -Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for -resampling. -""" -function set_all_del!(vi::AbstractVarInfo) - # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we - # could either: - # - keep a boolean 'resample' flag on the trace, or - # - modify the model context appropriately. - # However, this refactoring will have to wait until InitContext is - # merged into DPPL. - for vn in keys(vi) - DynamicPPL.set_flag!(vi, vn, "del") - end - return nothing -end - -""" - unset_all_del!(vi::AbstractVarInfo) - -Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing -them from being resampled. -""" -function unset_all_del!(vi::AbstractVarInfo) - for vn in keys(vi) - DynamicPPL.unset_flag!(vi, vn, "del") - end - return nothing +struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext + rng::R end +DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf() -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: - AdvancedPS.AbstractGenericModel +struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M - sampler::S varinfo::V evaluator::E + resample::Bool end function TracedModel( - model::Model, - sampler::AbstractSampler, - varinfo::AbstractVarInfo, - rng::Random.AbstractRNG, + model::Model, varinfo::AbstractVarInfo, rng::Random.AbstractRNG, resample::Bool ) - spl_context = DynamicPPL.SamplingContext(rng, sampler, model.context) - spl_model = DynamicPPL.contextualize(model, spl_context) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(spl_model, varinfo) - if kwargs !== nothing && !isempty(kwargs) - error( - "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", - ) - end - evaluator = (spl_model.f, args...) - return TracedModel(spl_model, sampler, varinfo, evaluator) + model = DynamicPPL.setleafcontext(model, ParticleMCMCContext(rng)) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) + isempty(kwargs) || error( + "Particle sampling methods do not currently support models with keyword arguments.", + ) + evaluator = (model.f, args...) + return TracedModel(model, varinfo, evaluator, resample) end function AdvancedPS.advance!( @@ -75,16 +41,9 @@ function AdvancedPS.delete_retained!(trace::TracedModel) # This method is called if, during a CSMC update, we perform a resampling # and choose the reference particle as the trajectory to carry on from. # In such a case, we need to ensure that when we continue sampling (i.e. - # the next time we hit tilde_assume), we don't use the values in the + # the next time we hit tilde_assume!!), we don't use the values in the # reference particle but rather sample new values. - # - # Here, we indiscriminately set the 'del' flag for all variables in the - # VarInfo. This is slightly overkill: it is not necessary to set the 'del' - # flag for variables that were already sampled. However, it allows us to - # avoid keeping track of which variables were sampled, which leads to many - # simplifications in the VarInfo data structure. - set_all_del!(trace.varinfo) - return trace + return TracedModel(trace.model, trace.varinfo, trace.evaluator, true) end function AdvancedPS.reset_model(trace::TracedModel) @@ -97,7 +56,7 @@ function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) ) end -abstract type ParticleInference <: InferenceAlgorithm end +abstract type ParticleInference <: AbstractSampler end #### #### Generic Sequential Monte Carlo sampler. @@ -117,8 +76,8 @@ struct SMC{R} <: ParticleInference end """ - SMC([resampler = AdvancedPS.ResampleWithESSThreshold()]) - SMC([resampler = AdvancedPS.resample_systematic, ]threshold) +SMC([resampler = AdvancedPS.ResampleWithESSThreshold()]) +SMC([resampler = AdvancedPS.resample_systematic, ]threshold) Create a sequential Monte Carlo sampler of type [`SMC`](@ref). @@ -142,69 +101,57 @@ struct SMCState{P,F<:AbstractFloat} average_logevidence::F end -function getlogevidence(samples, sampler::Sampler{<:SMC}, state::SMCState) +function getlogevidence(samples, ::SMC, state::SMCState) return state.average_logevidence end function AbstractMCMC.sample( rng::AbstractRNG, model::DynamicPPL.Model, - sampler::Sampler{<:SMC}, + sampler::SMC, N::Integer; - chain_type=DynamicPPL.default_chain_type(sampler), - resume_from=nothing, - initial_state=DynamicPPL.loadstate(resume_from), + check_model=true, + chain_type=DEFAULT_CHAIN_TYPE, + initial_params=Turing.Inference.init_strategy(sampler), progress=PROGRESS[], kwargs..., ) - if resume_from === nothing - return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - N; - chain_type=chain_type, - progress=progress, - nparticles=N, - kwargs..., - ) - else - return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - N; - chain_type, - initial_state, - progress=progress, - nparticles=N, - kwargs..., - ) - end + check_model && _check_model(model, sampler) + # need to add on the `nparticles` keyword argument for `initialstep` to make use of + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + N; + chain_type=chain_type, + initial_params=initial_params, + progress=progress, + nparticles=N, + kwargs..., + ) end -function DynamicPPL.initialstep( +function Turing.Inference.initialstep( rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:SMC}, + model::DynamicPPL.Model, + spl::SMC, vi::AbstractVarInfo; nparticles::Int, kwargs..., ) # Reset the VarInfo. vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - set_all_del!(vi) vi = DynamicPPL.empty!!(vi) # Create a new set of particles. particles = AdvancedPS.ParticleContainer( - [AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) for _ in 1:nparticles], + [AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for _ in 1:nparticles], AdvancedPS.TracedRNG(), rng, ) # Perform particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl) # Extract the first particle and its weight. particle = particles.vals[1] @@ -219,7 +166,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - ::AbstractRNG, model::AbstractModel, spl::Sampler{<:SMC}, state::SMCState; kwargs... + ::AbstractRNG, model::DynamicPPL.Model, spl::SMC, state::SMCState; kwargs... ) # Extract the index of the current particle. index = state.particleindex @@ -258,8 +205,8 @@ struct PG{R} <: ParticleInference end """ - PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()]) - PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold) +PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()]) +PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold) Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles. @@ -279,7 +226,7 @@ function PG(nparticles::Int, threshold::Real) end """ - CSMC(...) +CSMC(...) Equivalent to [`PG`](@ref). """ @@ -293,9 +240,7 @@ end get_varinfo(state::PGState) = state.vi function getlogevidence( - transitions::AbstractVector{<:Turing.Inference.Transition}, - sampler::Sampler{<:PG}, - state::PGState, + transitions::AbstractVector{<:Turing.Inference.Transition}, ::PG, ::PGState ) logevidences = map(transitions) do t if haskey(t.stat, :logevidence) @@ -309,27 +254,24 @@ function getlogevidence( return mean(logevidences) end -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:PG}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, vi::AbstractVarInfo; kwargs... ) vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) - # Reset the VarInfo before new sweep - set_all_del!(vi) # Create a new set of particles - num_particles = spl.alg.nparticles + num_particles = spl.nparticles particles = AdvancedPS.ParticleContainer( - [AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) for _ in 1:num_particles], + [ + AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) for + _ in 1:num_particles + ], AdvancedPS.TracedRNG(), rng, ) # Perform a particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl) # Pick a particle to be retained. Ws = AdvancedPS.getweights(particles) @@ -344,24 +286,20 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::AbstractRNG, model::AbstractModel, spl::Sampler{<:PG}, state::PGState; kwargs... + rng::AbstractRNG, model::DynamicPPL.Model, spl::PG, state::PGState; kwargs... ) # Reset the VarInfo before new sweep. vi = state.vi vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) # Create reference particle for which the samples will be retained. - unset_all_del!(vi) - reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) - - # For all other particles, do not retain the variables but resample them. - set_all_del!(vi) + reference = AdvancedPS.forkr(AdvancedPS.Trace(model, vi, state.rng, false)) # Create a new set of particles. - num_particles = spl.alg.nparticles + num_particles = spl.nparticles x = map(1:num_particles) do i if i != num_particles - return AdvancedPS.Trace(model, spl, vi, AdvancedPS.TracedRNG()) + return AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), true) else return reference end @@ -369,7 +307,7 @@ function AbstractMCMC.step( particles = AdvancedPS.ParticleContainer(x, AdvancedPS.TracedRNG(), rng) # Perform a particle sweep. - logevidence = AdvancedPS.sweep!(rng, particles, spl.alg.resampler, spl, reference) + logevidence = AdvancedPS.sweep!(rng, particles, spl.resampler, spl, reference) # Pick a particle to be retained. Ws = AdvancedPS.getweights(particles) @@ -383,14 +321,10 @@ function AbstractMCMC.step( return transition, PGState(_vi, newreference.rng) end -function DynamicPPL.use_threadsafe_eval( - ::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo -) - return false -end +DynamicPPL.use_threadsafe_eval(::ParticleMCMCContext, ::AbstractVarInfo) = false """ - get_trace_local_varinfo_maybe(vi::AbstractVarInfo) +get_trace_local_varinfo_maybe(vi::AbstractVarInfo) Get the `Trace` local varinfo if one exists. @@ -407,7 +341,24 @@ function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo) end """ - get_trace_local_varinfo_maybe(rng::Random.AbstractRNG) +get_trace_local_resampled_maybe(fallback_resampled::Bool) + +Get the `Trace` local `resampled` if one exists. + +If executed within a `TapedTask`, return the `resampled` stored in the "taped globals" of +the task, otherwise return `fallback_resampled`. +""" +function get_trace_local_resampled_maybe(fallback_resampled::Bool) + trace = try + Libtask.get_taped_globals(Any).other + catch e + e == KeyError(:task_variable) ? nothing : rethrow(e) + end + return (trace === nothing ? fallback_resampled : trace.model.f.resample)::Bool +end + +""" +get_trace_local_rng_maybe(rng::Random.AbstractRNG) Get the `Trace` local rng if one exists. @@ -423,7 +374,7 @@ function get_trace_local_rng_maybe(rng::Random.AbstractRNG) end """ - set_trace_local_varinfo_maybe(vi::AbstractVarInfo) +set_trace_local_varinfo_maybe(vi::AbstractVarInfo) Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`. @@ -446,30 +397,22 @@ function set_trace_local_varinfo_maybe(vi::AbstractVarInfo) return nothing end -function DynamicPPL.assume( - rng, ::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, vi::AbstractVarInfo +function DynamicPPL.tilde_assume!!( + ctx::ParticleMCMCContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) arg_vi_id = objectid(vi) vi = get_trace_local_varinfo_maybe(vi) using_local_vi = objectid(vi) == arg_vi_id - trng = get_trace_local_rng_maybe(rng) - - if ~haskey(vi, vn) - r = rand(trng, dist) - vi = push!!(vi, vn, r, dist) - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent - # TODO(mhauru): - # The below is the only line that differs from assume called on SampleFromPrior. - # Could we just call assume on SampleFromPrior with a specific rng? - r = rand(trng, dist) - vi[vn] = DynamicPPL.tovec(r) + trng = get_trace_local_rng_maybe(ctx.rng) + resample = get_trace_local_resampled_maybe(true) + + dispatch_ctx = if ~haskey(vi, vn) || resample + DynamicPPL.InitContext(trng, DynamicPPL.InitFromPrior()) else - r = vi[vn] + DynamicPPL.DefaultContext() end - - vi = DynamicPPL.accumulate_assume!!(vi, r, 0, vn, dist) + x, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, vi) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, @@ -477,17 +420,21 @@ function DynamicPPL.assume( if !using_local_vi set_trace_local_varinfo_maybe(vi) end - return r, vi + return x, vi end function DynamicPPL.tilde_observe!!( - ctx::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, right, left, vn, vi + ::ParticleMCMCContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, ) arg_vi_id = objectid(vi) vi = get_trace_local_varinfo_maybe(vi) using_local_vi = objectid(vi) == arg_vi_id - left, vi = DynamicPPL.tilde_observe!!(ctx.context, right, left, vn, vi) + left, vi = DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi) # TODO(mhauru) Rather than this if-block, we should use try-catch within # `set_trace_local_varinfo_maybe`. However, currently Libtask can't handle such a block, @@ -500,19 +447,16 @@ end # Convenient constructor function AdvancedPS.Trace( - model::Model, - sampler::Sampler{<:Union{SMC,PG}}, - varinfo::AbstractVarInfo, - rng::AdvancedPS.TracedRNG, + model::Model, varinfo::AbstractVarInfo, rng::AdvancedPS.TracedRNG, resample::Bool ) newvarinfo = deepcopy(varinfo) - tmodel = TracedModel(model, sampler, newvarinfo, rng) + tmodel = TracedModel(model, newvarinfo, rng, resample) newtrace = AdvancedPS.Trace(tmodel, rng) return newtrace end """ - ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator +ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value. @@ -573,7 +517,6 @@ Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} # Could the next two could have tighter type bounds on the arguments, namely a GibbsContext? # That's the only thing that makes tilde_assume calls result in tilde_observe calls. Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume!!),Vararg}}) = true -Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_assume),Vararg}}) = true Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true function Libtask.might_produce( ::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index 2ead40cedf..c4ec6c6f33 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -3,28 +3,23 @@ Algorithm for sampling from the prior. """ -struct Prior <: InferenceAlgorithm end +struct Prior <: AbstractSampler end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:Prior}, + sampler::Prior, state=nothing; kwargs..., ) - # TODO(DPPL0.38/penelopeysm): replace with init!! - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context) - ) - vi = VarInfo() vi = DynamicPPL.setaccs!!( - vi, + DynamicPPL.VarInfo(), ( DynamicPPL.ValuesAsInModelAccumulator(true), DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator(), ), ) - _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior()) return Transition(model, vi, nothing; reevaluate=false), nothing end diff --git a/src/mcmc/repeat_sampler.jl b/src/mcmc/repeat_sampler.jl index fa2eca96de..133517494e 100644 --- a/src/mcmc/repeat_sampler.jl +++ b/src/mcmc/repeat_sampler.jl @@ -24,11 +24,12 @@ struct RepeatSampler{S<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSa end end -function RepeatSampler(alg::InferenceAlgorithm, num_repeat::Int) - return RepeatSampler(Sampler(alg), num_repeat) -end - -function setparams_varinfo!!(model::DynamicPPL.Model, sampler::RepeatSampler, state, params) +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::RepeatSampler, + state, + params::DynamicPPL.AbstractVarInfo, +) return setparams_varinfo!!(model, sampler.sampler, state, params) end @@ -40,6 +41,14 @@ function AbstractMCMC.step( ) return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) end +# The following method needed for method ambiguity resolution. +# TODO(penelopeysm): Remove this method once the default `AbstractMCMC.step(rng, +# ::DynamicPPL.Model, ::AbstractSampler)` method in `src/mcmc/abstractmcmc.jl` is removed. +function AbstractMCMC.step( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::RepeatSampler; kwargs... +) + return AbstractMCMC.step(rng, model, sampler.sampler; kwargs...) +end function AbstractMCMC.step( rng::Random.AbstractRNG, @@ -81,3 +90,62 @@ function AbstractMCMC.step_warmup( end return transition, state end + +# Need some extra leg work to make RepeatSampler work seamlessly with DynamicPPL models + +# samplers, instead of generic AbstractMCMC samplers. + +function Turing.Inference.init_strategy(spl::RepeatSampler) + return Turing.Inference.init_strategy(spl.sampler) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::RepeatSampler, + N::Integer; + check_model=true, + initial_params=Turing.Inference.init_strategy(sampler), + chain_type=DEFAULT_CHAIN_TYPE, + progress=PROGRESS[], + kwargs..., +) + check_model && _check_model(model, sampler) + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + N; + initial_params=_convert_initial_params(initial_params), + chain_type=chain_type, + progress=progress, + kwargs..., + ) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::RepeatSampler, + ensemble::AbstractMCMC.AbstractMCMCEnsemble, + N::Integer, + n_chains::Integer; + check_model=true, + initial_params=fill(Turing.Inference.init_strategy(sampler), n_chains), + chain_type=DEFAULT_CHAIN_TYPE, + progress=PROGRESS[], + kwargs..., +) + check_model && _check_model(model, sampler) + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + ensemble, + N, + n_chains; + initial_params=map(_convert_initial_params, initial_params), + chain_type=chain_type, + progress=progress, + kwargs..., + ) +end diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 34d7cf9d8d..267a216209 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -51,22 +51,18 @@ struct SGHMCState{L,V<:AbstractVarInfo,T<:AbstractVector{<:Real}} velocity::T end -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGHMC}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::Random.AbstractRNG, model::Model, spl::SGHMC, vi::AbstractVarInfo; kwargs... ) # Transform the samples to unconstrained space. - if !DynamicPPL.islinked(vi) + if !DynamicPPL.is_transformed(vi) vi = DynamicPPL.link!!(vi, model) end # Compute initial sample and state. sample = Transition(model, vi, nothing) ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) state = SGHMCState(ℓ, vi, zero(vi[:])) @@ -74,11 +70,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGHMC}, - state::SGHMCState; - kwargs..., + rng::Random.AbstractRNG, model::Model, spl::SGHMC, state::SGHMCState; kwargs... ) # Compute gradient of log density. ℓ = state.logdensity @@ -90,8 +82,8 @@ function AbstractMCMC.step( # equation (15) of Chen et al. (2014) v = state.velocity θ .+= v - η = spl.alg.learning_rate - α = spl.alg.momentum_decay + η = spl.learning_rate + α = spl.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) # Save new variables. @@ -190,22 +182,18 @@ struct SGLDState{L,V<:AbstractVarInfo} step::Int end -function DynamicPPL.initialstep( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:SGLD}, - vi::AbstractVarInfo; - kwargs..., +function Turing.Inference.initialstep( + rng::Random.AbstractRNG, model::Model, spl::SGLD, vi::AbstractVarInfo; kwargs... ) # Transform the samples to unconstrained space. - if !DynamicPPL.islinked(vi) + if !DynamicPPL.is_transformed(vi) vi = DynamicPPL.link!!(vi, model) end # Create first sample and state. - transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.alg.stepsize(0)))) + transition = Transition(model, vi, (; SGLD_stepsize=zero(spl.stepsize(0)))) ℓ = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.alg.adtype + model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) state = SGLDState(ℓ, vi, 1) @@ -213,7 +201,7 @@ function DynamicPPL.initialstep( end function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler{<:SGLD}, state::SGLDState; kwargs... + rng::Random.AbstractRNG, model::Model, spl::SGLD, state::SGLDState; kwargs... ) # Perform gradient step. ℓ = state.logdensity @@ -221,7 +209,7 @@ function AbstractMCMC.step( θ = vi[:] grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step - stepsize = spl.alg.stepsize(step) + stepsize = spl.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) # Save new variables. diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 19c52c381b..3a7d15e685 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -2,6 +2,7 @@ module Optimisation using ..Turing using NamedArrays: NamedArrays +using AbstractPPL: AbstractPPL using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Optimization: Optimization @@ -273,7 +274,7 @@ function StatsBase.informationmatrix( # Convert the values to their unconstrained states to make sure the # Hessian is computed with respect to the untransformed parameters. old_ldf = m.f.ldf - linked = DynamicPPL.istrans(old_ldf.varinfo) + linked = DynamicPPL.is_transformed(old_ldf.varinfo) if linked new_vi = DynamicPPL.invlink!!(old_ldf.varinfo, old_ldf.model) new_f = OptimLogDensity( @@ -320,7 +321,7 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) # m.values, but they are more convenient to filter when they are VarNames rather than # Symbols. vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict)) vns_and_vals = mapreduce(collect, vcat, iters) varnames = collect(map(first, vns_and_vals)) # For each symbol s in var_symbols, pick all the values from m.values for which the @@ -351,7 +352,7 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u) # `getparams` performs invlinking if needed vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) vns_vals_iter = mapreduce(collect, vcat, iters) syms = map(Symbol ∘ first, vns_vals_iter) vals = map(last, vns_vals_iter) @@ -507,10 +508,8 @@ function estimate_mode( kwargs..., ) if check_model - spl_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - DynamicPPL.check_model(spl_model, DynamicPPL.VarInfo(); error_on_failure=true) + new_model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) + DynamicPPL.check_model(new_model, DynamicPPL.VarInfo(); error_on_failure=true) end constraints = ModeEstimationConstraints(lb, ub, cons, lcons, ucons) diff --git a/src/stdlib/distributions.jl b/src/stdlib/distributions.jl index 617a75dd14..b96b048a2c 100644 --- a/src/stdlib/distributions.jl +++ b/src/stdlib/distributions.jl @@ -65,7 +65,7 @@ of success in an individual trial, with the distribution P(X = k) = {n \\choose k}{(\\text{logistic}(logitp))}^k (1 - \\text{logistic}(logitp))^{n-k}, \\quad \\text{ for } k = 0,1,2, \\ldots, n. ``` -See also: [`Binomial`](@ref) +See also: [`Distributions.Binomial`](@extref) """ struct BinomialLogit{T<:Real,S<:Real} <: DiscreteUnivariateDistribution n::Int @@ -188,7 +188,7 @@ The distribution has the probability mass function P(X = k) = \\frac{e^{k \\cdot \\log\\lambda}}{k!} e^{-e^{\\log\\lambda}}, \\quad \\text{ for } k = 0,1,2,\\ldots. ``` -See also: [`Poisson`](@ref) +See also: [`Distributions.Poisson`](@extref) """ struct LogPoisson{T<:Real,S} <: DiscreteUnivariateDistribution logλ::T diff --git a/test/Project.toml b/test/Project.toml index 138b1a1a0d..435f8cc5f2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,6 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.37.2" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" diff --git a/test/ad.jl b/test/ad.jl index dcfe4ef46c..9524199dc7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -154,31 +154,23 @@ end # context, and then call check_adtype on the result before returning the results from the # child context. -function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) - value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) - check_adtype(context, vi) - return value, vi -end - -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi +function DynamicPPL.tilde_assume!!( + context::ADTypeCheckContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) - value, vi = DynamicPPL.tilde_assume( - rng, DynamicPPL.childcontext(context), sampler, right, vn, vi - ) + value, vi = DynamicPPL.tilde_assume!!(DynamicPPL.childcontext(context), right, vn, vi) check_adtype(context, vi) return value, vi end -function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi) - left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi) - check_adtype(context, vi) - return left, vi -end - -function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi) +function DynamicPPL.tilde_observe!!( + context::ADTypeCheckContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) left, vi = DynamicPPL.tilde_observe!!( - DynamicPPL.childcontext(context), sampler, right, left, vi + DynamicPPL.childcontext(context), right, left, vn, vi ) check_adtype(context, vi) return left, vi diff --git a/test/essential/container.jl b/test/essential/container.jl index 124637aab7..19609b6b51 100644 --- a/test/essential/container.jl +++ b/test/essential/container.jl @@ -2,7 +2,7 @@ module ContainerTests using AdvancedPS: AdvancedPS using Distributions: Bernoulli, Beta, Gamma, Normal -using DynamicPPL: DynamicPPL, @model, Sampler +using DynamicPPL: DynamicPPL, @model using Test: @test, @testset using Turing @@ -20,9 +20,9 @@ using Turing @testset "constructor" begin vi = DynamicPPL.VarInfo() vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) - sampler = Sampler(PG(10)) + sampler = PG(10) model = test() - trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) + trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) # Make sure the backreference from taped_globals to the trace is in place. @test trace.model.ctask.taped_globals.other === trace @@ -45,10 +45,10 @@ using Turing end vi = DynamicPPL.VarInfo() vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator()) - sampler = Sampler(PG(10)) + sampler = PG(10) model = normal() - trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG()) + trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false) newtrace = AdvancedPS.forkr(trace) # Catch broken replay mechanism diff --git a/test/ext/OptimInterface.jl b/test/ext/OptimInterface.jl index 8fb9e2b1ac..721e255f3c 100644 --- a/test/ext/OptimInterface.jl +++ b/test/ext/OptimInterface.jl @@ -2,6 +2,7 @@ module OptimInterfaceTests using ..Models: gdemo_default using Distributions.FillArrays: Zeros +using AbstractPPL: AbstractPPL using DynamicPPL: DynamicPPL using LinearAlgebra: I using Optim: Optim @@ -124,7 +125,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) @test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol = 0.05 end end @@ -159,7 +160,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) if model.f in allowed_incorrect_mle @test isfinite(get(result_true, vn_leaf)) else diff --git a/test/ext/dynamichmc.jl b/test/ext/dynamichmc.jl index 3f609504df..004970dd3c 100644 --- a/test/ext/dynamichmc.jl +++ b/test/ext/dynamichmc.jl @@ -6,7 +6,6 @@ using Test: @test, @testset using Distributions: sample using DynamicHMC: DynamicHMC using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using Turing diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 9f69a2de53..6918eaddf9 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -6,7 +6,6 @@ using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample using AbstractMCMC: AbstractMCMC import DynamicPPL -using DynamicPPL: Sampler import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -41,7 +40,9 @@ using Turing Random.seed!(5) chain2 = sample(model, sampler, MCMCThreads(), 10, 4) - @test chain1.value == chain2.value + # For HMC, the first step does not have stats, so we need to use isequal to + # avoid comparing `missing`s + @test isequal(chain1.value, chain2.value) end # Should also be stable with an explicit RNG @@ -54,7 +55,7 @@ using Turing Random.seed!(rng, local_seed) chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4) - @test chain1.value == chain2.value + @test isequal(chain1.value, chain2.value) end end @@ -64,28 +65,16 @@ using Turing StableRNG(seed), gdemo_default, HMC(0.1, 7), MCMCThreads(), 1_000, 4 ) check_gdemo(chain) - - # run sampler: progress logging should be disabled and - # it should return a Chains object - sampler = Sampler(HMC(0.1, 7)) - chains = sample(StableRNG(seed), gdemo_default, sampler, MCMCThreads(), 10, 4) - @test chains isa MCMCChains.Chains end end @testset "save/resume correctly reloads state" begin - struct StaticSampler <: Turing.Inference.InferenceAlgorithm end - function DynamicPPL.initialstep( - rng, model, ::DynamicPPL.Sampler{<:StaticSampler}, vi; kwargs... - ) + struct StaticSampler <: AbstractMCMC.AbstractSampler end + function Turing.Inference.initialstep(rng, model, ::StaticSampler, vi; kwargs...) return Turing.Inference.Transition(model, vi, nothing), vi end function AbstractMCMC.step( - rng, - model, - ::DynamicPPL.Sampler{<:StaticSampler}, - vi::DynamicPPL.AbstractVarInfo; - kwargs..., + rng, model, ::StaticSampler, vi::DynamicPPL.AbstractVarInfo; kwargs... ) return Turing.Inference.Transition(model, vi, nothing), vi end @@ -95,7 +84,7 @@ using Turing @testset "single-chain" begin chn1 = sample(demo(), StaticSampler(), 10; save_state=true) @test chn1.info.samplerstate isa DynamicPPL.AbstractVarInfo - chn2 = sample(demo(), StaticSampler(), 10; resume_from=chn1) + chn2 = sample(demo(), StaticSampler(), 10; initial_state=loadstate(chn1)) xval = chn1[:x][1] @test all(chn2[:x] .== xval) end @@ -107,7 +96,12 @@ using Turing @test chn1.info.samplerstate isa AbstractVector{<:DynamicPPL.AbstractVarInfo} @test length(chn1.info.samplerstate) == nchains chn2 = sample( - demo(), StaticSampler(), MCMCThreads(), 10, nchains; resume_from=chn1 + demo(), + StaticSampler(), + MCMCThreads(), + 10, + nchains; + initial_state=loadstate(chn1), ) xval = chn1[:x][1, :] @test all(i -> chn2[:x][i, :] == xval, 1:10) @@ -122,10 +116,14 @@ using Turing chn1 = sample(StableRNG(seed), gdemo_default, alg1, 10_000; save_state=true) check_gdemo(chn1) - chn1_contd = sample(StableRNG(seed), gdemo_default, alg1, 2_000; resume_from=chn1) + chn1_contd = sample( + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) + ) check_gdemo(chn1_contd) - chn1_contd2 = sample(StableRNG(seed), gdemo_default, alg1, 2_000; resume_from=chn1) + chn1_contd2 = sample( + StableRNG(seed), gdemo_default, alg1, 2_000; initial_state=loadstate(chn1) + ) check_gdemo(chn1_contd2) chn2 = sample( @@ -138,7 +136,9 @@ using Turing ) check_gdemo(chn2) - chn2_contd = sample(StableRNG(seed), gdemo_default, alg2, 2_000; resume_from=chn2) + chn2_contd = sample( + StableRNG(seed), gdemo_default, alg2, 2_000; initial_state=loadstate(chn2) + ) check_gdemo(chn2_contd) chn3 = sample( @@ -151,7 +151,9 @@ using Turing ) check_gdemo(chn3) - chn3_contd = sample(StableRNG(seed), gdemo_default, alg3, 5_000; resume_from=chn3) + chn3_contd = sample( + StableRNG(seed), gdemo_default, alg3, 5_000; initial_state=loadstate(chn3) + ) check_gdemo(chn3_contd) end @@ -608,8 +610,8 @@ using Turing @testset "names_values" begin ks, xs = Turing.Inference.names_values([(a=1,), (b=2,), (a=3, b=4)]) - @test all(xs[:, 1] .=== [1, missing, 3]) - @test all(xs[:, 2] .=== [missing, 2, 4]) + @test isequal(xs[:, 1], [1, missing, 3]) + @test isequal(xs[:, 2], [missing, 2, 4]) end @testset "check model" begin diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl new file mode 100644 index 0000000000..6f4b476130 --- /dev/null +++ b/test/mcmc/abstractmcmc.jl @@ -0,0 +1,136 @@ +module TuringAbstractMCMCTests + +using AbstractMCMC: AbstractMCMC +using DynamicPPL: DynamicPPL +using Random: AbstractRNG +using Test: @test, @testset, @test_throws +using Turing + +@testset "Initial parameters" begin + # Dummy algorithm that just returns initial value and does not perform any sampling + abstract type OnlyInit <: AbstractMCMC.AbstractSampler end + struct OnlyInitDefault <: OnlyInit end + struct OnlyInitUniform <: OnlyInit end + Turing.Inference.init_strategy(::OnlyInitUniform) = InitFromUniform() + function Turing.Inference.initialstep( + rng::AbstractRNG, + model::DynamicPPL.Model, + ::OnlyInit, + vi::DynamicPPL.VarInfo=DynamicPPL.VarInfo(rng, model); + kwargs..., + ) + return vi, nothing + end + + @testset "init_strategy" begin + # check that the default init strategy is prior + @test Turing.Inference.init_strategy(OnlyInitDefault()) == InitFromPrior() + @test Turing.Inference.init_strategy(OnlyInitUniform()) == InitFromUniform() + end + + for spl in (OnlyInitDefault(), OnlyInitUniform()) + # model with one variable: initialization p = 0.2 + @model function coinflip() + p ~ Beta(1, 1) + return 10 ~ Binomial(25, p) + end + model = coinflip() + lptrue = logpdf(Binomial(25, 0.2), 10) + let inits = InitFromParams((; p=0.2)) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test chain[1].metadata.p.vals == [0.2] + @test DynamicPPL.getlogjoint(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.p.vals == [0.2] + @test DynamicPPL.getlogjoint(c[1]) == lptrue + end + end + + # check that Vector no longer works + @test_throws ArgumentError sample( + model, spl, 1; initial_params=[4, -1], progress=false + ) + @test_throws ArgumentError sample( + model, spl, 1; initial_params=[missing, -1], progress=false + ) + + # model with two variables: initialization s = 4, m = -1 + @model function twovars() + s ~ InverseGamma(2, 3) + return m ~ Normal(0, sqrt(s)) + end + model = twovars() + lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) + for inits in ( + InitFromParams((s=4, m=-1)), + (s=4, m=-1), + InitFromParams(Dict(@varname(s) => 4, @varname(m) => -1)), + Dict(@varname(s) => 4, @varname(m) => -1), + ) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test chain[1].metadata.s.vals == [4] + @test chain[1].metadata.m.vals == [-1] + @test DynamicPPL.getlogjoint(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.s.vals == [4] + @test c[1].metadata.m.vals == [-1] + @test DynamicPPL.getlogjoint(c[1]) == lptrue + end + end + + # set only m = -1 + for inits in ( + InitFromParams((; s=missing, m=-1)), + InitFromParams(Dict(@varname(s) => missing, @varname(m) => -1)), + (; s=missing, m=-1), + Dict(@varname(s) => missing, @varname(m) => -1), + InitFromParams((; m=-1)), + InitFromParams(Dict(@varname(m) => -1)), + (; m=-1), + Dict(@varname(m) => -1), + ) + chain = sample(model, spl, 1; initial_params=inits, progress=false) + @test !ismissing(chain[1].metadata.s.vals[1]) + @test chain[1].metadata.m.vals == [-1] + + # parallel sampling + chains = sample( + model, + spl, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test !ismissing(c[1].metadata.s.vals[1]) + @test c[1].metadata.m.vals == [-1] + end + end + end +end + +end # module diff --git a/test/mcmc/emcee.jl b/test/mcmc/emcee.jl index b9a041d781..44bf75858e 100644 --- a/test/mcmc/emcee.jl +++ b/test/mcmc/emcee.jl @@ -4,7 +4,6 @@ using ..Models: gdemo_default using ..NumericalTests: check_gdemo using Distributions: sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using Test: @test, @test_throws, @testset using Turing @@ -34,18 +33,21 @@ using Turing nwalkers = 250 spl = Emcee(nwalkers, 2.0) - # No initial parameters, with im- and explicit `initial_params=nothing` Random.seed!(1234) chain1 = sample(gdemo_default, spl, 1) Random.seed!(1234) - chain2 = sample(gdemo_default, spl, 1; initial_params=nothing) + chain2 = sample(gdemo_default, spl, 1) @test Array(chain1) == Array(chain2) + initial_nt = DynamicPPL.InitFromParams((s=2.0, m=1.0)) # Initial parameters have to be specified for every walker - @test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=[2.0, 1.0]) + @test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=initial_nt) + @test_throws r"must be a vector of" sample( + gdemo_default, spl, 1; initial_params=initial_nt + ) # Initial parameters - chain = sample(gdemo_default, spl, 1; initial_params=fill([2.0, 1.0], nwalkers)) + chain = sample(gdemo_default, spl, 1; initial_params=fill(initial_nt, nwalkers)) @test chain[:s] == fill(2.0, 1, nwalkers) @test chain[:m] == fill(1.0, 1, nwalkers) end diff --git a/test/mcmc/ess.jl b/test/mcmc/ess.jl index 1e1be9b45f..e497fdde3a 100644 --- a/test/mcmc/ess.jl +++ b/test/mcmc/ess.jl @@ -2,9 +2,9 @@ module ESSTests using ..Models: MoGtest, MoGtest_default, gdemo, gdemo_default using ..NumericalTests: check_MoGtest_default, check_numerical +using ..SamplerTestUtils: test_rng_respected, test_sampler_analytical using Distributions: Normal, sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using Random: Random using StableRNGs: StableRNG using Test: @test, @testset @@ -38,6 +38,12 @@ using Turing c3 = sample(gdemo_default, s2, N) end + @testset "RNG is respected" begin + test_rng_respected(ESS()) + test_rng_respected(Gibbs(:x => ESS(), :y => MH())) + test_rng_respected(Gibbs(:x => ESS(), :y => ESS())) + end + @testset "ESS inference" begin @info "Starting ESS inference tests" seed = 23 @@ -78,9 +84,9 @@ using Turing model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,) end - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( models_conditioned, - DynamicPPL.Sampler(ESS()), + ESS(), 2000; # Filter out the varnames we've conditioned on. varnames_filter=vn -> DynamicPPL.getsym(vn) != :s, @@ -108,8 +114,12 @@ using Turing spl_x = Gibbs(@varname(z) => NUTS(), @varname(x) => ESS()) spl_xy = Gibbs(@varname(z) => NUTS(), (@varname(x), @varname(y)) => ESS()) - @test sample(StableRNG(23), xy(), spl_xy, num_samples).value ≈ - sample(StableRNG(23), x12(), spl_x, num_samples).value + chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples) + chn2 = sample(StableRNG(23), x12(), spl_x, num_samples) + + @test chn1.value ≈ chn2.value + @test mean(chn1[:z]) ≈ mean(Beta(2.0, 2.0)) atol = 0.05 + @test mean(chn1[:y]) ≈ -3.0 atol = 0.05 end end diff --git a/test/mcmc/external_sampler.jl b/test/mcmc/external_sampler.jl index 38b9b06608..56c03c87a8 100644 --- a/test/mcmc/external_sampler.jl +++ b/test/mcmc/external_sampler.jl @@ -1,6 +1,7 @@ module ExternalSamplerTests using ..Models: gdemo_default +using ..SamplerTestUtils: test_sampler_analytical using AbstractMCMC: AbstractMCMC using AdvancedMH: AdvancedMH using Distributions: sample @@ -45,6 +46,8 @@ using Turing.Inference: AdvancedHMC rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, sampler::MySampler; + # This initial_params should be an AbstractVector because the model is just a + # LogDensityModel, not a DynamicPPL.Model initial_params::AbstractVector, kwargs..., ) @@ -82,7 +85,10 @@ using Turing.Inference: AdvancedHMC model = test_external_sampler() a, b = 0.5, 0.0 - chn = sample(model, externalsampler(MySampler()), 10; initial_params=[a, b]) + # This `initial_params` should be an InitStrategy + chn = sample( + model, externalsampler(MySampler()), 10; initial_params=InitFromParams((a=a, b=b)) + ) @test chn isa MCMCChains.Chains @test all(chn[:a] .== a) @test all(chn[:b] .== b) @@ -156,10 +162,7 @@ function Distributions._rand!( ) model = d.model varinfo = deepcopy(d.varinfo) - for vn in keys(varinfo) - DynamicPPL.set_flag!(varinfo, vn, "del") - end - DynamicPPL.evaluate!!(model, varinfo, DynamicPPL.SamplingContext(rng)) + _, varinfo = DynamicPPL.init!!(rng, model, varinfo, DynamicPPL.InitFromPrior()) x .= varinfo[:] return x end @@ -170,16 +173,24 @@ function initialize_mh_with_prior_proposal(model) ) end -function test_initial_params( - model, sampler, initial_params=DynamicPPL.VarInfo(model)[:]; kwargs... -) +function test_initial_params(model, sampler; kwargs...) + # Generate some parameters. + dict = DynamicPPL.values_as(DynamicPPL.VarInfo(model), Dict) + init_strategy = DynamicPPL.InitFromParams(dict) + # Execute the transition with two different RNGs and check that the resulting - # parameter values are the same. + # parameter values are the same. This ensures that the `initial_params` are + # respected (i.e., regardless of the RNG, the first step should always return + # the same parameters). rng1 = Random.MersenneTwister(42) rng2 = Random.MersenneTwister(43) - transition1, _ = AbstractMCMC.step(rng1, model, sampler; initial_params, kwargs...) - transition2, _ = AbstractMCMC.step(rng2, model, sampler; initial_params, kwargs...) + transition1, _ = AbstractMCMC.step( + rng1, model, sampler; initial_params=init_strategy, kwargs... + ) + transition2, _ = AbstractMCMC.step( + rng2, model, sampler; initial_params=init_strategy, kwargs... + ) vn_to_val1 = DynamicPPL.OrderedDict(transition1.θ) vn_to_val2 = DynamicPPL.OrderedDict(transition2.θ) for vn in union(keys(vn_to_val1), keys(vn_to_val2)) @@ -195,23 +206,23 @@ end # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". sampler = initialize_nuts(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; adtype, unconstrained=true) - ) - # FIXME: Once https://github.com/TuringLang/AdvancedHMC.jl/pull/366 goes through, uncomment. + sampler_ext = externalsampler(sampler; adtype, unconstrained=true) + + # TODO: AdvancedHMC samplers do not return the initial parameters as the first + # step, so `test_initial_params` will fail. This should be fixed upstream in + # AdvancedHMC.jl. For reasons that are beyond my current understanding, this was + # done in https://github.com/TuringLang/AdvancedHMC.jl/pull/366, but the PR + # was then reverted and never looked at again. # @testset "initial_params" begin # test_initial_params(model, sampler_ext; n_adapts=0) # end sample_kwargs = ( - n_adapts=1_000, - discard_initial=1_000, - # FIXME: Remove this once we can run `test_initial_params` above. - initial_params=DynamicPPL.VarInfo(model)[:], + n_adapts=1_000, discard_initial=1_000, initial_params=InitFromUniform() ) @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( [model], sampler_ext, 2_000; @@ -240,14 +251,12 @@ end # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". sampler = initialize_mh_rw(model) - sampler_ext = DynamicPPL.Sampler( - externalsampler(sampler; unconstrained=true) - ) + sampler_ext = externalsampler(sampler; unconstrained=true) @testset "initial_params" begin test_initial_params(model, sampler_ext) end @testset "inference" begin - DynamicPPL.TestUtils.test_sampler( + test_sampler_analytical( [model], sampler_ext, 2_000; @@ -274,12 +283,12 @@ end # @testset "MH with prior proposal" begin # @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # sampler = initialize_mh_with_prior_proposal(model); - # sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; unconstrained=false)) + # sampler_ext = externalsampler(sampler; unconstrained=false) # @testset "initial_params" begin # test_initial_params(model, sampler_ext) # end # @testset "inference" begin - # DynamicPPL.TestUtils.test_sampler( + # test_sampler_analytical( # [model], # sampler_ext, # 10_000; diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 634fcc98d0..1e3d5856c6 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -134,26 +134,24 @@ end # Test that the samplers are being called in the correct order, on the correct target # variables. +# @testset "Sampler call order" begin # A wrapper around inference algorithms to allow intercepting the dispatch cascade to # collect testing information. - struct AlgWrapper{Alg<:Inference.InferenceAlgorithm} <: Inference.InferenceAlgorithm + struct AlgWrapper{Alg<:AbstractMCMC.AbstractSampler} <: AbstractMCMC.AbstractSampler inner::Alg end - unwrap_sampler(sampler::DynamicPPL.Sampler{<:AlgWrapper}) = - DynamicPPL.Sampler(sampler.alg.inner) - # Methods we need to define to be able to use AlgWrapper instead of an actual algorithm. # They all just propagate the call to the inner algorithm. Inference.isgibbscomponent(wrap::AlgWrapper) = Inference.isgibbscomponent(wrap.inner) function Inference.setparams_varinfo!!( model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, state, params::DynamicPPL.AbstractVarInfo, ) - return Inference.setparams_varinfo!!(model, unwrap_sampler(sampler), state, params) + return Inference.setparams_varinfo!!(model, sampler.inner, state, params) end # targets_and_algs will be a list of tuples, where the first element is the target_vns @@ -175,25 +173,23 @@ end function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, args...; kwargs..., ) - capture_targets_and_algs(sampler.alg.inner, model.context) - return AbstractMCMC.step(rng, model, unwrap_sampler(sampler), args...; kwargs...) + capture_targets_and_algs(sampler.inner, model.context) + return AbstractMCMC.step(rng, model, sampler.inner, args...; kwargs...) end - function DynamicPPL.initialstep( + function Turing.Inference.initialstep( rng::Random.AbstractRNG, model::DynamicPPL.Model, - sampler::DynamicPPL.Sampler{<:AlgWrapper}, + sampler::AlgWrapper, args...; kwargs..., ) - capture_targets_and_algs(sampler.alg.inner, model.context) - return DynamicPPL.initialstep( - rng, model, unwrap_sampler(sampler), args...; kwargs... - ) + capture_targets_and_algs(sampler.inner, model.context) + return Turing.Inference.initialstep(rng, model, sampler.inner, args...; kwargs...) end struct Wrapper{T<:Real} @@ -279,7 +275,7 @@ end @testset "Gibbs warmup" begin # An inference algorithm, for testing purposes, that records how many warm-up steps # and how many non-warm-up steps haven been taken. - mutable struct WarmupCounter <: Inference.InferenceAlgorithm + mutable struct WarmupCounter <: AbstractMCMC.AbstractSampler warmup_init_count::Int non_warmup_init_count::Int warmup_count::Int @@ -298,7 +294,7 @@ end Turing.Inference.get_varinfo(state::VarInfoState) = state.vi function Turing.Inference.setparams_varinfo!!( ::DynamicPPL.Model, - ::DynamicPPL.Sampler, + ::WarmupCounter, ::VarInfoState, params::DynamicPPL.AbstractVarInfo, ) @@ -306,23 +302,17 @@ end end function AbstractMCMC.step( - ::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}; - kwargs..., + ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... ) - spl.alg.non_warmup_init_count += 1 + spl.non_warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end function AbstractMCMC.step_warmup( - ::Random.AbstractRNG, - model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}; - kwargs..., + ::Random.AbstractRNG, model::DynamicPPL.Model, spl::WarmupCounter; kwargs... ) - spl.alg.warmup_init_count += 1 + spl.warmup_init_count += 1 vi = DynamicPPL.VarInfo(model) return (Turing.Inference.Transition(model, vi, nothing), VarInfoState(vi)) end @@ -330,22 +320,22 @@ end function AbstractMCMC.step( ::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}, + spl::WarmupCounter, s::VarInfoState; kwargs..., ) - spl.alg.non_warmup_count += 1 + spl.non_warmup_count += 1 return Turing.Inference.Transition(model, s.vi, nothing), s end function AbstractMCMC.step_warmup( ::Random.AbstractRNG, model::DynamicPPL.Model, - spl::DynamicPPL.Sampler{<:WarmupCounter}, + spl::WarmupCounter, s::VarInfoState; kwargs..., ) - spl.alg.warmup_count += 1 + spl.warmup_count += 1 return Turing.Inference.Transition(model, s.vi, nothing), s end @@ -403,9 +393,6 @@ end @test sample(gdemo_default, s4, N) isa MCMCChains.Chains @test sample(gdemo_default, s5, N) isa MCMCChains.Chains @test sample(gdemo_default, s6, N) isa MCMCChains.Chains - - g = DynamicPPL.Sampler(s3) - @test sample(gdemo_default, g, N) isa MCMCChains.Chains end # Test various combinations of samplers against models for which we know the analytical @@ -489,7 +476,7 @@ end @nospecialize function AbstractMCMC.bundle_samples( samples::Vector, ::typeof(model), - ::DynamicPPL.Sampler{<:Gibbs}, + ::Gibbs, state, ::Type{MCMCChains.Chains}; kwargs..., @@ -673,14 +660,10 @@ end @testset "$sampler" for sampler in samplers # Check that taking steps performs as expected. rng = Random.default_rng() - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(sampler) - ) + transition, state = AbstractMCMC.step(rng, model, sampler) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(sampler), state - ) + transition, state = AbstractMCMC.step(rng, model, sampler, state) check_transition_varnames(transition, vns) end end @@ -693,13 +676,9 @@ end num_chains = 4 # Determine initial parameters to make comparison as fair as possible. + # posterior_mean returns a NamedTuple so we can plug it in directly. posterior_mean = DynamicPPL.TestUtils.posterior_mean(model) - initial_params = DynamicPPL.TestUtils.update_values!!( - DynamicPPL.VarInfo(model), - posterior_mean, - DynamicPPL.TestUtils.varnames(model), - )[:] - initial_params = fill(initial_params, num_chains) + initial_params = fill(InitFromParams(posterior_mean), num_chains) # Sampler to use for Gibbs components. hmc = HMC(0.1, 32) @@ -754,36 +733,32 @@ end @testset "with both `s` and `m` as random" begin model = gdemo(1.5, 2.0) vns = (@varname(s), @varname(m)) - alg = Gibbs(vns => MH()) + spl = Gibbs(vns => MH()) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # `sample` Random.seed!(42) - chain = sample(model, alg, 1_000; progress=false) + chain = sample(model, spl, 1_000; progress=false) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.4) end @testset "without `m` as random" begin model = gdemo(1.5, 2.0) | (m=7 / 6,) vns = (@varname(s),) - alg = Gibbs(vns => MH()) + spl = Gibbs(vns => MH()) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end end @@ -825,7 +800,7 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default - alg = Gibbs( + spl = Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), @@ -839,25 +814,23 @@ end @varname(mu2) ) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # Sample! Random.seed!(42) - chain = sample(MoGtest_default, alg, 1000; progress=false) + chain = sample(MoGtest_default, spl, 1000; progress=false) check_MoGtest_default(chain; atol=0.2) end @testset "CSMC + ESS (usage of implicit varname)" begin rng = Random.default_rng() model = MoGtest_default_z_vector - alg = Gibbs(@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS()) + spl = Gibbs(@varname(z) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS()) vns = ( @varname(z[1]), @varname(z[2]), @@ -867,18 +840,16 @@ end @varname(mu2) ) # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + transition, state = AbstractMCMC.step(rng, model, spl) check_transition_varnames(transition, vns) for _ in 1:5 - transition, state = AbstractMCMC.step( - rng, model, DynamicPPL.Sampler(alg), state - ) + transition, state = AbstractMCMC.step(rng, model, spl, state) check_transition_varnames(transition, vns) end # Sample! Random.seed!(42) - chain = sample(model, alg, 1000; progress=false) + chain = sample(model, spl, 1000; progress=false) check_MoGtest_default_z_vector(chain; atol=0.2) end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 5f811b31d2..c6b5af2162 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -4,7 +4,7 @@ using ..Models: gdemo_default using ..NumericalTests: check_gdemo, check_numerical using Bijectors: Bijectors using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample -using DynamicPPL: DynamicPPL, Sampler +using DynamicPPL: DynamicPPL import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -171,6 +171,32 @@ using Turing @test Array(res1) == Array(res2) == Array(res3) end + @testset "initial params are respected" begin + @model demo_norm() = x ~ Beta(2, 2) + init_x = 0.5 + @testset "$spl_name" for (spl_name, spl) in + (("HMC", HMC(0.1, 10)), ("NUTS", NUTS())) + chain = sample( + demo_norm(), + spl, + 5; + discard_adapt=false, + initial_params=InitFromParams((x=init_x,)), + ) + @test chain[:x][1] == init_x + chain = sample( + demo_norm(), + spl, + MCMCThreads(), + 5, + 5; + discard_adapt=false, + initial_params=(fill(InitFromParams((x=init_x,)), 5)), + ) + @test all(chain[:x][1, :] .== init_x) + end + end + @testset "warning for difficult init params" begin attempt = 0 @model function demo_warn_initial_params() @@ -180,12 +206,11 @@ using Turing end end - @test_logs ( - :warn, - "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", - ) (:info,) match_mode = :any begin - sample(demo_warn_initial_params(), NUTS(), 5) - end + # verbose=false to suppress the initial step size notification, which messes with + # the test + @test_logs (:warn, r"consider providing a different initialisation strategy") sample( + demo_warn_initial_params(), NUTS(), 5; verbose=false + ) end @testset "error for impossible model" begin @@ -211,7 +236,7 @@ using Turing 10; nadapts=0, discard_adapt=false, - initial_state=chn1.info.samplerstate, + initial_state=loadstate(chn1), ) # if chn2 uses initial_state, its first sample should be somewhere around 5. if # initial_state isn't used, it will be sampled from [-2, 2] so this test should fail @@ -252,7 +277,8 @@ using Turing model = buggy_model() num_samples = 1_000 - chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0]) + initial_params = InitFromParams((lb=0.5, ub=1.75, x=1.0)) + chain = sample(model, NUTS(), num_samples; initial_params=initial_params) chain_prior = sample(model, Prior(), num_samples) # Extract the `x` like this because running `generated_quantities` was how @@ -269,12 +295,15 @@ using Turing end @testset "getstepsize: Turing.jl#2400" begin - algs = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)] - @testset "$(alg)" for alg in algs + spls = [HMC(0.1, 10), HMCDA(0.8, 0.75), NUTS(0.5), NUTS(0, 0.5)] + @testset "$(spl)" for spl in spls # Construct a HMC state by taking a single step - spl = Sampler(alg) - hmc_state = DynamicPPL.initialstep( - Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default) + hmc_state = Turing.Inference.initialstep( + Random.default_rng(), + gdemo_default, + spl, + DynamicPPL.VarInfo(gdemo_default); + initial_params=InitFromUniform(), )[2] # Check that we can obtain the current step size @test Turing.Inference.getstepsize(spl, hmc_state) isa Float64 diff --git a/test/mcmc/is.jl b/test/mcmc/is.jl index 2811e9c866..00550d1db4 100644 --- a/test/mcmc/is.jl +++ b/test/mcmc/is.jl @@ -1,63 +1,56 @@ module ISTests -using Distributions: Normal, sample using DynamicPPL: logpdf using Random: Random +using StableRNGs: StableRNG using StatsFuns: logsumexp using Test: @test, @testset using Turing @testset "is.jl" begin - function reference(n) - as = Vector{Float64}(undef, n) - bs = Vector{Float64}(undef, n) - logps = Vector{Float64}(undef, n) + @testset "numerical accuracy" begin + function reference(n) + rng = StableRNG(468) + as = Vector{Float64}(undef, n) + bs = Vector{Float64}(undef, n) - for i in 1:n - as[i], bs[i], logps[i] = reference() + for i in 1:n + as[i] = rand(rng, Normal(4, 5)) + bs[i] = rand(rng, Normal(as[i], 1)) + end + return (as=as, bs=bs) end - logevidence = logsumexp(logps) - log(n) - return (as=as, bs=bs, logps=logps, logevidence=logevidence) - end - - function reference() - x = rand(Normal(4, 5)) - y = rand(Normal(x, 1)) - loglik = logpdf(Normal(x, 2), 3) + logpdf(Normal(y, 2), 1.5) - return x, y, loglik - end - - @model function normal() - a ~ Normal(4, 5) - 3 ~ Normal(a, 2) - b ~ Normal(a, 1) - 1.5 ~ Normal(b, 2) - return a, b - end - - alg = IS() - seed = 0 - n = 10 + @model function normal() + a ~ Normal(4, 5) + 3 ~ Normal(a, 2) + b ~ Normal(a, 1) + 1.5 ~ Normal(b, 2) + return a, b + end - model = normal() - for i in 1:100 - Random.seed!(seed) - ref = reference(n) + alg = IS() + N = 1000 + model = normal() + chain = sample(StableRNG(468), model, alg, N) + ref = reference(N) - Random.seed!(seed) - chain = sample(model, alg, n; check_model=false) - sampled = get(chain, [:a, :b, :loglikelihood]) + # Note that in general, mean(chain) will differ from mean(ref). This is because the + # sampling process introduces extra calls to rand(), etc. which changes the output. + # These tests therefore are only meant to check that the results are qualitatively + # similar to the reference implementation of IS, and hence the atol is set to + # something fairly large. + @test isapprox(mean(chain[:a]), mean(ref.as); atol=0.1) + @test isapprox(mean(chain[:b]), mean(ref.bs); atol=0.1) - @test vec(sampled.a) == ref.as - @test vec(sampled.b) == ref.bs - @test vec(sampled.loglikelihood) == ref.logps - @test chain.logevidence == ref.logevidence + function expected_loglikelihoods(as, bs) + return logpdf.(Normal.(as, 2), 3) .+ logpdf.(Normal.(bs, 2), 1.5) + end + @test isapprox(chain[:loglikelihood], expected_loglikelihoods(chain[:a], chain[:b])) + @test isapprox(chain.logevidence, logsumexp(chain[:loglikelihood]) - log(N)) end @testset "logevidence" begin - Random.seed!(100) - @model function test() a ~ Normal(0, 1) x ~ Bernoulli(1) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 70810e1643..7c19f022b3 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -4,7 +4,6 @@ using AdvancedMH: AdvancedMH using Distributions: Bernoulli, Dirichlet, Exponential, InverseGamma, LogNormal, MvNormal, Normal, sample using DynamicPPL: DynamicPPL -using DynamicPPL: Sampler using LinearAlgebra: I using Random: Random using StableRNGs: StableRNG @@ -49,7 +48,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Set the initial parameters, because if we get unlucky with the initial state, # these chains are too short to converge to reasonable numbers. discard_initial = 1_000 - initial_params = [1.0, 1.0] + initial_params = InitFromParams((s=1.0, m=1.0)) @testset "gdemo_default" begin alg = MH() @@ -72,7 +71,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) chain = sample( StableRNG(seed), gdemo_default, alg, 10_000; discard_initial, initial_params ) - check_gdemo(chain; atol=0.1) + check_gdemo(chain; atol=0.15) end @testset "MoGtest_default with Gibbs" begin @@ -81,13 +80,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) @varname(mu1) => MH((:mu1, GKernel(1))), @varname(mu2) => MH((:mu2, GKernel(1))), ) + initial_params = InitFromParams(( + mu1=1.0, mu2=1.0, z1=0.0, z2=0.0, z3=1.0, z4=1.0 + )) chain = sample( StableRNG(seed), MoGtest_default, gibbs, 500; discard_initial=100, - initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0], + initial_params=initial_params, ) check_MoGtest_default(chain; atol=0.2) end @@ -113,7 +115,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) end model = M(zeros(2), I, 1) - sampler = Inference.Sampler(MH()) + sampler = MH() dt, vt = Inference.dist_val_tuple(sampler, DynamicPPL.VarInfo(model)) @@ -184,7 +186,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Test that the small variance version is actually smaller. variance_small = var(diff(Array(chn_small["μ[1]"]); dims=1)) variance_big = var(diff(Array(chn_big["μ[1]"]); dims=1)) - @test variance_small < variance_big / 1_000.0 + @test variance_small < variance_big / 100.0 end @testset "vector of multivariate distributions" begin @@ -228,38 +230,34 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Don't link when no proposals are given since we're using priors # as proposals. vi = deepcopy(vi_base) - alg = MH() - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) - @test !DynamicPPL.islinked(vi) + spl = MH() + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) + @test !DynamicPPL.is_transformed(vi) # Link if proposal is `AdvancedHM.RandomWalkProposal` vi = deepcopy(vi_base) d = length(vi_base[:]) - alg = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I))) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) - @test DynamicPPL.islinked(vi) + spl = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I))) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) + @test DynamicPPL.is_transformed(vi) # Link if ALL proposals are `AdvancedHM.RandomWalkProposal`. vi = deepcopy(vi_base) - alg = MH(:s => AdvancedMH.RandomWalkProposal(Normal())) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) - @test DynamicPPL.islinked(vi) + spl = MH(:s => AdvancedMH.RandomWalkProposal(Normal())) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) + @test DynamicPPL.is_transformed(vi) # Don't link if at least one proposal is NOT `RandomWalkProposal`. # TODO: make it so that only those that are using `RandomWalkProposal` # are linked! I.e. resolve https://github.com/TuringLang/Turing.jl/issues/1583. # https://github.com/TuringLang/Turing.jl/pull/1582#issuecomment-817148192 vi = deepcopy(vi_base) - alg = MH( + spl = MH( :m => AdvancedMH.StaticProposal(Normal()), :s => AdvancedMH.RandomWalkProposal(Normal()), ) - spl = DynamicPPL.Sampler(alg) - vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default) - @test !DynamicPPL.islinked(vi) + vi = Turing.Inference.maybe_link!!(vi, spl, spl.proposals, gdemo_default) + @test !DynamicPPL.is_transformed(vi) end @testset "`filldist` proposal (issue #2180)" begin diff --git a/test/mcmc/repeat_sampler.jl b/test/mcmc/repeat_sampler.jl index d2ca427dfd..1a22884029 100644 --- a/test/mcmc/repeat_sampler.jl +++ b/test/mcmc/repeat_sampler.jl @@ -1,9 +1,8 @@ module RepeatSamplerTests using ..Models: gdemo_default -using DynamicPPL: Sampler -using MCMCChains: Chains -using StableRNGs: StableRNG +using MCMCChains: MCMCChains +using Random: Xoshiro using Test: @test, @testset using Turing @@ -14,10 +13,12 @@ using Turing num_samples = 10 num_chains = 2 - rng = StableRNG(0) - for sampler in [MH(), Sampler(HMC(0.01, 4))] + # Use Xoshiro instead of StableRNGs as the output should always be + # similar regardless of what kind of random seed is used (as long + # as there is a random seed). + for sampler in [MH(), HMC(0.01, 4)] chn1 = sample( - copy(rng), + Xoshiro(0), gdemo_default, sampler, MCMCThreads(), @@ -27,15 +28,17 @@ using Turing ) repeat_sampler = RepeatSampler(sampler, num_repeats) chn2 = sample( - copy(rng), + Xoshiro(0), gdemo_default, repeat_sampler, MCMCThreads(), num_samples, - num_chains; - chain_type=Chains, + num_chains, ) - @test chn1.value == chn2.value + # isequal to avoid comparing `missing`s in chain stats + @test chn1 isa MCMCChains.Chains + @test chn2 isa MCMCChains.Chains + @test isequal(chn1.value, chn2.value) end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index ee943270cd..e08137109d 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -18,13 +18,6 @@ using Turing @testset "sghmc constructor" begin alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) @test alg isa SGHMC - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGHMC} - - alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1) - @test alg isa SGHMC - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGHMC} end @testset "sghmc inference" begin @@ -43,20 +36,13 @@ end @testset "sgld constructor" begin alg = SGLD(; stepsize=PolynomialStepsize(0.25)) @test alg isa SGLD - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGLD} - - alg = SGLD(; stepsize=PolynomialStepsize(0.25)) - @test alg isa SGLD - sampler = DynamicPPL.Sampler(alg) - @test sampler isa DynamicPPL.Sampler{<:SGLD} end @testset "sgld inference" begin rng = StableRNG(1) chain = sample(rng, gdemo_default, SGLD(; stepsize=PolynomialStepsize(0.5)), 20_000) - check_gdemo(chain; atol=0.2) + check_gdemo(chain; atol=0.25) # Weight samples by step sizes (cf section 4.2 in the paper by Welling and Teh) v = get(chain, [:SGLD_stepsize, :s, :m]) diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 269a71acb5..d93895e28c 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -1,6 +1,7 @@ module OptimisationTests using ..Models: gdemo, gdemo_default +using AbstractPPL: AbstractPPL using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL @@ -495,7 +496,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) @test get(result_true, vn_leaf) ≈ vals[Symbol(vn_leaf)] atol = 0.05 end end @@ -534,7 +535,7 @@ using Turing vals = result.values for vn in DynamicPPL.TestUtils.varnames(model) - for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn)) if model.f in allowed_incorrect_mle @test isfinite(get(result_true, vn_leaf)) else diff --git a/test/runtests.jl b/test/runtests.jl index 5fb6b21411..81b4bdde20 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,7 @@ end end @testset "samplers (without AD)" verbose = true begin + @timeit_include("mcmc/abstractmcmc.jl") @timeit_include("mcmc/particle_mcmc.jl") @timeit_include("mcmc/emcee.jl") @timeit_include("mcmc/ess.jl") diff --git a/test/stdlib/distributions.jl b/test/stdlib/distributions.jl index 0f8a7c7183..56c2e59b13 100644 --- a/test/stdlib/distributions.jl +++ b/test/stdlib/distributions.jl @@ -51,8 +51,6 @@ using Turing end @testset "single distribution correctness" begin - rng = StableRNG(1) - n_samples = 10_000 mean_tol = 0.1 var_atol = 1.0 @@ -132,7 +130,14 @@ using Turing @model m() = x ~ dist - chn = sample(rng, m(), HMC(0.05, 20), n_samples) + seed = if dist isa GeneralizedExtremeValue + # GEV is prone to giving really wacky results that are quite + # seed-dependent. + StableRNG(469) + else + StableRNG(468) + end + chn = sample(seed, m(), HMC(0.05, 20), n_samples) # Numerical tests. check_dist_numerical( diff --git a/test/test_utils/sampler.jl b/test/test_utils/sampler.jl index 32a3647f98..a2ca123b11 100644 --- a/test/test_utils/sampler.jl +++ b/test/test_utils/sampler.jl @@ -1,5 +1,9 @@ module SamplerTestUtils +using AbstractMCMC +using AbstractPPL +using DynamicPPL +using Random using Turing using Test @@ -24,4 +28,71 @@ function test_chain_logp_metadata(spl) @test chn[:lp] ≈ chn[:logprior] + chn[:loglikelihood] end +""" +Check that sampling is deterministic when using the same RNG seed. +""" +function test_rng_respected(spl) + @model function f(z) + # put at least two variables here so that we can meaningfully test Gibbs + x ~ Normal() + y ~ Normal() + return z ~ Normal(x + y) + end + model = f(2.0) + chn1 = sample(Xoshiro(468), model, spl, 100) + chn2 = sample(Xoshiro(468), model, spl, 100) + @test isapprox(chn1[:x], chn2[:x]) + @test isapprox(chn1[:y], chn2[:y]) +end + +""" + test_sampler_analytical(models, sampler, args...; kwargs...) + +Test that `sampler` produces correct marginal posterior means on each model in `models`. + +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` +and `sampler` to produce a `chain`, and then checks the chain's mean for every (leaf) +varname `vn` against the corresponding value returned by +`DynamicPPL.TestUtils.posterior_mean` for each model. + +For this to work, each model in `models` must have a known analytical posterior mean +that can be computed by `DynamicPPL.TestUtils.posterior_mean`. + +# Arguments +- `models`: A collection of instances of `DynamicPPL.Model` to test on. +- `sampler`: The `AbstractMCMC.AbstractSampler` to test. +- `args...`: Arguments forwarded to `sample`. + +# Keyword arguments +- `varnames_filter`: A filter to apply to `varnames(model)`, allowing comparison for only + a subset of the varnames. +- `atol=1e-1`: Absolute tolerance used in `@test`. +- `rtol=1e-3`: Relative tolerance used in `@test`. +- `kwargs...`: Keyword arguments forwarded to `sample`. +""" +function test_sampler_analytical( + models, + sampler::AbstractMCMC.AbstractSampler, + args...; + varnames_filter=Returns(true), + atol=1e-1, + rtol=1e-3, + sampler_name=typeof(sampler), + kwargs..., +) + @testset "$(sampler_name) on $(nameof(model))" for model in models + chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + target_values = DynamicPPL.TestUtils.posterior_mean(model) + for vn in filter(varnames_filter, DynamicPPL.TestUtils.varnames(model)) + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = mean(chain[Symbol(vn_leaf)]) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end + end + end +end + end