Skip to content

Commit 6649f10

Browse files
devmotionsunxd3
andauthored
Replace internal AD backend types with ADTypes (#2047)
* Replace internal AD backend types with ADTypes * Remove upstreamed functionality * Update ADBackend code * Formatting * Update Project.toml * Update src/essential/ad.jl * Fix tests * Switch ADType version to 0.1.5 for CI testing * A few fixes * Update Project.toml * Make ad type a field * Improve docs and tests * small fix * Fix test errors --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent d4a7975 commit 6649f10

File tree

12 files changed

+175
-169
lines changed

12 files changed

+175
-169
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.29.4"
3+
version = "0.30.0"
44

55
[deps]
6+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
67
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
78
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
89
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
@@ -44,6 +45,7 @@ TuringDynamicHMCExt = "DynamicHMC"
4445
TuringOptimExt = "Optim"
4546

4647
[compat]
48+
ADTypes = "0.2"
4749
AbstractMCMC = "4, 5"
4850
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
4951
AdvancedMH = "0.8"
@@ -61,7 +63,7 @@ EllipticalSliceSampling = "0.5, 1, 2"
6163
ForwardDiff = "0.10.3"
6264
Libtask = "0.7, 0.8"
6365
LogDensityProblems = "2"
64-
LogDensityProblemsAD = "1.4"
66+
LogDensityProblemsAD = "1.7.0"
6567
MCMCChains = "5, 6"
6668
NamedArrays = "0.9, 0.10"
6769
Optim = "1"

ext/TuringDynamicHMCExt.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ if isdefined(Base, :get_extension)
88
import DynamicHMC
99
using Turing
1010
using Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
11-
using Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
11+
using Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
1212
else
1313
import ..DynamicHMC
1414
using ..Turing
1515
using ..Turing: AbstractMCMC, Random, LogDensityProblems, DynamicPPL
16-
using ..Turing.Inference: LogDensityProblemsAD, TYPEDFIELDS
16+
using ..Turing.Inference: ADTypes, LogDensityProblemsAD, TYPEDFIELDS
1717
end
1818

1919
"""
@@ -26,14 +26,18 @@ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
2626
using DynamicHMC
2727
```
2828
"""
29-
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian{AD}
29+
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian
3030
sampler::T
31+
adtype::AD
3132
end
3233

33-
DynamicNUTS(args...) = DynamicNUTS{Turing.ADBackend()}(args...)
34-
DynamicNUTS{AD}(spl::DynamicHMC.NUTS, space::Tuple) where AD = DynamicNUTS{AD, space, typeof(spl)}(spl)
35-
DynamicNUTS{AD}(spl::DynamicHMC.NUTS) where AD = DynamicNUTS{AD}(spl, ())
36-
DynamicNUTS{AD}() where AD = DynamicNUTS{AD}(DynamicHMC.NUTS())
34+
function DynamicNUTS(
35+
spl::DynamicHMC.NUTS = DynamicHMC.NUTS(),
36+
space::Tuple = ();
37+
adtype::ADTypes.AbstractADType = Turing.ADBackend()
38+
)
39+
return DynamicNUTS{typeof(adtype),space,typeof(spl)}(spl, adtype)
40+
end
3741
Turing.externalsampler(spl::DynamicHMC.NUTS) = DynamicNUTS(spl)
3842

3943
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

src/essential/Essential.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using Bijectors: PDMatDistribution
1111
using AdvancedVI
1212
using StatsFuns: logsumexp, softmax
1313
@reexport using DynamicPPL
14+
using ADTypes: ADTypes, AutoForwardDiff, AutoTracker, AutoReverseDiff, AutoZygote
1415

1516
import AdvancedPS
1617
import LogDensityProblems
@@ -40,10 +41,10 @@ export @model,
4041
ADBackend,
4142
setadbackend,
4243
setadsafe,
43-
ForwardDiffAD,
44-
TrackerAD,
45-
ZygoteAD,
46-
ReverseDiffAD,
44+
AutoForwardDiff,
45+
AutoTracker,
46+
AutoZygote,
47+
AutoReverseDiff,
4748
value,
4849
CHUNKSIZE,
4950
ADBACKEND,

src/essential/ad.jl

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,10 @@ function setchunksize(chunk_size::Int)
3636
AdvancedVI.setchunksize(chunk_size)
3737
end
3838

39-
abstract type ADBackend end
40-
struct ForwardDiffAD{chunk,standardtag} <: ADBackend end
39+
getchunksize(::AutoForwardDiff{chunk}) where {chunk} = chunk
4140

42-
# Use standard tag if not specified otherwise
43-
ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}()
44-
45-
getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk
46-
47-
standardtag(::ForwardDiffAD{<:Any,true}) = true
48-
standardtag(::ForwardDiffAD) = false
49-
50-
struct TrackerAD <: ADBackend end
51-
struct ZygoteAD <: ADBackend end
52-
53-
struct ReverseDiffAD{cache} <: ADBackend end
41+
standardtag(::AutoForwardDiff{<:Any,Nothing}) = true
42+
standardtag(::AutoForwardDiff) = false
5443

