@@ -36,6 +36,113 @@ function chain_sample_to_varname_dict(
3636 return d
3737end
3838
39+ """
40+ DynamicPPL.to_chains(
41+ ::Type{MCMCChains.Chains},
42+ params_and_stats::AbstractArray{<:ParamsWithStats}
43+ )
44+
45+ Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+ """
47+ function DynamicPPL. to_chains (
48+ :: Type{MCMCChains.Chains} ,
49+ params_and_stats:: AbstractMatrix{<:DynamicPPL.ParamsWithStats} ,
50+ )
51+ # Handle parameters
52+ all_vn_leaves = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
53+ split_dicts = map (params_and_stats) do ps
54+ # Separate into individual VarNames.
55+ vn_leaves_and_vals = if isempty (ps. params)
56+ Tuple{DynamicPPL. VarName,Any}[]
57+ else
58+ iters = map (
59+ AbstractPPL. varname_and_value_leaves,
60+ keys (ps. params),
61+ values (ps. params),
62+ )
63+ mapreduce (collect, vcat, iters)
64+ end
65+ vn_leaves = map (first, vn_leaves_and_vals)
66+ vals = map (last, vn_leaves_and_vals)
67+ for vn_leaf in vn_leaves
68+ push! (all_vn_leaves, vn_leaf)
69+ end
70+ DynamicPPL. OrderedCollections. OrderedDict (zip (vn_leaves, vals))
71+ end
72+ vn_leaves = collect (all_vn_leaves)
73+ param_vals = [
74+ get (split_dicts[i, j], key, missing ) for i in eachindex (axes (split_dicts, 1 )),
75+ key in vn_leaves, j in eachindex (axes (split_dicts, 2 ))
76+ ]
77+ param_symbols = map (Symbol, vn_leaves)
78+ # Handle statistics
79+ stat_keys = DynamicPPL. OrderedCollections. OrderedSet {Symbol} ()
80+ for ps in params_and_stats
81+ for k in keys (ps. stats)
82+ push! (stat_keys, k)
83+ end
84+ end
85+ stat_keys = collect (stat_keys)
86+ stat_vals = [
87+ get (params_and_stats[i, j]. stats, key, missing ) for
88+ i in eachindex (axes (params_and_stats, 1 )), key in stat_keys,
89+ j in eachindex (axes (params_and_stats, 2 ))
90+ ]
91+ # Construct name map and info
92+ name_map = (internals= stat_keys,)
93+ info = (
94+ varname_to_symbol= DynamicPPL. OrderedCollections. OrderedDict (
95+ zip (all_vn_leaves, param_symbols)
96+ ),
97+ )
98+ # Concatenate parameter and statistic values
99+ vals = cat (param_vals, stat_vals; dims= 2 )
100+ symbols = vcat (param_symbols, stat_keys)
101+ return MCMCChains. Chains (MCMCChains. concretize (vals), symbols, name_map; info= info)
102+ end
103+ function DynamicPPL. to_chains (
104+ :: Type{MCMCChains.Chains} , ps:: AbstractVector{<:DynamicPPL.ParamsWithStats}
105+ )
106+ return DynamicPPL. to_chains (MCMCChains. Chains, hcat (ps))
107+ end
108+
109+ function DynamicPPL. from_chains (
110+ :: Type{T} , chain:: MCMCChains.Chains
111+ ) where {T<: AbstractDict{<:DynamicPPL.VarName} }
112+ idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
113+ matrix = map (idxs) do (sample_idx, chain_idx)
114+ d = T ()
115+ for vn in DynamicPPL. varnames (chain)
116+ d[vn] = DynamicPPL. getindex_varname (chain, sample_idx, vn, chain_idx)
117+ end
118+ d
119+ end
120+ return matrix
121+ end
122+ function DynamicPPL. from_chains (:: Type{NamedTuple} , chain:: MCMCChains.Chains )
123+ idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
124+ matrix = map (idxs) do (sample_idx, chain_idx)
125+ get (chain[sample_idx, :, chain_idx], keys (chain); flatten= true )
126+ end
127+ return matrix
128+ end
129+ function DynamicPPL. from_chains (
130+ :: Type{DynamicPPL.ParamsWithStats} , chain:: MCMCChains.Chains
131+ )
132+ idxs = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
133+ internals_chain = MCMCChains. get_sections (chain, :internals )
134+ params = DynamicPPL. from_chains (
135+ DynamicPPL. OrderedCollections. OrderedDict{DynamicPPL. VarName,eltype (chain. value)},
136+ chain,
137+ )
138+ stats = DynamicPPL. from_chains (NamedTuple, internals_chain)
139+ return map (idxs) do (sample_idx, chain_idx)
140+ DynamicPPL. ParamsWithStats (
141+ params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
142+ )
143+ end
144+ end
145+
39146"""
40147 predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41148
@@ -110,7 +217,6 @@ function DynamicPPL.predict(
110217 DynamicPPL. VarInfo (),
111218 (
112219 DynamicPPL. LogPriorAccumulator (),
113- DynamicPPL. LogJacobianAccumulator (),
114220 DynamicPPL. LogLikelihoodAccumulator (),
115221 DynamicPPL. ValuesAsInModelAccumulator (false ),
116222 ),
@@ -129,23 +235,9 @@ function DynamicPPL.predict(
129235 varinfo,
130236 DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
131237 )
132- vals = DynamicPPL. getacc (varinfo, Val (:ValuesAsInModel )). values
133- varname_vals = mapreduce (
134- collect,
135- vcat,
136- map (AbstractPPL. varname_and_value_leaves, keys (vals), values (vals)),
137- )
138-
139- return (varname_and_values= varname_vals, logp= DynamicPPL. getlogjoint (varinfo))
238+ DynamicPPL. ParamsWithStats (varinfo, nothing )
140239 end
141-
142- chain_result = reduce (
143- MCMCChains. chainscat,
144- [
145- _predictive_samples_to_chains (predictive_samples[:, chain_idx]) for
146- chain_idx in 1 : size (predictive_samples, 2 )
147- ],
148- )
240+ chain_result = DynamicPPL. to_chains (MCMCChains. Chains, predictive_samples)
149241 parameter_names = if include_all
150242 MCMCChains. names (chain_result, :parameters )
151243 else
@@ -164,45 +256,6 @@ function DynamicPPL.predict(
164256 )
165257end
166258
167- function _predictive_samples_to_arrays (predictive_samples)
168- variable_names_set = DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
169-
170- sample_dicts = map (predictive_samples) do sample
171- varname_value_pairs = sample. varname_and_values
172- varnames = map (first, varname_value_pairs)
173- values = map (last, varname_value_pairs)
174- for varname in varnames
175- push! (variable_names_set, varname)
176- end
177-
178- return DynamicPPL. OrderedCollections. OrderedDict (zip (varnames, values))
179- end
180-
181- variable_names = collect (variable_names_set)
182- variable_values = [
183- get (sample_dicts[i], key, missing ) for i in eachindex (sample_dicts),
184- key in variable_names
185- ]
186-
187- return variable_names, variable_values
188- end
189-
190- function _predictive_samples_to_chains (predictive_samples)
191- variable_names, variable_values = _predictive_samples_to_arrays (predictive_samples)
192- variable_names_symbols = map (Symbol, variable_names)
193-
194- internal_parameters = [:lp ]
195- log_probabilities = reshape ([sample. logp for sample in predictive_samples], :, 1 )
196-
197- parameter_names = [variable_names_symbols; internal_parameters]
198- parameter_values = hcat (variable_values, log_probabilities)
199- parameter_values = MCMCChains. concretize (parameter_values)
200-
201- return MCMCChains. Chains (
202- parameter_values, parameter_names, (internals= internal_parameters,)
203- )
204- end
205-
206259"""
207260 returned(model::Model, chain::MCMCChains.Chains)
208261
0 commit comments