|
1 | 1 | """
|
2 |
| - KernelSum(kernels::Array{Kernel}; weights::Array{Real}=ones(length(kernels))) |
| 2 | + KernelSum <: Kernel |
3 | 3 |
|
4 |
| -Create a positive weighted sum of kernels. All weights should be positive. |
5 |
| -One can also use the operator `+` |
| 4 | +Create a sum of kernels. One can also use the operator `+`. |
| 5 | +
|
| 6 | +There are various ways in which you create a `KernelSum`: |
| 7 | +
|
| 8 | +The simplest way to specify a `KernelSum` would be to use the overloaded `+` operator. This is |
| 9 | +equivalent to creating a `KernelSum` by specifying the kernels as the arguments to the constructor. |
| 10 | +```jldoctest kernelsum |
| 11 | +julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5); |
| 12 | +
|
| 13 | +julia> (k = k1 + k2) == KernelSum(k1, k2) |
| 14 | +true |
| 15 | +
|
| 16 | +julia> kernelmatrix(k1 + k2, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X) |
| 17 | +true |
| 18 | +
|
| 19 | +julia> kernelmatrix(k, X) == kernelmatrix(k1 + k2, X) |
| 20 | +true |
6 | 21 | ```
|
7 |
| - k1 = SqExponentialKernel() |
8 |
| - k2 = LinearKernel() |
9 |
| - k = KernelSum([k1, k2]) == k1 + k2 |
10 |
| - kernelmatrix(k, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X) |
11 |
| - kernelmatrix(k, X) == kernelmatrix(k1 + k2, X) |
12 |
| - kweighted = 0.5* k1 + 2.0*k2 == KernelSum([k1, k2], weights = [0.5, 2.0]) |
| 22 | +
|
| 23 | +You could also specify a `KernelSum` by providing a `Tuple` or a `Vector` of the |
| 24 | +kernels to be summed. We suggest you to use a `Tuple` when you have fewer components |
| 25 | +and a `Vector` when dealing with a large number of components. |
| 26 | +```jldoctest kernelsum |
| 27 | +julia> KernelSum((k1, k2)) == k1 + k2 |
| 28 | +true |
| 29 | +
|
| 30 | +julia> KernelSum([k1, k2]) == KernelSum((k1, k2)) == k1 + k2 |
| 31 | +true |
13 | 32 | ```
|
14 | 33 | """
|
15 |
| -struct KernelSum <: Kernel |
16 |
| - kernels::Vector{Kernel} |
17 |
| - weights::Vector{Real} |
| 34 | +struct KernelSum{Tk} <: Kernel |
| 35 | + kernels::Tk |
| 36 | +end |
| 37 | + |
| 38 | +function KernelSum(kernel::Kernel, kernels::Kernel...) |
| 39 | + return KernelSum((kernel, kernels...)) |
18 | 40 | end
|
19 | 41 |
|
20 |
| -function KernelSum( |
21 |
| - kernels::AbstractVector{<:Kernel}; |
22 |
| - weights::AbstractVector{<:Real} = ones(Float64, length(kernels)), |
| 42 | +Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2) |
| 43 | + |
| 44 | +function Base.:+( |
| 45 | + k1::KernelSum{<:AbstractVector{<:Kernel}}, |
| 46 | + k2::KernelSum{<:AbstractVector{<:Kernel}} |
23 | 47 | )
|
24 |
| - @assert length(kernels) == length(weights) "Weights and kernel vector should be of the same length" |
25 |
| - @assert all(weights .>= 0) "All weights should be positive" |
26 |
| - return KernelSum(kernels, weights) |
| 48 | + KernelSum(vcat(k1.kernels, k2.kernels)) |
27 | 49 | end
|
28 | 50 |
|
29 |
| -Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0]) |
30 |
| -Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)]) |
31 |
| -Base.:+(k1::KernelSum, k2::KernelSum) = |
32 |
| - KernelSum(vcat(k1.kernels, k2.kernels), weights = vcat(k1.weights, k2.weights)) |
33 |
| -Base.:+(k::Kernel, ks::KernelSum) = |
34 |
| - KernelSum(vcat(k, ks.kernels), weights = vcat(1.0, ks.weights)) |
35 |
| -Base.:+(k::ScaledKernel, ks::KernelSum) = |
36 |
| - KernelSum(vcat(kernel(k), ks.kernels), weights = vcat(first(k.σ²), ks.weights)) |
37 |
| -Base.:+(k::ScaledKernel, ks::Kernel) = |
38 |
| - KernelSum(vcat(kernel(k), ks), weights = vcat(first(k.σ²), 1.0)) |
39 |
| -Base.:+(ks::KernelSum, k::Kernel) = |
40 |
| - KernelSum(vcat(ks.kernels, k), weights = vcat(ks.weights, 1.0)) |
41 |
| -Base.:+(ks::KernelSum, k::ScaledKernel) = |
42 |
| - KernelSum(vcat(ks.kernels, kernel(k)), weights = vcat(ks.weights, first(k.σ²))) |
43 |
| -Base.:+(ks::Kernel, k::ScaledKernel) = |
44 |
| - KernelSum(vcat(ks, kernel(k)), weights = vcat(1.0, first(k.σ²))) |
45 |
| -Base.:*(w::Real, k::KernelSum) = KernelSum(k.kernels, weights = w * k.weights) #TODO add tests |
| 51 | +Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...) |
| 52 | + |
| 53 | +function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}}) |
| 54 | + return KernelSum(vcat(k, ks.kernels)) |
| 55 | +end |
| 56 | + |
| 57 | +Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...) |
| 58 | + |
| 59 | +function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel) |
| 60 | + return KernelSum(vcat(ks.kernels, k)) |
| 61 | +end |
| 62 | + |
| 63 | +Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k) |
46 | 64 |
|
47 | 65 | Base.length(k::KernelSum) = length(k.kernels)
|
48 | 66 |
|
49 |
| -(κ::KernelSum)(x, y) = sum(κ.weights[i] * κ.kernels[i](x, y) for i in 1:length(κ)) |
| 67 | +(κ::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels) |
50 | 68 |
|
51 | 69 | function kernelmatrix(κ::KernelSum, x::AbstractVector)
|
52 |
| - return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x) for i in 1:length(κ)) |
| 70 | + return sum(kernelmatrix(k, x) for k in κ.kernels) |
53 | 71 | end
|
54 | 72 |
|
55 | 73 | function kernelmatrix(κ::KernelSum, x::AbstractVector, y::AbstractVector)
|
56 |
| - return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x, y) for i in 1:length(κ)) |
| 74 | + return sum(kernelmatrix(k, x, y) for k in κ.kernels) |
57 | 75 | end
|
58 | 76 |
|
59 | 77 | function kerneldiagmatrix(κ::KernelSum, x::AbstractVector)
|
60 |
| - return sum(κ.weights[i] * kerneldiagmatrix(κ.kernels[i], x) for i in 1:length(κ)) |
| 78 | + return sum(kerneldiagmatrix(k, x) for k in κ.kernels) |
61 | 79 | end
|
62 | 80 |
|
63 | 81 | function Base.show(io::IO, κ::KernelSum)
|
64 | 82 | printshifted(io, κ, 0)
|
65 | 83 | end
|
66 | 84 |
|
| 85 | +function Base.:(==)(x::KernelSum, y::KernelSum) |
| 86 | + return ( |
| 87 | + length(x.kernels) == length(y.kernels) && |
| 88 | + all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) |
| 89 | + ) |
| 90 | +end |
| 91 | + |
67 | 92 | function printshifted(io::IO,κ::KernelSum, shift::Int)
|
68 | 93 | print(io,"Sum of $(length(κ)) kernels:")
|
69 |
| - for i in 1:length(κ) |
70 |
| - print(io, "\n" * ("\t" ^ (shift + 1)) * "- (w = $(κ.weights[i])) ") |
71 |
| - printshifted(io, κ.kernels[i], shift + 2) |
| 94 | + for k in κ.kernels |
| 95 | + print(io, "\n" ) |
| 96 | + for _ in 1:(shift + 1) |
| 97 | + print(io, "\t") |
| 98 | + end |
| 99 | + printshifted(io, k, shift + 2) |
72 | 100 | end
|
73 | 101 | end
|
0 commit comments