Skip to content

Commit d9ccffa

Browse files
authored
Use tuples in KernelSum and KernelProduct; Make Zygote tests pass for KernelSum and KernelProduct; Improve doctring. (#146)
* Use tuples in KernelSum and KernelProduct * Remove weights in KernelSum * Add jldocs and remove duplicate functions * Revert userguide and fix formatting * Add more jldoctests * Fix test/trainable.jl * Add detailed docstring * Iterate over kernels * Sum/product of array of kernels results in array of kernels * Modify show * Fix AD test * Add tests for show function * Zygote tests pass now * Modify + and * functions. * Define '==' in TensorProduct * Address code review * Update kernels.md * Patch bump
1 parent 17b4bd9 commit d9ccffa

File tree

12 files changed

+210
-84
lines changed

12 files changed

+210
-84
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.5.0"
3+
version = "0.5.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/make.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
using Documenter
22
using KernelFunctions
33

4+
DocMeta.setdocmeta!(
5+
KernelFunctions,
6+
:DocTestSetup,
7+
:(using KernelFunctions, LinearAlgebra, Random);
8+
recursive=true,
9+
)
10+
411
makedocs(
512
sitename = "KernelFunctions",
613
format = Documenter.HTML(),

docs/src/kernels.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ Where $\widetilde{k}$ is another kernel and $\sigma^2 > 0$.
246246
The [`KernelSum`](@ref) is defined as a sum of kernels
247247

248248
```math
249-
k(x,x';\{w_i\},\{k_i\}) = \sum_i w_i k_i(x,x'),
249+
k(x, x'; \{k_i\}) = \sum_i k_i(x, x').
250250
```
251-
Where $w_i > 0$.
251+
252252
### KernelProduct
253253

254254
The [`KernelProduct`](@ref) is defined as a product of kernels

docs/src/userguide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ For example :
7171
```julia
7272
k1 = SqExponentialKernel()
7373
k2 = Matern32Kernel()
74-
k = 0.5*k1 + 0.2*k2 # KernelSum
75-
k = k1*k2 # KernelProduct
74+
k = 0.5 * k1 + 0.2 * k2 # KernelSum
75+
k = k1 * k2 # KernelProduct
7676
```
7777

7878
## Kernel Parameters

src/kernels/kernelproduct.jl

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,103 @@
11
"""
2-
KernelProduct(kernels::Array{Kernel})
2+
KernelProduct <: Kernel
33
4-
Create a product of kernels.
5-
One can also use the operator `*` :
4+
Create a product of kernels. One can also use the overloaded operator `*`.
5+
6+
There are various ways in which you create a `KernelProduct`:
7+
8+
The simplest way to specify a `KernelProduct` would be to use the overloaded `*` operator. This is
9+
equivalent to creating a `KernelProduct` by specifying the kernels as the arguments to the constructor.
10+
```jldoctest kernelprod
11+
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5);
12+
13+
julia> (k = k1 * k2) == KernelProduct(k1, k2)
14+
true
15+
16+
julia> kernelmatrix(k1 * k2, X) == kernelmatrix(k1, X) .* kernelmatrix(k2, X)
17+
true
18+
19+
julia> kernelmatrix(k, X) == kernelmatrix(k1 * k2, X)
20+
true
621
```
7-
k1 = SqExponentialKernel()
8-
k2 = LinearKernel()
9-
k = KernelProduct([k1, k2]) == k1 * k2
10-
kernelmatrix(k, X) == kernelmatrix(k1, X) .* kernelmatrix(k2, X)
11-
kernelmatrix(k, X) == kernelmatrix(k1 * k2, X)
22+
23+
You could also specify a `KernelProduct` by providing a `Tuple` or a `Vector` of the
24+
kernels to be multiplied. We suggest you to use a `Tuple` when you have fewer components
25+
and a `Vector` when dealing with a large number of components.
26+
```jldoctest kernelprod
27+
julia> KernelProduct((k1, k2)) == k1 * k2
28+
true
29+
30+
julia> KernelProduct([k1, k2]) == KernelProduct((k1, k2)) == k1 * k2
31+
true
1232
```
1333
"""
14-
struct KernelProduct <: Kernel
15-
kernels::Vector{Kernel}
34+
struct KernelProduct{Tk} <: Kernel
35+
kernels::Tk
36+
end
37+
38+
function KernelProduct(kernel::Kernel, kernels::Kernel...)
39+
return KernelProduct((kernel, kernels...))
1640
end
1741

18-
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
19-
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
20-
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
21-
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k))
42+
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct(k1, k2)
43+
44+
function Base.:*(
45+
k1::KernelProduct{<:AbstractVector{<:Kernel}},
46+
k2::KernelProduct{<:AbstractVector{<:Kernel}}
47+
)
48+
return KernelProduct(vcat(k1.kernels, k2.kernels))
49+
end
50+
51+
function Base.:*(k1::KernelProduct,k2::KernelProduct)
52+
return KernelProduct(k1.kernels..., k2.kernels...)
53+
end
54+
55+
function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}})
56+
return KernelProduct(vcat(k, ks.kernels))
57+
end
58+
59+
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(k, kp.kernels...)
60+
61+
function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel)
62+
return KernelProduct(vcat(ks.kernels, k))
63+
end
64+
65+
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(kp.kernels..., k)
2266

