Skip to content

Commit cafd9d7

Browse files
authored
Implement Turing.Inference.getlogp_external (#39)
* Implement Turing.Inference.getlogp_external * Add tests * Fix tests * Don't set arch=x64 for macos-latest
1 parent c3888ff commit cafd9d7

File tree

4 files changed

+41
-19
lines changed

4 files changed

+41
-19
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@ jobs:
2525
- ubuntu-latest
2626
- macOS-latest
2727
- windows-latest
28-
arch:
29-
- x64
3028
steps:
3129
- uses: actions/checkout@v4
3230
- uses: julia-actions/setup-julia@v2
3331
with:
3432
version: ${{ matrix.version }}
35-
arch: ${{ matrix.arch }}
3633
- uses: julia-actions/cache@v1
3734
- uses: julia-actions/julia-buildpkg@v1
3835
- uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SliceSampling"
22
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
3-
version = "0.7.6"
3+
version = "0.7.7"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -21,7 +21,7 @@ Distributions = "0.25"
2121
LinearAlgebra = "1"
2222
LogDensityProblems = "2"
2323
Random = "1"
24-
Turing = "0.37, 0.38, 0.39"
24+
Turing = "0.39.5"
2525
julia = "1.10"
2626

2727
[extras]

ext/SliceSamplingTuringExt.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,29 @@ Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
2222
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
2323
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
2424

25-
function Turing.Inference.getparams(
26-
::Turing.DynamicPPL.Model, sample::SliceSampling.UnivariateSliceState
27-
)
25+
const SliceSamplingStates = Union{
26+
SliceSampling.UnivariateSliceState,
27+
SliceSampling.GibbsState,
28+
SliceSampling.HitAndRunState,
29+
SliceSampling.LatentSliceState,
30+
SliceSampling.GibbsPolarSliceState,
31+
}
32+
function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
2833
return sample.transition.params
2934
end
3035

31-
function Turing.Inference.getparams(
32-
::Turing.DynamicPPL.Model, state::SliceSampling.GibbsState
33-
)
34-
return state.transition.params
35-
end
36-
37-
function Turing.Inference.getparams(
38-
::Turing.DynamicPPL.Model, state::SliceSampling.HitAndRunState
36+
function Turing.Inference.getlogp_external(
37+
::Turing.DynamicPPL.Model, t::SliceSampling.Transition, state
3938
)
40-
return state.transition.params
39+
return t.lp
4140
end
4241
# end
4342

4443
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
4544
model =.model
4645
vi = Turing.DynamicPPL.VarInfo(rng, model, Turing.SampleFromUniform())
4746
vi_spl = last(Turing.DynamicPPL.evaluate!!(model, rng, vi, Turing.SampleFromUniform()))
48-
θ = vi_spl[:]
47+
θ = vi_spl[:]
4948

5049
init_attempt_count = 1
5150
while !all(isfinite.(θ))

test/turing.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,13 @@
88
return nothing
99
end
1010

11+
@model function logp_check()
12+
a ~ Normal()
13+
return b ~ Normal()
14+
end
15+
1116
n_samples = 1000
12-
model = demo()
17+
model = demo()
1318

1419
@testset for sampler in [
1520
RandPermGibbs(Slice(1)),
@@ -30,6 +35,15 @@
3035
)
3136

3237
chain = sample(model, externalsampler(sampler), n_samples; progress=false)
38+
39+
chain_logp_check = sample(
40+
logp_check(), externalsampler(sampler), 100; progress=false
41+
)
42+
@test isapprox(
43+
logpdf.(Normal(), chain_logp_check[:a]) .+
44+
logpdf.(Normal(), chain_logp_check[:b]),
45+
chain_logp_check[:lp],
46+
)
3347
end
3448

3549
@testset "gibbs($sampler)" for sampler in [
@@ -46,5 +60,17 @@
4660
n_samples;
4761
progress=false,
4862
)
63+
64+
chain_logp_check = sample(
65+
logp_check(),
66+
Turing.Gibbs(:a => externalsampler(sampler), :b => externalsampler(sampler)),
67+
100;
68+
progress=false,
69+
)
70+
@test isapprox(
71+
logpdf.(Normal(), chain_logp_check[:a]) .+
72+
logpdf.(Normal(), chain_logp_check[:b]),
73+
chain_logp_check[:lp],
74+
)
4975
end
5076
end

0 commit comments

Comments
 (0)