Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.21"
version = "0.24.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ DynamicPPL.reconstruct
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```
Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ export AbstractVarInfo,
invlink,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
VarName,
Expand Down
15 changes: 0 additions & 15 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -738,21 +738,6 @@ function unflatten(sampler::AbstractSampler, varinfo::AbstractVarInfo, ::Abstrac
return unflatten(varinfo, sampler, θ)
end

"""
tonamedtuple(vi::AbstractVarInfo)

Convert a `vi` into a `NamedTuple` where each variable symbol maps to the values and
indexing string of the variable.

For example, a model that had a vector of vector-valued
variables `x` would return

```julia
(x = ([1.5, 2.0], [3.0, 1.0], ["x[1]", "x[2]"]), )
```
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Expand Down
38 changes: 0 additions & 38 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,44 +532,6 @@ function dot_assume(
return value, lp, vi
end

# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl.
# TODO: Move away from using these `tonamedtuple` methods.
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names}
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

return NamedTuple{names}(nt_vals)
end

function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}()
for vn in keys(vi)
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(TestUtils.varname_leaves(vn, val))
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))

# Initialize entry if not yet initialized.
if !haskey(syms_to_result, sym)
syms_to_result[sym] = (Real[], String[])
end

# Combine with old result.
old_vals, old_string_vns = syms_to_result[sym]
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns)))
end

# Construct `NamedTuple`.
return NamedTuple(pairs(syms_to_result))
end

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down
2 changes: 0 additions & 2 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
end

tonamedtuple(vi::ThreadSafeVarInfo) = tonamedtuple(vi.varinfo)

# Transformations.
function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
return Setfield.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
Expand Down
16 changes: 0 additions & 16 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1506,22 +1506,6 @@ end
return expr
end

# TODO: Remove this completely.
tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo)
function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names}
length(names) === 0 && return NamedTuple()

vals_tuple = map(values(metadata)) do x
# NOTE: `tonamedtuple` is really only used in Turing.jl to convert to
# a "transition". This means that we really don't mutations of the values
# in `varinfo` to propoagate the previous samples. Hence we `copy.`
vals = map(copy ∘ Base.Fix1(getindex, varinfo), x.vns)
return vals, map(string, x.vns)
end

return NamedTuple{names}(vals_tuple)
end

@inline function findvns(vi, f_vns)
if length(f_vns) == 0
throw("Unidentified error, please report this error in an issue.")
Expand Down
30 changes: 15 additions & 15 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,22 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
nt = DynamicPPL.tonamedtuple(var_info)
for (k, (vals, names)) in pairs(nt)
for (n, v) in zip(names, vals)
if Symbol(n) ∉ keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) ∉ keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

Expand Down