From 4de2a0146191304edc7a294b86eaea53daa61716 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 7 Oct 2023 23:47:34 +0100 Subject: [PATCH 01/12] link and invlink should correctly work with Selector etc. --- src/varinfo.jl | 58 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index ddb4caffb..08b986f74 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -903,25 +903,44 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) - return _link(varinfo) + return _link(varinfo, spl) end -function _link(varinfo::UntypedVarInfo) +function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata), + _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _link(varinfo::TypedVarInfo) +function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_link_metadata!, varinfo), varinfo.metadata) + md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) # TODO: Update logp, etc. return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end +@generated function _link_metadata!( + varinfo::VarInfo, + metadata::NamedTuple{names}, + ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!( + expr.args, + :(_link_metadata!(varinfo, metadata.$f)) + ) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end function _link_metadata!(varinfo::VarInfo, metadata::Metadata) vns = metadata.vns @@ -972,25 +991,44 @@ end function invlink( ::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model ) - return _invlink(varinfo) + return _invlink(varinfo, spl) end -function _invlink(varinfo::UntypedVarInfo) +function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata), + _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) end -function _invlink(varinfo::TypedVarInfo) +function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = map(Base.Fix1(_invlink_metadata!, varinfo), varinfo.metadata) + md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) # TODO: Update logp, etc. return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end +@generated function _invlink_metadata!( + varinfo::VarInfo, + metadata::NamedTuple{names}, + ::Val{space} +) where {names,space} + vals = Expr(:tuple) + for f in names + if inspace(f, space) || length(space) == 0 + push!( + expr.args, + :(_invlink_metadata!(varinfo, metadata.$f)) + ) + else + push!(vals.args, :(metadata.$f)) + end + end + + return :(NamedTuple{$names}($vals)) +end function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) vns = metadata.vns From 1e4d9f19ee4df18af801fc50bac7e58b56f5f94a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:20:02 +0100 Subject: [PATCH 02/12] more fixes to link and invlink --- src/varinfo.jl | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 08b986f74..d56d5b22d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), + _link_metadata!(varinfo, varinfo.metadata, _getvns(spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -917,22 +917,22 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _link_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) - # TODO: Update logp, etc. + md = _link_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _link_metadata!( +@generated function _link_metadata_namedtuple!( varinfo::VarInfo, metadata::NamedTuple{names}, + vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 push!( - expr.args, - :(_link_metadata!(varinfo, metadata.$f)) + vals.args, + :(_link_metadata!(varinfo, metadata.$f, vns.$f)) ) else push!(vals.args, :(metadata.$f)) @@ -941,13 +941,13 @@ end return :(NamedTuple{$names}($vals)) end -function _link_metadata!(varinfo::VarInfo, metadata::Metadata) +function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) + if istrans(varinfo, vn) || vn ∉ target_vns return metadata.vals[getrange(metadata, vn)] end @@ -997,7 +997,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns(spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1005,22 +1005,22 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(varinfo, varinfo.metadata, Val(getspace(spl))) - # TODO: Update logp, etc. + md = _invlink_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end -@generated function _invlink_metadata!( +@generated function _invlink_metadata_namedtuple!( varinfo::VarInfo, metadata::NamedTuple{names}, + vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 push!( - expr.args, - :(_invlink_metadata!(varinfo, metadata.$f)) + vals.args, + :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)) ) else push!(vals.args, :(metadata.$f)) @@ -1029,13 +1029,14 @@ end return :(NamedTuple{$names}($vals)) end -function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata) +function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn - # Return early if we're already in constrained space. - if !istrans(varinfo, vn) + # Return early if we're already in constrained space OR if we're not + # supposed to touch this `vn`. + if !istrans(varinfo, vn) || vn ∉ target_vns return metadata.vals[getrange(metadata, vn)] end From bf4fcc66f949b1e0597cc0311d96f12332aa4f4c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:20:44 +0100 Subject: [PATCH 03/12] formatting --- src/varinfo.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index d56d5b22d..dedab5f97 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -917,23 +917,19 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _link_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) + md = _link_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _link_metadata_namedtuple!( - varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space} + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!( - vals.args, - :(_link_metadata!(varinfo, metadata.$f, vns.$f)) - ) + push!(vals.args, :(_link_metadata!(varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end @@ -1005,23 +1001,19 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) - md = _invlink_metadata_namedtuple!(varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl))) + md = _invlink_metadata_namedtuple!( + varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @generated function _invlink_metadata_namedtuple!( - varinfo::VarInfo, - metadata::NamedTuple{names}, - vns::NamedTuple, - ::Val{space} + varinfo::VarInfo, metadata::NamedTuple{names}, vns::NamedTuple, ::Val{space} ) where {names,space} vals = Expr(:tuple) for f in names if inspace(f, space) || length(space) == 0 - push!( - vals.args, - :(_invlink_metadata!(varinfo, metadata.$f, vns.$f)) - ) + push!(vals.args, :(_invlink_metadata!(varinfo, metadata.$f, vns.$f))) else push!(vals.args, :(metadata.$f)) end From f1fde0bbee3ae9213479cb78c0b4803a8ec17bc5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:35:17 +0100 Subject: [PATCH 04/12] added simple tests for usage of selectors --- test/varinfo.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index 598ea7814..7f96c071e 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,7 @@ +# A simple "algorithm" which only has `s` variables in its space. +struct MySAlg end +DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) + @testset "varinfo.jl" begin @testset "TypedVarInfo" begin @model gdemo(x, y) = begin @@ -421,4 +425,42 @@ end end end + + @testset "VarInfo with selectors" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(model) + selector = DynamicPPL.Selector() + spl = Sampler(MySAlg(), model, selector) + + vns = DynamicPPL.TestUtils.varnames(model) + vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) + vns_m = filter(vn -> DynamicPPL.getsym(vn) === :m, vns) + for vn in vns_s + DynamicPPL.updategid!(varinfo, vn, spl) + end + + # Should only get the variables subsumed by `@varname(s)`. + @test varinfo[spl] == + mapreduce(Base.Fix1(DynamicPPL.getval, varinfo), vcat, vns_s) + + # `link` + varinfo_linked = DynamicPPL.link(varinfo, spl, model) + # `s` variables should be linked + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + # `m` variables should NOT be linked + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + # And `varinfo` should be unchanged + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo), vns) + + # `invlink` + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, spl, model) + # `s` variables should no longer be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_s) + # `m` variables should still not be linked + @test all(!Base.Fix1(DynamicPPL.istrans, varinfo_invlinked), vns_m) + # And `varinfo_linked` should be unchanged + @test any(Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_s) + @test any(!Base.Fix1(DynamicPPL.istrans, varinfo_linked), vns_m) + end + end end From 67770d32a917fd066695814c1cd0bff943c1fb2b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:35:37 +0100 Subject: [PATCH 05/12] bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 08ca184bf..c9805dadb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.18" +version = "0.23.19" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 4bf5f7c7b33fb6fe6fbb28a2aded5fe56b85446f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 00:36:21 +0100 Subject: [PATCH 06/12] fied typos --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index dedab5f97..3e9d0c204 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, _getvns(spl)), + _link_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -993,7 +993,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, _getvns(spl)), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) From d15df29c74dba13655241102a428437f86e54756 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:16:15 +0100 Subject: [PATCH 07/12] added missing _getvns_link for UntypedVarInfo --- src/varinfo.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3e9d0c204..600577796 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -902,6 +902,17 @@ function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) return vi end +# HACK: We need `SampleFromPrior` to result in ALL values which are in need +# of a transformation to be transformed. `_getvns` will by default return +# an empty iterable for `SampleFromPrior`, so we need to override it here. +# This is quite hacky, but seems safer than changing the behavior of `_getvns`. +_getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) +_getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing +_getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( + Base.Returns(nothing), + _getvns(varinfo, spl) +) + function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) end @@ -909,7 +920,7 @@ end function _link(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _link_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), + _link_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -918,7 +929,7 @@ end function _link(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _link_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -943,7 +954,8 @@ function _link_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in unconstrained space. - if istrans(varinfo, vn) || vn ∉ target_vns + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -993,7 +1005,7 @@ end function _invlink(varinfo::UntypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) return VarInfo( - _invlink_metadata!(varinfo, varinfo.metadata, _getvns(varinfo, spl)), + _invlink_metadata!(varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo)), ) @@ -1002,7 +1014,7 @@ end function _invlink(varinfo::TypedVarInfo, spl::AbstractSampler) varinfo = deepcopy(varinfo) md = _invlink_metadata_namedtuple!( - varinfo, varinfo.metadata, _getvns(varinfo, spl), Val(getspace(spl)) + varinfo, varinfo.metadata, _getvns_link(varinfo, spl), Val(getspace(spl)) ) return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) end @@ -1028,7 +1040,8 @@ function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not # supposed to touch this `vn`. - if !istrans(varinfo, vn) || vn ∉ target_vns + # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. + if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end From 203aeb370cc3ee725214dbacb9da1e3475f9397d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:17:12 +0100 Subject: [PATCH 08/12] simplify `_getvns_link` for TypedVarInfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 600577796..a00bcf0ee 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -910,7 +910,7 @@ _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( Base.Returns(nothing), - _getvns(varinfo, spl) + varinfo.metadata ) function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) From b170c0b3da309798b6aa2563665d6ff4a547db33 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 00:21:23 +0100 Subject: [PATCH 09/12] Update src/varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varinfo.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a00bcf0ee..0fca2e128 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -908,10 +908,9 @@ end # This is quite hacky, but seems safer than changing the behavior of `_getvns`. _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing -_getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) = map( - Base.Returns(nothing), - varinfo.metadata -) +function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) + return map(Base.Returns(nothing), varinfo.metadata) +end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) return _link(varinfo, spl) From 1086731c5d7fbeb90b9b534351223da3fd7ba867 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 10:33:30 +0100 Subject: [PATCH 10/12] added Compat as dep so we can make use of certain features, e.g. Returns --- Project.toml | 14 ++++++++------ src/varinfo.jl | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c9805dadb..e9c88fa9f 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -21,6 +22,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" @@ -28,6 +35,7 @@ BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" ConstructionBase = "1.5.4" +Compat = "4" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" julia = "1.6" -[extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] - [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - -[weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/src/varinfo.jl b/src/varinfo.jl index 0fca2e128..dec1a96a4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -909,7 +909,7 @@ end _getvns_link(varinfo::VarInfo, spl::AbstractSampler) = _getvns(varinfo, spl) _getvns_link(varinfo::UntypedVarInfo, spl::SampleFromPrior) = nothing function _getvns_link(varinfo::TypedVarInfo, spl::SampleFromPrior) - return map(Base.Returns(nothing), varinfo.metadata) + return map(Returns(nothing), varinfo.metadata) end function link(::DynamicTransformation, varinfo::VarInfo, spl::AbstractSampler, model::Model) From 40cce2556673d8ed267ecb5554d8546e952b7354 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 9 Oct 2023 11:10:08 +0100 Subject: [PATCH 11/12] forgot using Compat --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 042931ebb..8e3a778ad 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -3,6 +3,7 @@ module DynamicPPL using AbstractMCMC: AbstractSampler, AbstractChains using AbstractPPL using Bijectors +using Compat using Distributions using OrderedCollections: OrderedDict From cb2f40ef65ac5cd51f22a6a57f6be1c53012dae1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 10 Oct 2023 01:33:36 +0100 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index dec1a96a4..3e7dc119f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1038,7 +1038,7 @@ function _invlink_metadata!(varinfo::VarInfo, metadata::Metadata, target_vns) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`. + # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)]