-
Notifications
You must be signed in to change notification settings - Fork 36
link
and invlink
should correctly work with Selector
and thus Gibbs
#542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
4de2a01
1e4d9f1
bf4fcc6
f1fde0b
67770d3
4bf5f7c
d15df29
203aeb3
b170c0b
1086731
40cce25
cb2f40e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -902,33 +902,60 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) | |||||||
return vi | ||||||||
end | ||||||||
|
||||||||
# HACK: We need `SampleFromPrior` to result in ALL values which are in need | ||||||||
# of a transformation to be transformed. `_getvns` will by default return | ||||||||
# an empty iterable for `SampleFromPrior`, so we need to override it here. | ||||||||
# This is quite hacky, but seems safer than changing the behavior of `_getvns`. | ||||||||
_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) | ||||||||
_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing | ||||||||
_getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( | ||||||||
Base.Returns(nothing), | ||||||||
varinfo.metadata | ||||||||
) | ||||||||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
|
||||||||
function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) | ||||||||
return _link(varinfo) | ||||||||
return _link(varinfo, spl) | ||||||||
end | ||||||||
|
||||||||
function _link(varinfo::UntypedVarInfo) | ||||||||
function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) | ||||||||
varinfo = deepcopy(varinfo) | ||||||||
return VarInfo( | ||||||||
_link_metadata!(varinfo, varinfo.metadata), | ||||||||
_link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), | ||||||||
Base.Ref(getlogp(varinfo)), | ||||||||
Ref(get_num_produce(varinfo)), | ||||||||
) | ||||||||
end | ||||||||
|
||||||||
function _link(varinfo::TypedVarInfo) | ||||||||
function _link(varinfo::TypedVarInfo, spl::AbstractSampler) | ||||||||
varinfo = deepcopy(varinfo) | ||||||||
md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) | ||||||||
# TODO: Update logp, etc. | ||||||||
md = _link_metadata_namedtuple!( | ||||||||
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) | ||||||||
) | ||||||||
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) | ||||||||
end | ||||||||
|
||||||||
function _link_metadata!(varinfo::VarInfo, metadata::Metadata) | ||||||||
@generated function _link_metadata_namedtuple!( | ||||||||
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} | ||||||||
) where {names,space} | ||||||||
vals = Expr(:tuple) | ||||||||
for f in names | ||||||||
if inspace(f, space) || length(space) == 0 | ||||||||
push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) | ||||||||
else | ||||||||
push!(vals.args, :(metadata.$f)) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
return :(NamedTuple{$names}($vals)) | ||||||||
end | ||||||||
function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) | ||||||||
vns = metadata.vns | ||||||||
|
||||||||
# Construct the new transformed values, and keep track of their lengths. | ||||||||
vals_new = map(vns) do vn | ||||||||
# Return early if we're already in unconstrained space. | ||||||||
if istrans(varinfo, vn) | ||||||||
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not entirely clear to me why this is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||||||||
if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) | ||||||||
return metadata.vals[getrange(metadata, vn)] | ||||||||
end | ||||||||
|
||||||||
|
@@ -972,32 +999,49 @@ end | |||||||
function invlink( | ||||||||
::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model | ||||||||
) | ||||||||
return _invlink(varinfo) | ||||||||
return _invlink(varinfo, spl) | ||||||||
end | ||||||||
|
||||||||
function _invlink(varinfo::UntypedVarInfo) | ||||||||
function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) | ||||||||
varinfo = deepcopy(varinfo) | ||||||||
return VarInfo( | ||||||||
_invlink_metadata!(varinfo, varinfo.metadata), | ||||||||
_invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), | ||||||||
Base.Ref(getlogp(varinfo)), | ||||||||
Ref(get_num_produce(varinfo)), | ||||||||
) | ||||||||
end | ||||||||
|
||||||||
function _invlink(varinfo::TypedVarInfo) | ||||||||
function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) | ||||||||
varinfo = deepcopy(varinfo) | ||||||||
md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) | ||||||||
# TODO: Update logp, etc. | ||||||||
md = _invlink_metadata_namedtuple!( | ||||||||
varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) | ||||||||
) | ||||||||
return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) | ||||||||
end | ||||||||
|
||||||||
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) | ||||||||
@generated function _invlink_metadata_namedtuple!( | ||||||||
varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} | ||||||||
) where {names,space} | ||||||||
vals = Expr(:tuple) | ||||||||
for f in names | ||||||||
if inspace(f, space) || length(space) == 0 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||
push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) | ||||||||
else | ||||||||
push!(vals.args, :(metadata.$f)) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
return :(NamedTuple{$names}($vals)) | ||||||||
end | ||||||||
function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) | ||||||||
vns = metadata.vns | ||||||||
|
||||||||
# Construct the new transformed values, and keep track of their lengths. | ||||||||
vals_new = map(vns) do vn | ||||||||
# Return early if we're already in constrained space. | ||||||||
if !istrans(varinfo, vn) | ||||||||
# Return early if we're already in constrained space OR if we're not | ||||||||
# supposed to touch this `vn`. | ||||||||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. | ||||||||
if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) | ||||||||
return metadata.vals[getrange(metadata, vn)] | ||||||||
end | ||||||||
|
||||||||
|
Uh oh!
There was an error while loading. Please reload this page.