-
Notifications
You must be signed in to change notification settings - Fork 36
subset
and merge
for VarInfo
(clean version)
#544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
028a81a
caa6e25
cac7fa8
5e41c4f
d5a2631
0ade696
db21844
1dbca4c
b67288f
e43029e
cd4033d
8f47dfe
aba9008
3b621ae
2c2c90b
5c1ece3
cfff96c
ed5d948
00c36cf
cf02816
d02cb61
7f01ada
14105e0
c164d32
743162a
dc9ad94
2f320e6
d3a9b56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ export AbstractVarInfo, | |
SimpleVarInfo, | ||
push!!, | ||
empty!!, | ||
subset, | ||
getlogp, | ||
setlogp!!, | ||
acclogp!!, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,6 +236,344 @@ else | |
_tail(nt::NamedTuple) = Base.tail(nt) | ||
end | ||
|
||
# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert | ||
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which | ||
# might result in a `Vector{Any}`. | ||
""" | ||
subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) | ||
|
||
Subset a `varinfo` to only contain the variables `vns`. | ||
|
||
!!! warning | ||
The ordering of the variables in the resulting `varinfo` will _not_ | ||
necessarily follow the ordering of the variables in `varinfo`. | ||
Hence care must be taken, in particular when used in conjunction with | ||
other methods which uses the vector-representation of `varinfo`, e.g. | ||
`getindex(varinfo, sampler)` | ||
|
||
# Examples | ||
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) | ||
julia> @model function demo() | ||
s ~ InverseGamma(2, 3) | ||
m ~ Normal(0, sqrt(s)) | ||
x = Vector{Float64}(undef, 2) | ||
x[1] ~ Normal(m, sqrt(s)) | ||
x[2] ~ Normal(m, sqrt(s)) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> model = demo(); | ||
|
||
julia> varinfo = VarInfo(model); | ||
|
||
julia> keys(varinfo) | ||
4-element Vector{VarName}: | ||
s | ||
m | ||
x[1] | ||
x[2] | ||
|
||
julia> for (i, vn) in enumerate(keys(varinfo)) | ||
varinfo[vn] = i | ||
end | ||
|
||
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] | ||
4-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
3.0 | ||
4.0 | ||
|
||
julia> # Extract one with only `m`. | ||
varinfo_subset1 = subset(varinfo, [@varname(m),]); | ||
|
||
|
||
julia> keys(varinfo_subset1) | ||
1-element Vector{VarName{:m, Setfield.IdentityLens}}: | ||
m | ||
|
||
julia> varinfo_subset1[@varname(m)] | ||
2.0 | ||
|
||
julia> # Extract one with both `s` and `x[2]`. | ||
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); | ||
|
||
julia> keys(varinfo_subset2) | ||
2-element Vector{VarName}: | ||
s | ||
x[2] | ||
|
||
julia> varinfo_subset2[[@varname(s), @varname(x[2])]] | ||
2-element Vector{Float64}: | ||
1.0 | ||
4.0 | ||
``` | ||
|
||
`subset` is particularly useful when combined with [`merge(varinfo_left::VarInfo, varinfo_right::VarInfo)`](@ref) | ||
|
||
```jldoctest varinfo-subset | ||
julia> # Merge the two. | ||
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); | ||
|
||
julia> keys(varinfo_subset_merged) | ||
3-element Vector{VarName}: | ||
m | ||
s | ||
x[2] | ||
|
||
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] | ||
3-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
4.0 | ||
|
||
julia> # Merge the two with the original. | ||
varinfo_merged = merge(varinfo, varinfo_subset_merged); | ||
|
||
julia> keys(varinfo_merged) | ||
4-element Vector{VarName}: | ||
s | ||
m | ||
x[1] | ||
x[2] | ||
|
||
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] | ||
4-element Vector{Float64}: | ||
1.0 | ||
2.0 | ||
3.0 | ||
4.0 | ||
``` | ||
|
||
# Notes | ||
|
||
## Type-stability | ||
|
||
!!! warning | ||
This function is only type-stable when `vns` contains only varnames | ||
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will | ||
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. | ||
""" | ||
function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) | ||
metadata = subset(varinfo.metadata, vns) | ||
return VarInfo(metadata, varinfo.logp, varinfo.num_produce) | ||
end | ||
|
||
function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} | ||
# If all the variables are using the same symbol, then we can just extract that field from the metadata. | ||
metadata = subset(getfield(varinfo.metadata, sym), vns) | ||
return VarInfo(NamedTuple{(sym,)}(tuple(metadata)), varinfo.logp, varinfo.num_produce) | ||
end | ||
|
||
function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) | ||
syms = Tuple(unique(map(getsym, vns))) | ||
metadatas = map(syms) do sym | ||
subset(getfield(varinfo.metadata, sym), filter(==(sym) ∘ getsym, vns)) | ||
end | ||
|
||
return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) | ||
end | ||
|
||
function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) | ||
# TODO: Should we error if `vns` contains a variable that is not in `metadata`? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least a warning? |
||
indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) | ||
indices = Dict(vn => i for (i, vn) in enumerate(vns)) | ||
# Construct new `vals` and `ranges`. | ||
vals_original = metadata.vals | ||
ranges_original = metadata.ranges | ||
# Allocate the new `vals`. and `ranges`. | ||
vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) | ||
ranges = similar(ranges_original) | ||
# The new range `r` for `vns[i]` is offset by `offset` and | ||
# has the same length as the original range `r_original`. | ||
# The new `indices` (from above) ensures ordering according to `vns`. | ||
# NOTE: This means that the order of the variables in `vns` defines the order | ||
# in the resulting `varinfo`! This can have performance implications, e.g. | ||
# if in the model we have something like | ||
# | ||
# for i = 1:N | ||
# x[i] ~ Normal() | ||
# end | ||
# | ||
# and we then we do | ||
# | ||
# subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) | ||
# | ||
# the resulting `varinfo` will have `vals` ordered differently from the | ||
# original `varinfo`, which can have performance implications. | ||
offset = 0 | ||
for (idx, idx_original) in enumerate(indices_for_vns) | ||
r_original = ranges_original[idx_original] | ||
r = (offset + 1):(offset + length(r_original)) | ||
vals[r] = vals_original[r_original] | ||
ranges[idx] = r | ||
offset = r[end] | ||
end | ||
|
||
flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) | ||
return Metadata( | ||
indices, | ||
vns, | ||
ranges, | ||
vals, | ||
metadata.dists[indices_for_vns], | ||
metadata.gids, | ||
metadata.orders[indices_for_vns], | ||
flags, | ||
) | ||
end | ||
|
||
""" | ||
merge(varinfo_left::VarInfo, varinfo_right::VarInfo) | ||
|
||
Merge two `VarInfo` instances into one, giving precedence to `varinfo_right` when reasonable. | ||
|
||
""" | ||
function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) | ||
return _merge(varinfo_left, varinfo_right) | ||
end | ||
|
||
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) | ||
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) | ||
lp = getlogp(varinfo_left) + getlogp(varinfo_right) | ||
# TODO: Is this really the way we want to combine `num_produce`? | ||
num_produce = varinfo_left.num_produce[] + varinfo_right.num_produce[] | ||
|
||
return VarInfo(metadata, Ref(lp), Ref(num_produce)) | ||
end | ||
|
||
@generated function merge_metadata( | ||
metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} | ||
) where {names_left,names_right} | ||
names = Expr(:tuple) | ||
vals = Expr(:tuple) | ||
# Loop over `names_left` first because we want to preserve the order of the variables. | ||
for sym in names_left | ||
push!(names.args, QuoteNode(sym)) | ||
if sym in names_right | ||
push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) | ||
else | ||
push!(vals.args, :(metadata_left.$sym)) | ||
end | ||
end | ||
# Loop over remaining variables in `names_right`. | ||
names_right_only = filter(∉(names_left), names_right) | ||
for sym in names_right_only | ||
push!(names.args, QuoteNode(sym)) | ||
push!(vals.args, :(metadata_right.$sym)) | ||
end | ||
|
||
return :(NamedTuple{$names}($vals)) | ||
end | ||
|
||
function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) | ||
# Extract the varnames. | ||
vns_left = metadata_left.vns | ||
vns_right = metadata_right.vns | ||
vns_both = union(vns_left, vns_right) | ||
|
||
# Determine `eltype` of `vals`. | ||
T_left = eltype(metadata_left.vals) | ||
T_right = eltype(metadata_right.vals) | ||
T = promote_type(T_left, T_right) | ||
# TODO: Is this necessary? | ||
if !(T <: Real) | ||
T = Real | ||
end | ||
|
||
# Determine `eltype` of `dists`. | ||
D_left = eltype(metadata_left.dists) | ||
D_right = eltype(metadata_right.dists) | ||
D = promote_type(D_left, D_right) | ||
# TODO: Is this necessary? | ||
if !(D <: Distribution) | ||
D = Distribution | ||
end | ||
|
||
# Initialize required fields for `metadata`. | ||
vns = VarName[] | ||
idcs = Dict{VarName,Int}() | ||
ranges = Vector{UnitRange{Int}}() | ||
vals = T[] | ||
dists = D[] | ||
gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` | ||
orders = Int[] | ||
flags = Dict{String,BitVector}() | ||
# Initialize the `flags`. | ||
for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) | ||
flags[k] = BitVector() | ||
end | ||
|
||
# Range offset. | ||
offset = 0 | ||
|
||
for (idx, vn) in enumerate(vns_both) | ||
# `idcs` | ||
idcs[vn] = idx | ||
# `vns` | ||
push!(vns, vn) | ||
if vn in vns_left && vn in vns_right | ||
# `vals`: only valid if they're the length. | ||
vals_left = getval(metadata_left, vn) | ||
vals_right = getval(metadata_right, vn) | ||
@assert length(vals_left) == length(vals_right) | ||
append!(vals, vals_right) | ||
# `ranges` | ||
r = (offset + 1):(offset + length(vals_left)) | ||
push!(ranges, r) | ||
offset = r[end] | ||
# `dists`: only valid if they're the same. | ||
dists_left = getdist(metadata_left, vn) | ||
dists_right = getdist(metadata_right, vn) | ||
@assert dists_left == dists_right | ||
push!(dists, dists_left) | ||
# `orders`: giving precedence to `metadata_right` | ||
push!(orders, getorder(metadata_right, vn)) | ||
# `flags` | ||
for k in keys(flags) | ||
# Using `metadata_right`; should we? | ||
push!(flags[k], is_flagged(metadata_right, vn, k)) | ||
end | ||
elseif vn in vns_left | ||
# Just extract the metadata from `metadata_left`. | ||
# `vals` | ||
vals_left = getval(metadata_left, vn) | ||
append!(vals, vals_left) | ||
# `ranges` | ||
r = (offset + 1):(offset + length(vals_left)) | ||
push!(ranges, r) | ||
offset = r[end] | ||
# `dists` | ||
dists_left = getdist(metadata_left, vn) | ||
push!(dists, dists_left) | ||
# `orders` | ||
push!(orders, getorder(metadata_left, vn)) | ||
# `flags` | ||
for k in keys(flags) | ||
push!(flags[k], is_flagged(metadata_left, vn, k)) | ||
end | ||
else | ||
# Just extract the metadata from `metadata_right`. | ||
# `vals` | ||
vals_right = getvals(metadata_right, vn) | ||
append!(vals, vals_right) | ||
# `ranges` | ||
r = (offset + 1):(offset + length(vals_right)) | ||
push!(ranges, r) | ||
offset = r[end] | ||
# `dists` | ||
dists_right = getdist(metadata_right, vn) | ||
push!(dists, dists_right) | ||
# `orders` | ||
push!(orders, getorder(metadata_right, vn)) | ||
# `flags` | ||
for k in keys(flags) | ||
push!(flags[k], is_flagged(metadata_right, vn, k)) | ||
end | ||
end | ||
end | ||
|
||
return Metadata(idcs, vns, ranges, vals, dists, gids, orders, flags) | ||
end | ||
|
||
const VarView = Union{Int,UnitRange,Vector{Int}} | ||
|
||
""" | ||
|
@@ -1331,6 +1669,15 @@ function setorder!(vi::VarInfo, vn::VarName, index::Int) | |
return vi | ||
end | ||
|
||
""" | ||
getorder(vi::VarInfo, vn::VarName) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this API -- we can consider depreciating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused. Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements | ||
run before sampling `vn`. | ||
""" | ||
getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) | ||
getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] | ||
|
||
####################################### | ||
# Rand & replaying method for VarInfo # | ||
####################################### | ||
|
@@ -1341,7 +1688,10 @@ end | |
Check whether `vn` has a true value for `flag` in `vi`. | ||
""" | ||
function is_flagged(vi::VarInfo, vn::VarName, flag::String) | ||
return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] | ||
return is_flagged(getmetadata(vi, vn), vn, flag) | ||
end | ||
function is_flagged(metadata::Metadata, vn::VarName, flag::String) | ||
return metadata.flags[flag][getidx(metadata, vn)] | ||
end | ||
|
||
""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.