-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
So with 1.7 (maybe 1.6 too?), there is an issue with Zygote gradients and the expec_loglik
function because of it hits a BLAS
function.
I tried to rewrite it to make fewer allocations and here are the results:
using BenchmarkTools, Distributions, ApproximateGPs, IrrationalConstants, FastGaussQuadrature
# The new proposed version
function expected_loglik(
gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik
)
xs, ws = gausshermite(gh.n_points)
return mapreduce(+, q_f, y) do q, y
μ = mean(q)
σ = std(q)
mapreduce(+, xs, ws) do x, w
f = sqrt2 * σ * x + μ
loglikelihood(lik(f), y) * w
end
end / sqrtπ
end
# The previous version
function expected_loglik_old(
gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik
)
xs, ws = gausshermite(gh.n_points)
fs = sqrt2 * std.(q_f) .* xs' .+ mean.(q_f)
lls = loglikelihood.(lik.(fs), y)
return sum(lls * ws) / √π
end
function evaluate_speed(N)
gh = GaussHermite(100)
lik = BernoulliLikelihood()
y = rand(0:1, N)
q_f = Normal.(randn(N), rand(N))
@btime expected_loglik($gh, $y, $q_f, $lik)
@btime expected_loglik_old($gh, $y, $q_f, $lik)
end
for N in [10, 100, 500, 1000]
@info N
evaluate_speed(N)
end
[ Info: 10
164.407 μs (192 allocations: 44.45 KiB)
139.752 μs (87 allocations: 48.83 KiB)
[ Info: 100
396.769 μs (1002 allocations: 146.42 KiB)
312.959 μs (89 allocations: 191.52 KiB)
[ Info: 500
1.433 ms (4602 allocations: 599.61 KiB)
1.027 ms (89 allocations: 826.08 KiB)
[ Info: 1000
2.736 ms (9102 allocations: 1.14 MiB)
1.972 ms (89 allocations: 1.58 MiB)
So the old approach is faster but make bigger allocations, I actually don't know where all this allocations come from for the first approach, any clue?
Metadata
Metadata
Assignees
Labels
No labels