Skip to content

Commit 052807f

Browse files
committed
Added tests and corrected some
1 parent 23a9797 commit 052807f

File tree

6 files changed

+22
-7
lines changed

6 files changed

+22
-7
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
3+
export kernel, kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
44
export Kernel
55
export ConstantKernel, WhiteKernel, ZeroKernel
66
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel

src/kernels/polynomial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function LinearKernel(ρ::A,c::T=zero(eltype(ρ))) where {A<:AbstractVector{<:Re
2424
LinearKernel{eltype(A),ScaleTransform{A},T}(ScaleTransform(ρ),c)
2525
end
2626

27-
function LinearKernel(t::Tr,c::T=zero(eltype(t))) where {Tr<:Transform,T<:Real}
27+
function LinearKernel(t::Tr,c::T=zero(Float64)) where {Tr<:Transform,T<:Real}
2828
LinearKernel{eltype(t),Tr,T}(t,c)
2929
end
3030

@@ -59,7 +59,7 @@ function PolynomialKernel(ρ::A,d::T₁=2.0,c::T₂=zero(eltype(ρ))) where {A<:
5959
PolynomialKernel{eltype(A),ScaleTransform{A},T₁,T₂}(ScaleTransform(ρ),c,d)
6060
end
6161

62-
function PolynomialKernel(t::Tr,d::T₁=2.0,c::T₂=zero(eltype(t))) where {Tr<:Transform,T₁<:Real,T₂<:Real}
62+
function PolynomialKernel(t::Tr,d::T₁=2.0,c::T₂=zero(eltype(T₁))) where {Tr<:Transform,T₁<:Real,T₂<:Real}
6363
@check_args(PolynomialKernel, d, d >= one(T₁), "d >= 1")
6464
PolynomialKernel{eltype(Tr),Tr,T₁,T₂}(t,c,d)
6565
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using Test
22
using KernelFunctions
33
using Distances
4-
using FiniteDifferences
4+
# using FiniteDifferences
55
using Random
66
using Zygote
77

88
# Helpful functionality for writing tests.
9-
include("test_util.jl")
9+
# include("test_util.jl")
1010

1111
@testset "KernelFunctions" begin
1212
include("test_kernelmatrix.jl")

test/test_distances.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ B = rand(20,5)
77
@testset "Distance" begin
88
@testset "Dot Product" begin
99
d = KernelFunctions.DotProduct()
10-
@test diag(pairwise(d,A,dims=2)) == dot.(eachcol(A),eachcol(A))
10+
@test diag(pairwise(d,A,dims=2)) == [dot(A[:,i],A[:,i]) for i in 1:size(A,2)]
1111
@test_throws DimensionMismatch d(rand(3),rand(4))
1212
@test d(3.0,2.0) == 6.0
1313
end

test/test_kernelmatrix.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ k = SqExponentialKernel()
1212
@testset "Inplace Kernel Matrix" begin
1313
for obsdim in [1,2]
1414
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
15+
@test kernelmatrix!(K[obsdim],k,A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
1516
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
1617
end
1718
end
@@ -22,5 +23,6 @@ end
2223
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
2324
@test k(A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
2425
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
26+
@test kernel(k,1.0,2.0) == kernel(k,[1.0],[2.0])
2527
end
2628
end

test/test_kernels.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using LinearAlgebra
33
using KernelFunctions
44
using SpecialFunctions
55

6-
x = rand()*2; v1 = rand(3); v2 = rand(3)
6+
x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
77
@testset "Kappa functions of kernels" begin
88
@testset "Constant" begin
99
@testset "ZeroKernel" begin
@@ -32,6 +32,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
3232
k = SqExponentialKernel()
3333
@test kappa(k,x) exp(-x)
3434
@test k(v1,v2) exp(-norm(v1-v2)^2)
35+
@test kappa(SqExponentialKernel(id),x) == kappa(k,x)
3536
l = 0.5
3637
k = SqExponentialKernel(l)
3738
@test k(v1,v2) exp(-l^2*norm(v1-v2)^2)
@@ -43,6 +44,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
4344
k = ExponentialKernel()
4445
@test kappa(k,x) exp(-x)
4546
@test k(v1,v2) exp(-norm(v1-v2))
47+
@test kappa(ExponentialKernel(id),x) == kappa(k,x)
4648
l = 0.5
4749
k = ExponentialKernel(l)
4850
@test k(v1,v2) exp(-l*norm(v1-v2))
@@ -54,12 +56,16 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
5456
k = GammaExponentialKernel(1.0,2.0)
5557
@test kappa(k,x) exp(-(x)^(k.γ))
5658
@test k(v1,v2) exp(-norm(v1-v2)^(2k.γ))
59+
@test kappa(GammaExponentialKernel(id),x) == kappa(k,x)
5760
l = 0.5
5861
k = GammaExponentialKernel(l,1.5)
5962
@test k(v1,v2) exp(-l^(3.0)*norm(v1-v2)^(3.0))
6063
v = rand(3)
6164
k = GammaExponentialKernel(v,3.0)
6265
@test k(v1,v2) exp(-norm(v.*(v1-v2)).^6.0)
66+
#Coherence :
67+
@test kernel(GammaExponentialKernel(1.0,1.0),v1,v2) kernel(SqExponentialKernel(),v1,v2)
68+
@test kernel(GammaExponentialKernel(1.0,0.5),v1,v2) kernel(ExponentialKernel(),v1,v2)
6369
end
6470
end
6571
@testset "Exponentiated" begin
@@ -83,6 +89,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
8389
matern(x,ν) = 2^(1-ν)/gamma(ν)*(sqrt(2ν)*x)^ν*besselk(ν,sqrt(2ν)*x)
8490
@test kappa(k,x) matern(x,ν)
8591
@test kappa(k,0.0) == 1.0
92+
@test kappa(MaternKernel(id,ν),x) == kappa(k,x)
8693
l = 0.5; ν = 3.0
8794
k = MaternKernel(l,ν)
8895
@test k(v1,v2) matern(l*norm(v1-v2),ν)
@@ -94,6 +101,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
94101
k = Matern32Kernel()
95102
@test kappa(k,x) (1+sqrt(3)*x)exp(-sqrt(3)*x)
96103
@test k(v1,v2) (1+sqrt(3)*norm(v1-v2))exp(-sqrt(3)*norm(v1-v2))
104+
@test kappa(Matern32Kernel(id),x) == kappa(k,x)
97105
l = 0.5
98106
k = Matern32Kernel(l)
99107
@test k(v1,v2) (1+l*sqrt(3)*norm(v1-v2))exp(-l*sqrt(3)*norm(v1-v2))
@@ -105,6 +113,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
105113
k = Matern52Kernel()
106114
@test kappa(k,x) (1+sqrt(5)*x+5/3*x^2)exp(-sqrt(5)*x)
107115
@test k(v1,v2) (1+sqrt(5)*norm(v1-v2)+5/3*norm(v1-v2)^2)exp(-sqrt(5)*norm(v1-v2))
116+
@test kappa(Matern52Kernel(id),x) == kappa(k,x)
108117
l = 0.5
109118
k = Matern52Kernel(l)
110119
@test k(v1,v2) (1+l*sqrt(5)*norm(v1-v2)+l^2*5/3*norm(v1-v2)^2)exp(-l*sqrt(5)*norm(v1-v2))
@@ -124,6 +133,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
124133
k = LinearKernel()
125134
@test kappa(k,x) x
126135
@test k(v1,v2) dot(v1,v2)
136+
@test kappa(LinearKernel(id),x) == kappa(k,x)
127137
l = 0.5
128138
k = LinearKernel(l,c)
129139
@test k(v1,v2) l^2*dot(v1,v2) + c
@@ -135,6 +145,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
135145
k = PolynomialKernel()
136146
@test kappa(k,x) x^2
137147
@test k(v1,v2) dot(v1,v2)^2
148+
@test kappa(PolynomialKernel(id),x) == kappa(k,x)
138149
d = 3.0
139150
l = 0.5
140151
k = PolynomialKernel(l,d,c)
@@ -151,6 +162,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
151162
k = RationalQuadraticKernel()
152163
@test kappa(k,x) (1.0+x/2.0)^-2
153164
@test k(v1,v2) (1.0+norm(v1-v2)^2/2.0)^-2
165+
@test kappa(RationalQuadraticKernel(id),x) == kappa(k,x)
154166
l = 0.5
155167
a = 1.0 + rand()
156168
k = RationalQuadraticKernel(l,a)
@@ -163,6 +175,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
163175
k = GammaRationalQuadraticKernel()
164176
@test kappa(k,x) (1.0+x^2.0/2.0)^-2
165177
@test k(v1,v2) (1.0+norm(v1-v2)^4.0/2.0)^-2
178+
@test kappa(GammaRationalQuadraticKernel(id),x) == kappa(k,x)
166179
l = 0.5
167180
a = 1.0 + rand()
168181
g = 4.0

0 commit comments

Comments
 (0)