2367
Base.length(k::KernelProduct) = length(k.kernels)
2468

2569
::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
2670

2771
function kernelmatrix::KernelProduct, x::AbstractVector)
28-
return reduce(hadamard, kernelmatrix(κ.kernels[i], x) for i in 1:length(κ))
72+
return reduce(hadamard, kernelmatrix(k, x) for k in κ.kernels)
2973
end
3074

3175
function kernelmatrix::KernelProduct, x::AbstractVector, y::AbstractVector)
32-
return reduce(hadamard, kernelmatrix(κ.kernels[i], x, y) for i in 1:length(κ))
76+
return reduce(hadamard, kernelmatrix(k, x, y) for k in κ.kernels)
3377
end
3478

3579
function kerneldiagmatrix::KernelProduct, x::AbstractVector)
36-
return reduce(hadamard, kerneldiagmatrix(κ.kernels[i], x) for i in 1:length(κ))
80+
return reduce(hadamard, kerneldiagmatrix(k, x) for k in κ.kernels)
3781
end
3882

3983
function Base.show(io::IO, κ::KernelProduct)
4084
printshifted(io, κ, 0)
4185
end
4286

87+
function Base.:(==)(x::KernelProduct, y::KernelProduct)
88+
return (
89+
length(x.kernels) == length(y.kernels) &&
90+
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
91+
)
92+
end
93+
4394
function printshifted(io::IO, κ::KernelProduct, shift::Int)
4495
print(io, "Product of $(length(κ)) kernels:")
45-
for i in 1:length(κ)
46-
print(io, "\n" * ("\t" ^ (shift + 1))* "- ")
47-
printshifted(io, κ.kernels[i], shift + 2)
96+
for k in κ.kernels
97+
print(io, "\n" )
98+
for _ in 1:(shift + 1)
99+
print(io, "\t")
100+
end
101+
printshifted(io, k, shift + 2)
48102
end
49103
end

src/kernels/kernelsum.jl

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,101 @@
11
"""
2-
KernelSum(kernels::Array{Kernel}; weights::Array{Real}=ones(length(kernels)))
2+
KernelSum <: Kernel
33
4-
Create a positive weighted sum of kernels. All weights should be positive.
5-
One can also use the operator `+`
4+
Create a sum of kernels. One can also use the operator `+`.
5+
6+
There are various ways in which you create a `KernelSum`:
7+
8+
The simplest way to specify a `KernelSum` would be to use the overloaded `+` operator. This is
9+
equivalent to creating a `KernelSum` by specifying the kernels as the arguments to the constructor.
10+
```jldoctest kernelsum
11+
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5);
12+
13+
julia> (k = k1 + k2) == KernelSum(k1, k2)
14+
true
15+
16+
julia> kernelmatrix(k1 + k2, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X)
17+
true
18+
19+
julia> kernelmatrix(k, X) == kernelmatrix(k1 + k2, X)
20+
true
621
```
7-
k1 = SqExponentialKernel()
8-
k2 = LinearKernel()
9-
k = KernelSum([k1, k2]) == k1 + k2
10-
kernelmatrix(k, X) == kernelmatrix(k1, X) .+ kernelmatrix(k2, X)
11-
kernelmatrix(k, X) == kernelmatrix(k1 + k2, X)
12-
kweighted = 0.5* k1 + 2.0*k2 == KernelSum([k1, k2], weights = [0.5, 2.0])
22+
23+
You could also specify a `KernelSum` by providing a `Tuple` or a `Vector` of the
24+
kernels to be summed. We suggest you to use a `Tuple` when you have fewer components
25+
and a `Vector` when dealing with a large number of components.
26+
```jldoctest kernelsum
27+
julia> KernelSum((k1, k2)) == k1 + k2
28+
true
29+
30+
julia> KernelSum([k1, k2]) == KernelSum((k1, k2)) == k1 + k2
31+
true
1332
```
1433
"""
15-
struct KernelSum <: Kernel
16-
kernels::Vector{Kernel}
17-
weights::Vector{Real}
34+
struct KernelSum{Tk} <: Kernel
35+
kernels::Tk
36+
end
37+
38+
function KernelSum(kernel::Kernel, kernels::Kernel...)
39+
return KernelSum((kernel, kernels...))
1840
end
1941

