diff --git a/HISTORY.md b/HISTORY.md index 2b2ca2c35..d9be6da03 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,12 @@ # DynamicPPL Changelog +## 0.38.8 + +Added a new exported struct, `DynamicPPL.ParamsWithStats`. +This can broadly be used to represent the output of a model: it consists of an `OrderedDict` of `VarName` parameters and their values, along with a `stats` NamedTuple which can hold arbitrary data, such as (but not limited to) log-probabilities. + +Implemented the functions `AbstractMCMC.to_samples` and `AbstractMCMC.from_samples`, which convert between an `MCMCChains.Chains` object and a matrix of `DynamicPPL.ParamsWithStats` objects. + ## 0.38.7 Made a small tweak to DynamicPPL's compiler output to avoid potential undefined variables when resuming model functions midway through (e.g. with Libtask in Turing's SMC/PG samplers). diff --git a/Project.toml b/Project.toml index 9d013d287..c71b89bc7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.38.7" +version = "0.38.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" -AbstractMCMC = "5" +AbstractMCMC = "5.10" AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" diff --git a/docs/Project.toml b/docs/Project.toml index fed06ebde..169a1b626 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -24,6 +25,6 @@ FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" -MarginalLogDensities = "0.4" MCMCChains = "5, 6, 7" +MarginalLogDensities = "0.4" StableRNGs = "1" diff --git a/docs/make.jl b/docs/make.jl index b8f8de7bb..8ac9709ce 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,7 @@ using Distributions using DocumenterMermaid # load MCMCChains package extension to make `predict` available using MCMCChains +using AbstractMCMC: AbstractMCMC using MarginalLogDensities: MarginalLogDensities # Need this to document a method which uses a type inside the extension... diff --git a/docs/src/api.md b/docs/src/api.md index b04bd445d..bbe39fb73 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -505,3 +505,29 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` + +### Converting VarInfos to/from chains + +It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. + +This can be accomplished by first converting each `VarInfo` into a `ParamsWithStats` object: + +```@docs +DynamicPPL.ParamsWithStats +``` + +Once you have a **matrix** of these, you can convert them into a chains object using: + +```@docs +AbstractMCMC.from_samples(::Type{MCMCChains.Chains}, ::AbstractMatrix{<:DynamicPPL.ParamsWithStats}) +``` + +If you only have a vector you can use `hcat` to convert it into an `N×1` matrix first. + +Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with: + +```@docs +AbstractMCMC.to_samples(::Type{DynamicPPL.ParamsWithStats}, ::MCMCChains.Chains) +``` + +With these, you can (for example) extract the parameter dictionaries and use `InitFromParams` to re-evaluate a model at each point in the chain. diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 003372449..d8c343917 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC using MCMCChains: MCMCChains _has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names @@ -36,6 +36,110 @@ function chain_sample_to_varname_dict( return d end +""" + AbstractMCMC.from_samples( + ::Type{MCMCChains.Chains}, + params_and_stats::AbstractMatrix{<:ParamsWithStats} + ) + +Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object. +""" +function AbstractMCMC.from_samples( + ::Type{MCMCChains.Chains}, + params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats}, +) + # Handle parameters + all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() + split_dicts = map(params_and_stats) do ps + # Separate into individual VarNames. + vn_leaves_and_vals = if isempty(ps.params) + Tuple{DynamicPPL.VarName,Any}[] + else + iters = map( + AbstractPPL.varname_and_value_leaves, + keys(ps.params), + values(ps.params), + ) + mapreduce(collect, vcat, iters) + end + vn_leaves = map(first, vn_leaves_and_vals) + vals = map(last, vn_leaves_and_vals) + for vn_leaf in vn_leaves + push!(all_vn_leaves, vn_leaf) + end + DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals)) + end + vn_leaves = collect(all_vn_leaves) + param_vals = [ + get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)), + key in vn_leaves, j in eachindex(axes(split_dicts, 2)) + ] + param_symbols = map(Symbol, vn_leaves) + # Handle statistics + stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}() + for ps in params_and_stats + for k in keys(ps.stats) + push!(stat_keys, k) + end + end + stat_keys = collect(stat_keys) + stat_vals = [ + get(params_and_stats[i, j].stats, key, missing) for + i in eachindex(axes(params_and_stats, 1)), key in stat_keys, + j in eachindex(axes(params_and_stats, 2)) + ] + # Construct name map and info + name_map = (internals=stat_keys,) + info = ( + varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict( + zip(all_vn_leaves, param_symbols) + ), + ) + # Concatenate parameter and statistic values + vals = cat(param_vals, stat_vals; dims=2) + symbols = vcat(param_symbols, stat_keys) + return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info) +end + +""" + AbstractMCMC.to_samples( + ::Type{DynamicPPL.ParamsWithStats}, + chain::MCMCChains.Chains + ) + +Convert an `MCMCChains.Chains` object to an array of `DynamicPPL.ParamsWithStats`. + +For this to work, `chain` must contain the `varname_to_symbol` mapping in its `info` field. +""" +function AbstractMCMC.to_samples( + ::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains +) + idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + # Get parameters + params_matrix = map(idxs) do (sample_idx, chain_idx) + d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() + for vn in DynamicPPL.varnames(chain) + d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + end + d + end + # Statistics + stats_matrix = if :internals in MCMCChains.sections(chain) + internals_chain = MCMCChains.get_sections(chain, :internals) + map(idxs) do (sample_idx, chain_idx) + get(internals_chain[sample_idx, :, chain_idx], keys(internals_chain); flatten=true) + end + else + fill(NamedTuple(), size(idxs)) + end + # Bundle them together + return map(idxs) do (sample_idx, chain_idx) + DynamicPPL.ParamsWithStats( + params_matrix[sample_idx, chain_idx], stats_matrix[sample_idx, chain_idx] + ) + end +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -110,7 +214,6 @@ function DynamicPPL.predict( DynamicPPL.VarInfo(), ( DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(), DynamicPPL.ValuesAsInModelAccumulator(false), ), @@ -118,34 +221,17 @@ function DynamicPPL.predict( _, varinfo = DynamicPPL.init!!(model, varinfo) varinfo = DynamicPPL.typed_varinfo(varinfo) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - predictive_samples = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Resample any variables that are not present in `values_dict` + params_and_stats = AbstractMCMC.to_samples( + DynamicPPL.ParamsWithStats, parameter_only_chain + ) + predictions = map(params_and_stats) do ps _, varinfo = DynamicPPL.init!!( - rng, - model, - varinfo, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) ) - vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values - varname_vals = mapreduce( - collect, - vcat, - map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), - ) - - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) + DynamicPPL.ParamsWithStats(varinfo) end + chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - chain_result = reduce( - MCMCChains.chainscat, - [ - _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for - chain_idx in 1:size(predictive_samples, 2) - ], - ) parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -164,45 +250,6 @@ function DynamicPPL.predict( ) end -function _predictive_samples_to_arrays(predictive_samples) - variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - - sample_dicts = map(predictive_samples) do sample - varname_value_pairs = sample.varname_and_values - varnames = map(first, varname_value_pairs) - values = map(last, varname_value_pairs) - for varname in varnames - push!(variable_names_set, varname) - end - - return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) - end - - variable_names = collect(variable_names_set) - variable_values = [ - get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), - key in variable_names - ] - - return variable_names, variable_values -end - -function _predictive_samples_to_chains(predictive_samples) - variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) - variable_names_symbols = map(Symbol, variable_names) - - internal_parameters = [:lp] - log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) - - parameter_names = [variable_names_symbols; internal_parameters] - parameter_values = hcat(variable_values, log_probabilities) - parameter_values = MCMCChains.concretize(parameter_values) - - return MCMCChains.Chains( - parameter_values, parameter_names, (internals=internal_parameters,) - ) -end - """ returned(model::Model, chain::MCMCChains.Chains) @@ -266,17 +313,15 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha chain = MCMCChains.get_sections(chain_full, :parameters) varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - return map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) - # Resample any variables that are not present in `values_dict`, and - # return the model's retval. - retval, _ = DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + first( + DynamicPPL.init!!( + model, + varinfo, + DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), + ), ) - retval end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f5bd33d6d..e66f3fe11 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -126,6 +126,8 @@ export AbstractVarInfo, prefix, returned, to_submodel, + # Struct to hold model outputs + ParamsWithStats, # Convenience macros @addlogprob!, value_iterator_from_chain, @@ -169,7 +171,6 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") -include("chains.jl") include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") @@ -193,6 +194,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("chains.jl") include("bijector.jl") include("debug_utils.jl") diff --git a/src/chains.jl b/src/chains.jl index fd6564e5b..2b5976b9b 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -23,3 +23,113 @@ Return an iterator over the varnames present in `chains`. Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). """ function varnames end + +""" + ParamsWithStats + +A struct which contains parameter values extracted from a `VarInfo`, along with any +statistics associated with the VarInfo. The statistics are provided as a NamedTuple and are +optional. +""" +struct ParamsWithStats{P<:OrderedDict{<:VarName,<:Any},S<:NamedTuple} + params::P + stats::S +end + +""" + ParamsWithStats( + varinfo::AbstractVarInfo, + model::Model, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `model` with the provided `varinfo`. +Re-evaluation of the model is often necessary to obtain correct parameter values as well as +log probabilities. This is especially true when using linked VarInfos, i.e., when variables +have been transformed to unconstrained space, and if this is not done, subtle correctness +bugs may arise: see, e.g., https://github.com/TuringLang/Turing.jl/issues/2195. + +`include_colon_eq` controls whether variables on the left-hand side of `:=` are included in +the resulting parameters. + +`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log +joint) are added to the resulting statistics NamedTuple. +""" +function ParamsWithStats( + varinfo::AbstractVarInfo, + model::DynamicPPL.Model, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + varinfo = maybe_to_typed_varinfo(varinfo) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + varinfo = DynamicPPL.setaccs!!(varinfo, accs) + varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) + params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(varinfo), + loglikelihood=DynamicPPL.getloglikelihood(varinfo), + lp=DynamicPPL.getlogjoint(varinfo), + ), + ) + end + return ParamsWithStats(params, stats) +end + +# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to +# convert it to a typed varinfo first, hence this method. +# https://github.com/TuringLang/Turing.jl/issues/2604 +maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) +maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) +maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi + +""" + ParamsWithStats( + varinfo::AbstractVarInfo, + stats::NamedTuple=NamedTuple(); + include_log_probs::Bool=true, + ) + +There is one case where re-evaluation is not necessary, which is when the VarInfos all +already contain `DynamicPPL.ValuesAsInModelAccumulator`. This accumulator stores values +as seen during the model evaluation, so the values can be simply read off. In this case, +the `model` argument can be omitted, and no re-evaluation will be performed. However, it is +the caller's responsibility to ensure that `ValuesAsInModelAccumulator` is indeed present +inside `varinfo`. + +`include_log_probs` controls whether log probabilities (log prior, log likelihood, and log +joint) are added to the resulting statistics NamedTuple. +""" +function ParamsWithStats( + varinfo::AbstractVarInfo, stats::NamedTuple=NamedTuple(); include_log_probs::Bool=true +) + params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values + if include_log_probs + has_prior_acc = DynamicPPL.hasacc(varinfo, Val(:LogPrior)) + has_likelihood_acc = DynamicPPL.hasacc(varinfo, Val(:LogLikelihood)) + if has_prior_acc + stats = merge(stats, (logprior=DynamicPPL.getlogprior(varinfo),)) + end + if has_likelihood_acc + stats = merge(stats, (loglikelihood=DynamicPPL.getloglikelihood(varinfo),)) + end + if has_prior_acc && has_likelihood_acc + stats = merge(stats, (logjoint=DynamicPPL.getlogjoint(varinfo),)) + end + end + return ParamsWithStats(params, stats) +end diff --git a/test/Project.toml b/test/Project.toml index c96087d66..2dbd5b455 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" -AbstractMCMC = "5" +AbstractMCMC = "5.10" AbstractPPL = "0.13" Accessors = "0.1" Aqua = "0.8" diff --git a/test/chains.jl b/test/chains.jl new file mode 100644 index 000000000..ab0ff4475 --- /dev/null +++ b/test/chains.jl @@ -0,0 +1,69 @@ +module DynamicPPLChainsTests + +using DynamicPPL +using Distributions +using Test + +@testset "ParamsWithStats" begin + @model function f(z) + x ~ Normal() + y := x + 1 + return z ~ Normal(y) + end + z = 1.0 + model = f(z) + + @testset "with reevaluation" begin + ps = ParamsWithStats(VarInfo(model), model) + @test haskey(ps.params, @varname(x)) + @test haskey(ps.params, @varname(y)) + @test length(ps.params) == 2 + @test haskey(ps.stats, :logprior) + @test haskey(ps.stats, :loglikelihood) + @test haskey(ps.stats, :lp) + @test length(ps.stats) == 3 + @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.params[@varname(y)] ≈ ps.params[@varname(x)] + 1 + @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) + @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(y)]), z) + end + + @testset "without colon_eq" begin + ps = ParamsWithStats(VarInfo(model), model; include_colon_eq=false) + @test haskey(ps.params, @varname(x)) + @test length(ps.params) == 1 + @test haskey(ps.stats, :logprior) + @test haskey(ps.stats, :loglikelihood) + @test haskey(ps.stats, :lp) + @test length(ps.stats) == 3 + @test ps.stats.lp ≈ ps.stats.logprior + ps.stats.loglikelihood + @test ps.stats.logprior ≈ logpdf(Normal(), ps.params[@varname(x)]) + @test ps.stats.loglikelihood ≈ logpdf(Normal(ps.params[@varname(x)] + 1), z) + end + + @testset "without log probs" begin + ps = ParamsWithStats(VarInfo(model), model; include_log_probs=false) + @test haskey(ps.params, @varname(x)) + @test haskey(ps.params, @varname(y)) + @test length(ps.params) == 2 + @test isempty(ps.stats) + end + + @testset "no reevaluation" begin + # Without VAIM, it should error + @test_throws ErrorException ParamsWithStats(VarInfo(model)) + # With VAIM, it should work + vi = DynamicPPL.setaccs!!( + VarInfo(model), (DynamicPPL.ValuesAsInModelAccumulator(true),) + ) + vi = last(DynamicPPL.evaluate!!(model, vi)) + ps = ParamsWithStats(vi) + @test haskey(ps.params, @varname(x)) + @test haskey(ps.params, @varname(y)) + @test length(ps.params) == 2 + # Because we didn't evaluate with log prob accumulators, there should be no stats + @test isempty(ps.stats) + end +end + +end # module diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 79e13ad84..f537415d5 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -1,3 +1,7 @@ +module DynamicPPLMCMCChainsExtTests + +using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC + @testset "DynamicPPLMCMCChainsExt" begin @model demo() = x ~ Normal() model = demo() @@ -11,6 +15,54 @@ chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 + + @testset "from_samples" begin + @model function f(z) + x ~ Normal() + y := x + 1 + return z ~ Normal(y) + end + + z = 1.0 + model = f(z) + + @testset "matrix" begin + ps = [ParamsWithStats(VarInfo(model), model) for _ in 1:50, _ in 1:3] + c = AbstractMCMC.from_samples(MCMCChains.Chains, ps) + @test c isa MCMCChains.Chains + @test size(c, 1) == 50 + @test size(c, 3) == 3 + @test Set(c.name_map.parameters) == Set([:x, :y]) + @test Set(c.name_map.internals) == Set([:logprior, :loglikelihood, :lp]) + @test logpdf.(Normal(), c[:x]) ≈ c[:logprior] + @test c.info.varname_to_symbol[@varname(x)] == :x + @test c.info.varname_to_symbol[@varname(y)] == :y + end + end + + @testset "to_samples" begin + @model function f(z) + x ~ Normal() + y := x + 1 + return z ~ Normal(y) + end + # Make the chain first + z = 1.0 + model = f(z) + ps = hcat([ParamsWithStats(VarInfo(model), model) for _ in 1:50]) + c = AbstractMCMC.from_samples(MCMCChains.Chains, ps) + # Then convert back to ParamsWithStats + arr_pss = AbstractMCMC.to_samples(ParamsWithStats, c) + @test size(arr_pss) == (50, 1) + for i in 1:50 + new_p = arr_pss[i, 1] + p = ps[i] + @test new_p.params == p.params + @test new_p.stats == p.stats + end + end end # test for `predict` is in `test/model.jl` + +end # module diff --git a/test/runtests.jl b/test/runtests.jl index 7a9c12525..861d3bb87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,6 +73,7 @@ include("test_util.jl") include("threadsafe.jl") include("debug_utils.jl") include("submodels.jl") + include("chains.jl") include("bijector.jl") end diff --git a/test/test_util.jl b/test/test_util.jl index 164751c7b..94fdbd744 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -62,35 +62,10 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for `n_iters` iterations. """ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) - # Sample from the prior - varinfos = [VarInfo(rng, model) for _ in 1:n_iters] - # Extract all varnames found in any dictionary. Doing it this way guards - # against the possibility of having different varnames in different - # dictionaries, e.g. for models that have dynamic variables / array sizes - varnames = OrderedSet{VarName}() - # Convert each varinfo into an OrderedDict of vns => params. - # We have to use varname_and_value_leaves so that each parameter is a scalar - dicts = map(varinfos) do t - vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) - tuples = mapreduce(collect, vcat, iters) - # The following loop is a replacement for: - # push!(varnames, map(first, tuples)...) - # which causes a stack overflow if `map(first, tuples)` is too large. - # Unfortunately there isn't a union() function for OrderedSet. - for vn in map(first, tuples) - push!(varnames, vn) - end - OrderedDict(tuples) - end - # Convert back to list - varnames = collect(varnames) - # Construct matrix of values - vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] - # Construct dict of varnames -> symbol - vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) - # Construct and return the Chains object - return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) + vi = VarInfo(model) + vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.ValuesAsInModelAccumulator(false),)) + ps = hcat([ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for _ in 1:n_iters]) + return AbstractMCMC.from_samples(MCMCChains.Chains, ps) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters)