Skip to content

Commit 0b56415

Browse files
torfjeldeyebai
andauthored
Partial fix for #2095 (#2096)
* use immutable link in the initialstep for HMC * bump patch version * added test * Update hmc.jl --------- Co-authored-by: Hong Ge <[email protected]>
1 parent 4b5e4d7 commit 0b56415

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/mcmc/hmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ function DynamicPPL.initialstep(
131131
rng::AbstractRNG,
132132
model::AbstractModel,
133133
spl::Sampler{<:Hamiltonian},
134-
vi::AbstractVarInfo;
134+
vi_original::AbstractVarInfo;
135135
initial_params=nothing,
136136
nadapts=0,
137137
kwargs...
138138
)
139139
# Transform the samples to unconstrained space and compute the joint log probability.
140-
vi = link!!(vi, spl, model)
140+
vi = DynamicPPL.link(vi_original, spl, model)
141141

142142
# Extract parameters.
143143
theta = vi[spl]

test/mcmc/hmc.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,15 @@
246246
sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5)
247247
end
248248
end
249+
250+
@turing_testset "(partially) issue: #2095" begin
251+
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
252+
xs = Vector{TV}(undef, 2)
253+
xs[1] ~ Dirichlet(ones(5))
254+
xs[2] ~ Dirichlet(ones(5))
255+
end
256+
model = vector_of_dirichlet()
257+
chain = sample(model, NUTS(), 1000)
258+
@test mean(Array(chain)) 0.2
259+
end
249260
end

0 commit comments

Comments
 (0)