20-
function KernelSum(
21-
kernels::AbstractVector{<:Kernel};
22-
weights::AbstractVector{<:Real} = ones(Float64, length(kernels)),
42+
Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)
43+
44+
function Base.:+(
45+
k1::KernelSum{<:AbstractVector{<:Kernel}},
46+
k2::KernelSum{<:AbstractVector{<:Kernel}}
2347
)
24-
@assert length(kernels) == length(weights) "Weights and kernel vector should be of the same length"
25-
@assert all(weights .>= 0) "All weights should be positive"
26-
return KernelSum(kernels, weights)
48+
KernelSum(vcat(k1.kernels, k2.kernels))
2749
end
2850

29-
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
30-
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)])
31-
Base.:+(k1::KernelSum, k2::KernelSum) =
32-
KernelSum(vcat(k1.kernels, k2.kernels), weights = vcat(k1.weights, k2.weights))
33-
Base.:+(k::Kernel, ks::KernelSum) =
34-
KernelSum(vcat(k, ks.kernels), weights = vcat(1.0, ks.weights))
35-
Base.:+(k::ScaledKernel, ks::KernelSum) =
36-
KernelSum(vcat(kernel(k), ks.kernels), weights = vcat(first(k.σ²), ks.weights))
37-
Base.:+(k::ScaledKernel, ks::Kernel) =
38-
KernelSum(vcat(kernel(k), ks), weights = vcat(first(k.σ²), 1.0))
39-
Base.:+(ks::KernelSum, k::Kernel) =
40-
KernelSum(vcat(ks.kernels, k), weights = vcat(ks.weights, 1.0))
41-
Base.:+(ks::KernelSum, k::ScaledKernel) =
42-
KernelSum(vcat(ks.kernels, kernel(k)), weights = vcat(ks.weights, first(k.σ²)))
43-
Base.:+(ks::Kernel, k::ScaledKernel) =
44-
KernelSum(vcat(ks, kernel(k)), weights = vcat(1.0, first(k.σ²)))
45-
Base.:*(w::Real, k::KernelSum) = KernelSum(k.kernels, weights = w * k.weights) #TODO add tests
51+
Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...)
52+
53+
function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}})
54+
return KernelSum(vcat(k, ks.kernels))
55+
end
56+
57+
Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...)
58+
59+
function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel)
60+
return KernelSum(vcat(ks.kernels, k))
61+
end
62+
63+
Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k)
4664

4765
Base.length(k::KernelSum) = length(k.kernels)
4866

49-
::KernelSum)(x, y) = sum(κ.weights[i] * κ.kernels[i](x, y) for i in 1:length(κ))
67+
::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)
5068

5169
function kernelmatrix::KernelSum, x::AbstractVector)
52-
return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x) for i in 1:length(κ))
70+
return sum(kernelmatrix(k, x) for k in κ.kernels)
5371
end
5472

5573
function kernelmatrix::KernelSum, x::AbstractVector, y::AbstractVector)
56-
return sum(κ.weights[i] * kernelmatrix(κ.kernels[i], x, y) for i in 1:length(κ))
74+
return sum(kernelmatrix(k, x, y) for k in κ.kernels)
5775
end
5876

5977
function kerneldiagmatrix::KernelSum, x::AbstractVector)
60-
return sum(κ.weights[i] * kerneldiagmatrix(κ.kernels[i], x) for i in 1:length(κ))
78+
return sum(kerneldiagmatrix(k, x) for k in κ.kernels)
6179
end
6280

6381
function Base.show(io::IO, κ::KernelSum)
6482
printshifted(io, κ, 0)
6583
end
6684

85+
function Base.:(==)(x::KernelSum, y::KernelSum)
86+
return (
87+
length(x.kernels) == length(y.kernels) &&
88+
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
89+
)
90+
end
91+
6792
function printshifted(io::IO::KernelSum, shift::Int)
6893
print(io,"Sum of $(length(κ)) kernels:")
69-
for i in 1:length(κ)
70-
print(io, "\n" * ("\t" ^ (shift + 1)) * "- (w = $(κ.weights[i])) ")
71-
printshifted(io, κ.kernels[i], shift + 2)
94+
for k in κ.kernels
95+
print(io, "\n" )
96+
for _ in 1:(shift + 1)
97+
print(io, "\t")
98+
end
99+
printshifted(io, k, shift + 2)
72100
end
73101
end

src/kernels/tensorproduct.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ end
101101

102102
Base.show(io::IO, kernel::TensorProduct) = printshifted(io, kernel, 0)
103103

104+
function Base.:(==)(x::TensorProduct, y::TensorProduct)
105+
return (
106+
length(x.kernels) == length(y.kernels) &&
107+
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
108+
)
109+
end
110+
104111
function printshifted(io::IO, kernel::TensorProduct, shift::Int)
105112
print(io, "Tensor product of ", length(kernel), " kernels:")
106113
for k in kernel.kernels

src/trainable.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ trainable(k::GaborKernel) = (k.kernel,)
2828

2929
trainable::KernelProduct) = κ.kernels
3030

31-
trainable::KernelSum) = .weights, κ.kernels) #To check
31+
trainable::KernelSum) = κ.kernels #To check
3232

3333
trainable::ScaledKernel) =.σ², κ.kernel)
3434

test/kernels/kernelproduct.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,38 @@
11
@testset "kernelproduct" begin
22
rng = MersenneTwister(123456)
3+
x = rand(rng)*2
34
v1 = rand(rng, 3)
45
v2 = rand(rng, 3)
56

67
k1 = LinearKernel()
78
k2 = SqExponentialKernel()
89
k3 = RationalQuadraticKernel()
10+
X = rand(rng, 2,2)
911

10-
k = KernelProduct([k1, k2])
12+
k = KernelProduct(k1,k2)
13+
ks1 = 2.0*k1
14+
ks2 = 0.5*k2
1115
@test length(k) == 2
16+
@test string(k) == (
17+
"Product of 2 kernels:\n\tLinear Kernel (c = 0.0)\n\tSquared " *
18+
"Exponential Kernel"
19+
)
1220
@test k(v1, v2) == (k1 * k2)(v1, v2)
13-
@test (k * k)(v1, v2) k(v1, v2)^2
14-
@test (k * k3)(v1, v2) (k3 * k)(v1, v2)
21+
@test (k * k3)(v1,v2) (k3 * k)(v1, v2)
22+
@test (k1 * k2)(v1, v2) == KernelProduct(k1, k2)(v1, v2)
23+
@test (k * ks1)(v1, v2) (ks1 * k)(v1, v2)
24+
@test (k * k)(v1, v2) == KernelProduct([k1, k2, k1, k2])(v1, v2)
25+
@test KernelProduct([k1, k2]) == KernelProduct((k1, k2)) == k1 * k2
1526

16-
@testset "kernelmatrix" begin
27+
@test (KernelProduct([k1, k2]) * KernelProduct([k2, k1])).kernels == [k1, k2, k2, k1]
28+
@test (KernelProduct([k1, k2]) * k3).kernels == [k1, k2, k3]
29+
@test (k3 * KernelProduct([k1, k2])).kernels == [k3, k1, k2]
30+
31+
@test (KernelProduct((k1, k2)) * KernelProduct((k2, k1))).kernels == (k1, k2, k2, k1)
32+
@test (KernelProduct((k1, k2)) * k3).kernels == (k1, k2, k3)
33+
@test (k3 * KernelProduct((k1, k2))).kernels == (k3, k1, k2)
34+
35+
@testset "kernelmatrix" begin
1736
rng = MersenneTwister(123456)
1837

1938
Nx = 5
@@ -22,8 +41,8 @@
2241

2342
w1 = rand(rng) + 1e-3
2443
w2 = rand(rng) + 1e-3
25-
k1 = SqExponentialKernel()
26-
k2 = LinearKernel()
44+
k1 = w1 * SqExponentialKernel()
45+
k2 = w2 * LinearKernel()
2746
k = k1 * k2
2847

2948
@testset "$(typeof(x))" for (x, y) in [
@@ -47,6 +66,5 @@
4766
@test kerneldiagmatrix!(tmp_diag, k, x) kerneldiagmatrix(k, x)
4867
end
4968
end
50-
test_ADs(x->SqExponentialKernel() * LinearKernel(c= x[1]), rand(1), ADs = [:ForwardDiff, :ReverseDiff])
51-
@test_broken "Zygote issue"
69+
test_ADs(x->SqExponentialKernel() * LinearKernel(c= x[1]), rand(1), ADs = [:ForwardDiff, :ReverseDiff, :Zygote])
5270
end

0 commit comments

Comments
 (0)