Skip to content

Commit c29d36e

Browse files
authored
Fix for #2180 (#2181)
* fixed dispatch issue of `propose!!` (ref #2180) * added test for `filldist` proposal for MH * bump patch version * added `demo_dot_assume_observe_index` to models that are allowed to fail the MLE test, as it should be * fixed test * try comparing median instead of mean to avoid outliers completely ruining the estimate * try with initial params * upped the number of samples used for `filldist` proposal test
1 parent 3a315ce commit c29d36e

File tree

5 files changed

+16
-5
lines changed

5 files changed

+16
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.30.6"
3+
version = "0.30.7"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/mh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,8 @@ function propose!!(
413413
vi::AbstractVarInfo,
414414
model::Model,
415415
spl::Sampler{<:MH},
416-
proposal::AdvancedMH.RandomWalkProposal{issymmetric,<:MvNormal}
417-
) where {issymmetric}
416+
proposal::AdvancedMH.RandomWalkProposal
417+
)
418418
# If this is the case, we can just draw directly from the proposal
419419
# matrix.
420420
vals = vi[spl]

test/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@
226226
end
227227
alg = NUTS(1000, 0.8; adtype=adbackend)
228228
gdemo_default_prior = DynamicPPL.contextualize(demo_hmc_prior(), DynamicPPL.PriorContext())
229-
chain = sample(gdemo_default_prior, alg, 10_000)
229+
chain = sample(gdemo_default_prior, alg, 10_000; initial_params=[3.0, 0.0])
230230
check_numerical(chain, [:s, :m], [mean(truncated(Normal(3, 1); lower=0)), 0], atol=0.1)
231231
end
232232

test/mcmc/mh.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,14 @@
234234
chain = sample(rng, gdemo_default_prior, alg, n; discard_initial = burnin, thinning=10)
235235
check_numerical(chain, [:s, :m], [mean(InverseGamma(2, 3)), 0], atol=0.3)
236236
end
237+
238+
@turing_testset "`filldist` proposal (issue #2180)" begin
239+
@model demo_filldist_issue2180() = x ~ MvNormal(zeros(3), I)
240+
chain = sample(
241+
demo_filldist_issue2180(),
242+
MH(AdvancedMH.RandomWalkProposal(filldist(Normal(), 3))),
243+
10_000
244+
)
245+
check_numerical(chain, [Symbol("x[1]"), Symbol("x[2]"), Symbol("x[3]")], [0, 0, 0], atol=0.1)
246+
end
237247
end

test/optimisation/OptimInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ end
156156
DynamicPPL.TestUtils.demo_dot_assume_dot_observe_matrix,
157157
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix,
158158
DynamicPPL.TestUtils.demo_assume_submodel_observe_index_literal,
159+
DynamicPPL.TestUtils.demo_dot_assume_observe_index,
159160
DynamicPPL.TestUtils.demo_dot_assume_observe_index_literal,
160-
DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix
161+
DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix,
161162
]
162163
@testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
163164
result_true = DynamicPPL.TestUtils.likelihood_optima(model)

0 commit comments

Comments
 (0)