Skip to content

Commit 11b7e01

Browse files
authored
Add to_chains and from_chains function (#1087)
* Implement `ParamsWithStats` and `to_chains` functions * actually test the tests * Fix tests * Convert untyped VarInfo to typed for performance benefits * Add `from_chains` as well
1 parent 9a2607b commit 11b7e01

File tree

10 files changed

+459
-86
lines changed

10 files changed

+459
-86
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.3
4+
5+
Added a new exported struct, `DynamicPPL.ParamsWithStats`, and a corresponding function `DynamicPPL.to_chains`, which automatically converts a collection of `ParamsWithStats` to a given Chains type.
6+
37
## 0.38.2
48

59
Added a compatibility entry for [email protected].

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.2"
3+
version = "0.38.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,21 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508+
509+
### Converting VarInfos to chains
510+
511+
It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
512+
This can be accomplished with the following:
513+
514+
```@docs
515+
DynamicPPL.ParamsWithStats
516+
DynamicPPL.to_chains
517+
```
518+
519+
Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:
520+
521+
```@docs
522+
DynamicPPL.from_chains
523+
```
524+
525+
This is useful if you want to use the result of a chain in further model evaluations.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 109 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,113 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39+
"""
40+
DynamicPPL.to_chains(
41+
::Type{MCMCChains.Chains},
42+
params_and_stats::AbstractArray{<:ParamsWithStats}
43+
)
44+
45+
Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46+
"""
47+
function DynamicPPL.to_chains(
48+
::Type{MCMCChains.Chains},
49+
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
50+
)
51+
# Handle parameters
52+
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
53+
split_dicts = map(params_and_stats) do ps
54+
# Separate into individual VarNames.
55+
vn_leaves_and_vals = if isempty(ps.params)
56+
Tuple{DynamicPPL.VarName,Any}[]
57+
else
58+
iters = map(
59+
AbstractPPL.varname_and_value_leaves,
60+
keys(ps.params),
61+
values(ps.params),
62+
)
63+
mapreduce(collect, vcat, iters)
64+
end
65+
vn_leaves = map(first, vn_leaves_and_vals)
66+
vals = map(last, vn_leaves_and_vals)
67+
for vn_leaf in vn_leaves
68+
push!(all_vn_leaves, vn_leaf)
69+
end
70+
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
71+
end
72+
vn_leaves = collect(all_vn_leaves)
73+
param_vals = [
74+
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
75+
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
76+
]
77+
param_symbols = map(Symbol, vn_leaves)
78+
# Handle statistics
79+
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
80+
for ps in params_and_stats
81+
for k in keys(ps.stats)
82+
push!(stat_keys, k)
83+
end
84+
end
85+
stat_keys = collect(stat_keys)
86+
stat_vals = [
87+
get(params_and_stats[i, j].stats, key, missing) for
88+
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
89+
j in eachindex(axes(params_and_stats, 2))
90+
]
91+
# Construct name map and info
92+
name_map = (internals=stat_keys,)
93+
info = (
94+
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
95+
zip(all_vn_leaves, param_symbols)
96+
),
97+
)
98+
# Concatenate parameter and statistic values
99+
vals = cat(param_vals, stat_vals; dims=2)
100+
symbols = vcat(param_symbols, stat_keys)
101+
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
102+
end
103+
function DynamicPPL.to_chains(
104+
::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats}
105+
)
106+
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
107+
end
108+
109+
function DynamicPPL.from_chains(
110+
::Type{T}, chain::MCMCChains.Chains
111+
) where {T<:AbstractDict{<:DynamicPPL.VarName}}
112+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
113+
matrix = map(idxs) do (sample_idx, chain_idx)
114+
d = T()
115+
for vn in DynamicPPL.varnames(chain)
116+
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
117+
end
118+
d
119+
end
120+
return matrix
121+
end
122+
function DynamicPPL.from_chains(::Type{NamedTuple}, chain::MCMCChains.Chains)
123+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
124+
matrix = map(idxs) do (sample_idx, chain_idx)
125+
get(chain[sample_idx, :, chain_idx], keys(chain); flatten=true)
126+
end
127+
return matrix
128+
end
129+
function DynamicPPL.from_chains(
130+
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
131+
)
132+
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
133+
internals_chain = MCMCChains.get_sections(chain, :internals)
134+
params = DynamicPPL.from_chains(
135+
DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,eltype(chain.value)},
136+
chain,
137+
)
138+
stats = DynamicPPL.from_chains(NamedTuple, internals_chain)
139+
return map(idxs) do (sample_idx, chain_idx)
140+
DynamicPPL.ParamsWithStats(
141+
params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
142+
)
143+
end
144+
end
145+
39146
"""
40147
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
41148
@@ -110,7 +217,6 @@ function DynamicPPL.predict(
110217
DynamicPPL.VarInfo(),
111218
(
112219
DynamicPPL.LogPriorAccumulator(),
113-
DynamicPPL.LogJacobianAccumulator(),
114220
DynamicPPL.LogLikelihoodAccumulator(),
115221
DynamicPPL.ValuesAsInModelAccumulator(false),
116222
),
@@ -129,23 +235,9 @@ function DynamicPPL.predict(
129235
varinfo,
130236
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
131237
)
132-
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
133-
varname_vals = mapreduce(
134-
collect,
135-
vcat,
136-
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
137-
)
138-
139-
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
238+
DynamicPPL.ParamsWithStats(varinfo, nothing)
140239
end
141-
142-
chain_result = reduce(
143-
MCMCChains.chainscat,
144-
[
145-
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
146-
chain_idx in 1:size(predictive_samples, 2)
147-
],
148-
)
240+
chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples)
149241
parameter_names = if include_all
150242
MCMCChains.names(chain_result, :parameters)
151243
else
@@ -164,45 +256,6 @@ function DynamicPPL.predict(
164256
)
165257
end
166258

167-
function _predictive_samples_to_arrays(predictive_samples)
168-
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
169-
170-
sample_dicts = map(predictive_samples) do sample
171-
varname_value_pairs = sample.varname_and_values
172-
varnames = map(first, varname_value_pairs)
173-
values = map(last, varname_value_pairs)
174-
for varname in varnames
175-
push!(variable_names_set, varname)
176-
end
177-
178-
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
179-
end
180-
181-
variable_names = collect(variable_names_set)
182-
variable_values = [
183-
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
184-
key in variable_names
185-
]
186-
187-
return variable_names, variable_values
188-
end
189-
190-
function _predictive_samples_to_chains(predictive_samples)
191-
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
192-
variable_names_symbols = map(Symbol, variable_names)
193-
194-
internal_parameters = [:lp]
195-
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
196-
197-
parameter_names = [variable_names_symbols; internal_parameters]
198-
parameter_values = hcat(variable_values, log_probabilities)
199-
parameter_values = MCMCChains.concretize(parameter_values)
200-
201-
return MCMCChains.Chains(
202-
parameter_values, parameter_names, (internals=internal_parameters,)
203-
)
204-
end
205-
206259
"""
207260
returned(model::Model, chain::MCMCChains.Chains)
208261

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129+
# Chain construction
130+
ParamsWithStats,
131+
to_chains,
129132
# Convenience macros
130133
@addlogprob!,
131134
value_iterator_from_chain,
@@ -194,6 +197,7 @@ include("model_utils.jl")
194197
include("extract_priors.jl")
195198
include("values_as_in_model.jl")
196199
include("bijector.jl")
200+
include("to_chains.jl")
197201

198202
include("debug_utils.jl")
199203
using .DebugUtils

0 commit comments

Comments
 (0)