Skip to content

Commit 1cd7b6d

Browse files
committed
2 parents e5f4fd7 + 9d8069e commit 1cd7b6d

File tree

8 files changed

+96
-0
lines changed

8 files changed

+96
-0
lines changed

src/ClosedFormExpectations.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ include("Exponential/Exponential.jl")
5555
include("Normal/expectation.jl")
5656
include("Normal/williams/normal.jl")
5757
include("Normal/williams/normal_mean_variance.jl")
58+
include("Normal/williams/ef_parametrization.jl")
59+
5860
# gamma
5961
include("Gamma/Gamma.jl")
6062

63+
# exponetial family distribution interface
64+
include("exponential_family_interface.jl")
65+
6166
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using StaticArrays
2+
import Distributions: var, mean, std
3+
import ExponentialFamily: ExponentialFamilyDistribution, NormalMeanVariance, getnaturalparameters
4+
5+
function mean(expectation::ClosedWilliamsProduct, f, q::ExponentialFamilyDistribution{T}) where {T <: NormalMeanVariance}
6+
η = getnaturalparameters(q)
7+
jacobian = @SMatrix [-inv(2*η[2]) η[1]/(2*η[2]^2); 0.0 (-1/η[2])^(3/2)/(2*sqrt(2))]
8+
normal = Normal(mean(q), std(q))
9+
return mean(expectation, f, normal)' * jacobian
10+
end
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import ExponentialFamily: ExponentialFamilyDistribution
2+
3+
function mean(expectation::ClosedFormExpectation, f, q::ExponentialFamilyDistribution)
4+
dist = convert(Distribution, q)
5+
return mean(expectation, f, dist)
6+
end

test/Exponential/mean_tests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,20 @@ end
5454
λ2 = rand(rng)*10
5555
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), Exponential(λ1))
5656
end
57+
end
58+
59+
@testitem "mean(::ClosedFormExpectation, ::Logpdf{Exponential}, ::ExponentialFamilyDistribution{Exponential}" begin
60+
using Distributions
61+
using ClosedFormExpectations
62+
using ExponentialFamily
63+
using StableRNGs
64+
using Base.MathConstants: eulergamma
65+
66+
include("../test_utils.jl")
67+
rng = StableRNG(123)
68+
for _ in 1:10
69+
λ1 = rand(rng)*10
70+
λ2 = rand(rng)*10
71+
central_limit_theorem_test(ClosedFormExpectation(), Logpdf(Exponential(λ2)), convert(ExponentialFamilyDistribution, Exponential(λ1)))
72+
end
5773
end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
include("../normal_utils.jl")
2+
3+
using ExponentialFamily
4+
5+
score(q::ExponentialFamilyDistribution{NormalMeanVariance}, x) = sufficientstatistics(q, x) .- gradlogpartition(q)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Normal}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
2+
include("ef_utils.jl")
3+
rng = StableRNG(123)
4+
for _ in 1:10
5+
μ1, σ1 = rand(rng)*10, rand(rng)*5
6+
μ2, σ2 = rand(rng)*10, rand(rng)*5
7+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
8+
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Normal(μ2, σ2)), ef, score)
9+
end
10+
end
11+
12+
@testitem "mean(::ClosedWilliamsProduct, p::Abs, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
13+
include("ef_utils.jl")
14+
rng = StableRNG(123)
15+
for _ in 1:10
16+
μ1, σ1 = rand(rng)*10, rand(rng)*5
17+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
18+
central_limit_theorem_test(ClosedWilliamsProduct(), Abs(), ef, score)
19+
end
20+
end
21+
22+
@testitem "mean(::ClosedWilliamsProduct, p::Logpdf{Laplace}, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
23+
include("ef_utils.jl")
24+
rng = StableRNG(123)
25+
for _ in 1:10
26+
μ1, σ1 = rand(rng)*10, rand(rng)*5
27+
loc, θ = rand(rng)*10, rand(rng)*10
28+
ef = convert(ExponentialFamilyDistribution, Normal(μ1, σ1))
29+
central_limit_theorem_test(ClosedWilliamsProduct(), Logpdf(Laplace(loc, θ)), ef, score)
30+
end
31+
end
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testitem "Support ExponetialFamilyDistribution for ClosedFormExpectation" begin
2+
include("../test_utils.jl")
3+
using ExponentialFamily
4+
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Exponential})" begin
5+
dist = Exponential(1.0)
6+
ef = convert(ExponentialFamilyDistribution, Exponential(1.0))
7+
@test mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), ef) mean(ClosedFormExpectation(), Logpdf(Exponential(1.0)), dist)
8+
end
9+
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{NormalMeanVariance})" begin
10+
dist = Normal(1.0, 1.0)
11+
ef = convert(ExponentialFamilyDistribution, Normal(1.0, 1.0))
12+
@test mean(ClosedFormExpectation(), Abs(), ef) mean(ClosedFormExpectation(), Abs(), dist)
13+
end
14+
@testset "mean(::ClosedFormExpectation, f, q::ExponentialFamilyDistribution{Gamma})" begin
15+
import LogExpFunctions: xlogx
16+
dist = Gamma(1.0, 1.0)
17+
ef = convert(ExponentialFamilyDistribution, Gamma(1.0, 1.0))
18+
@test mean(ClosedFormExpectation(), xlogx, ef) mean(ClosedFormExpectation(), xlogx, dist)
19+
end
20+
end

test/interface/gauss_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ end
2727

2828
@testset "ClosedWilliamsProduct interface" begin
2929
using Distributions
30+
using ExponentialFamily
3031

3132
nmv = NormalMeanVariance(0.0, 1.0)
3233
nmp = NormalMeanPrecision(0.0, 1.0)
34+
ef = convert(ExponentialFamilyDistribution, nmv)
3335

3436
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) isa AbstractArray
3537
@test mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1)) isa AbstractArray
3638
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), Normal(0, 1)) mean(ClosedWilliamsProduct(), Logpdf(nmp), Normal(0, 1))
3739
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), nmv) mean(ClosedWilliamsProduct(), Logpdf(nmp), nmv)
40+
@test mean(ClosedWilliamsProduct(), Logpdf(nmv), ef) isa AbstractArray
3841
end
3942
end

0 commit comments

Comments
 (0)