@@ -36,21 +36,10 @@ function setchunksize(chunk_size::Int)
3636 AdvancedVI. setchunksize (chunk_size)
3737end
3838
39- abstract type ADBackend end
40- struct ForwardDiffAD{chunk,standardtag} <: ADBackend end
39+ getchunksize (:: AutoForwardDiff{chunk} ) where {chunk} = chunk
4140
42- # Use standard tag if not specified otherwise
43- ForwardDiffAD {N} () where {N} = ForwardDiffAD {N,true} ()
44-
45- getchunksize (:: ForwardDiffAD{chunk} ) where chunk = chunk
46-
47- standardtag (:: ForwardDiffAD{<:Any,true} ) = true
48- standardtag (:: ForwardDiffAD ) = false
49-
50- struct TrackerAD <: ADBackend end
51- struct ZygoteAD <: ADBackend end
52-
53- struct ReverseDiffAD{cache} <: ADBackend end
41+ standardtag (:: AutoForwardDiff{<:Any,Nothing} ) = true
42+ standardtag (:: AutoForwardDiff ) = false
5443
5544const RDCache = Ref (false )
5645
@@ -63,10 +52,10 @@ getrdcache() = RDCache[]
6352ADBackend () = ADBackend (ADBACKEND[])
6453ADBackend (T:: Symbol ) = ADBackend (Val (T))
6554
66- ADBackend (:: Val{:forwarddiff} ) = ForwardDiffAD{ CHUNKSIZE[]}
67- ADBackend (:: Val{:tracker} ) = TrackerAD
68- ADBackend (:: Val{:zygote} ) = ZygoteAD
69- ADBackend (:: Val{:reversediff} ) = ReverseDiffAD{ getrdcache ()}
55+ ADBackend (:: Val{:forwarddiff} ) = AutoForwardDiff (; chunksize = CHUNKSIZE[])
56+ ADBackend (:: Val{:tracker} ) = AutoTracker ()
57+ ADBackend (:: Val{:zygote} ) = AutoZygote ()
58+ ADBackend (:: Val{:reversediff} ) = AutoReverseDiff (; compile = getrdcache ())
7059
7160ADBackend (:: Val ) = error (" The requested AD backend is not available. Make sure to load all required packages." )
7261
@@ -76,18 +65,18 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
7665Find the autodifferentiation backend of the algorithm `alg`.
7766"""
7867getADbackend (spl:: Sampler ) = getADbackend (spl. alg)
79- getADbackend (:: SampleFromPrior ) = ADBackend ()()
68+ getADbackend (:: SampleFromPrior ) = ADBackend ()
8069getADbackend (ctx:: DynamicPPL.SamplingContext ) = getADbackend (ctx. sampler)
8170getADbackend (ctx:: DynamicPPL.AbstractContext ) = getADbackend (DynamicPPL. NodeTrait (ctx), ctx)
8271
83- getADbackend (:: DynamicPPL.IsLeaf , ctx:: DynamicPPL.AbstractContext ) = ADBackend ()()
72+ getADbackend (:: DynamicPPL.IsLeaf , ctx:: DynamicPPL.AbstractContext ) = ADBackend ()
8473getADbackend (:: DynamicPPL.IsParent , ctx:: DynamicPPL.AbstractContext ) = getADbackend (DynamicPPL. childcontext (ctx))
8574
8675function LogDensityProblemsAD. ADgradient (ℓ:: Turing.LogDensityFunction )
8776 return LogDensityProblemsAD. ADgradient (getADbackend (ℓ. context), ℓ)
8877end
8978
90- function LogDensityProblemsAD. ADgradient (ad:: ForwardDiffAD , ℓ:: Turing.LogDensityFunction )
79+ function LogDensityProblemsAD. ADgradient (ad:: AutoForwardDiff , ℓ:: Turing.LogDensityFunction )
9180 θ = DynamicPPL. getparams (ℓ)
9281 f = Base. Fix1 (LogDensityProblems. logdensity, ℓ)
9382
@@ -107,20 +96,8 @@ function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensi
10796 return LogDensityProblemsAD. ADgradient (Val (:ForwardDiff ), ℓ; chunk, tag, x = θ)
10897end
10998
110- function LogDensityProblemsAD. ADgradient (:: TrackerAD , ℓ:: Turing.LogDensityFunction )
111- return LogDensityProblemsAD. ADgradient (Val (:Tracker ), ℓ)
112- end
113-
114- function LogDensityProblemsAD. ADgradient (:: ZygoteAD , ℓ:: Turing.LogDensityFunction )
115- return LogDensityProblemsAD. ADgradient (Val (:Zygote ), ℓ)
116- end
117-
118- for cache in (:true , :false )
119- @eval begin
120- function LogDensityProblemsAD. ADgradient (:: ReverseDiffAD{$cache} , ℓ:: Turing.LogDensityFunction )
121- return LogDensityProblemsAD. ADgradient (Val (:ReverseDiff ), ℓ; compile= Val ($ cache), x= DynamicPPL. getparams (ℓ))
122- end
123- end
99+ function LogDensityProblemsAD. ADgradient (ad:: AutoReverseDiff , ℓ:: Turing.LogDensityFunction )
100+ return LogDensityProblemsAD. ADgradient (Val (:ReverseDiff ), ℓ; compile= Val (ad. compile), x= DynamicPPL. getparams (ℓ))
124101end
125102
126103function verifygrad (grad:: AbstractVector{<:Real} )
0 commit comments