Skip to content

Commit fa6f30a

Browse files
authored
Bugfix for Optim.jl on models with different linked dimensionality (#2196)
* fixed bug with optim interface * bump patch version * fixed test * dirichlet onles has a unique mode for alpha > 1...
1 parent c29d36e commit fa6f30a

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.30.7"
3+
version = "0.30.8"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/TuringOptimExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ function _optimize(
228228
# Convert the initial values, since it is assumed that users provide them
229229
# in the constrained space.
230230
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
231-
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
231+
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
232232
init_vals = DynamicPPL.getparams(f)
233233

234234
# Optimize!
@@ -242,9 +242,9 @@ function _optimize(
242242
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
243243
# correct dimensionality.
244244
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
245-
Setfield.@set! f.varinfo = DynamicPPL.invlink!!(f.varinfo, model)
245+
Setfield.@set! f.varinfo = DynamicPPL.invlink(f.varinfo, model)
246246
vals = DynamicPPL.getparams(f)
247-
Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model)
247+
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
248248

249249
# Make one transition to get the parameter names.
250250
ts = [Turing.Inference.Transition(

test/optimisation/OptimInterface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,12 @@ end
225225
@test Turing.OptimLogDensity(m1, ctx)(w) == Turing.OptimLogDensity(m2, ctx)(w)
226226
end
227227
end
228+
229+
# Issue: https://discourse.julialang.org/t/turing-mixture-models-with-dirichlet-weightings/112910
230+
@testset "with different linked dimensionality" begin
231+
@model demo_dirichlet() = x ~ Dirichlet(2 * ones(3))
232+
model = demo_dirichlet()
233+
result = optimize(model, MAP())
234+
@test result.values mode(Dirichlet(2 * ones(3))) atol=0.2
235+
end
228236
end

0 commit comments

Comments
 (0)