5544
const RDCache = Ref(false)
5645

@@ -63,10 +52,10 @@ getrdcache() = RDCache[]
6352
ADBackend() = ADBackend(ADBACKEND[])
6453
ADBackend(T::Symbol) = ADBackend(Val(T))
6554

66-
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
67-
ADBackend(::Val{:tracker}) = TrackerAD
68-
ADBackend(::Val{:zygote}) = ZygoteAD
69-
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}
55+
ADBackend(::Val{:forwarddiff}) = AutoForwardDiff(; chunksize=CHUNKSIZE[])
56+
ADBackend(::Val{:tracker}) = AutoTracker()
57+
ADBackend(::Val{:zygote}) = AutoZygote()
58+
ADBackend(::Val{:reversediff}) = AutoReverseDiff(; compile=getrdcache())
7059

7160
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
7261

@@ -76,18 +65,18 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
7665
Find the autodifferentiation backend of the algorithm `alg`.
7766
"""
7867
getADbackend(spl::Sampler) = getADbackend(spl.alg)
79-
getADbackend(::SampleFromPrior) = ADBackend()()
68+
getADbackend(::SampleFromPrior) = ADBackend()
8069
getADbackend(ctx::DynamicPPL.SamplingContext) = getADbackend(ctx.sampler)
8170
getADbackend(ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.NodeTrait(ctx), ctx)
8271

83-
getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()()
72+
getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()
8473
getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx))
8574

8675
function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction)
8776
return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.context), ℓ)
8877
end
8978

90-
function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction)
79+
function LogDensityProblemsAD.ADgradient(ad::AutoForwardDiff, ℓ::Turing.LogDensityFunction)
9180
θ = DynamicPPL.getparams(ℓ)
9281
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
9382

@@ -107,20 +96,8 @@ function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensi
10796
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x = θ)
10897
end
10998

110-
function LogDensityProblemsAD.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction)
111-
return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ)
112-
end
113-
114-
function LogDensityProblemsAD.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction)
115-
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
116-
end
117-
118-
for cache in (:true, :false)
119-
@eval begin
120-
function LogDensityProblemsAD.ADgradient(::ReverseDiffAD{$cache}, ℓ::Turing.LogDensityFunction)
121-
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache), x=DynamicPPL.getparams(ℓ))
122-
end
123-
end
99+
function LogDensityProblemsAD.ADgradient(ad::AutoReverseDiff, ℓ::Turing.LogDensityFunction)
100+
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val(ad.compile), x=DynamicPPL.getparams(ℓ))
124101
end
125102

126103
function verifygrad(grad::AbstractVector{<:Real})

src/mcmc/Inference.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using DocStringExtensions: TYPEDEF, TYPEDFIELDS
2323
using DataStructures: OrderedSet
2424
using Setfield: Setfield
2525

26+
import ADTypes
2627
import AbstractMCMC
2728
import AdvancedHMC; const AHMC = AdvancedHMC
2829
import AdvancedMH; const AMH = AdvancedMH
@@ -74,10 +75,10 @@ export InferenceAlgorithm,
7475
abstract type AbstractAdapter end
7576
abstract type InferenceAlgorithm end
7677
abstract type ParticleInference <: InferenceAlgorithm end
77-
abstract type Hamiltonian{AD} <: InferenceAlgorithm end
78-
abstract type StaticHamiltonian{AD} <: Hamiltonian{AD} end
79-
abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end
80-
getADbackend(::Hamiltonian{AD}) where AD = AD()
78+
abstract type Hamiltonian <: InferenceAlgorithm end
79+
abstract type StaticHamiltonian <: Hamiltonian end
80+
abstract type AdaptiveHamiltonian <: Hamiltonian end
81+
getADbackend(alg::Hamiltonian) = alg.adtype
8182

8283
"""
8384
ExternalSampler{S<:AbstractSampler}

0 commit comments

Comments
 (0)