Skip to content

Commit e9c32a4

Browse files
committed
FInished tests for kernel functions
1 parent 6d8afad commit e9c32a4

File tree

5 files changed

+63
-10
lines changed

5 files changed

+63
-10
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
77
export ExponentiatedKernel
88
export MaternKernel, Matern32Kernel, Matern52Kernel
99
export LinearKernel, PolynomialKernel
10+
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1011

1112

1213

src/kernels/polynomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151

5252
function PolynomialKernel::T₁=1.0,d::T₂=2.0,c::T₃=zero(T₁)) where {T₁<:Real,T₂<:Real,T₃<:Real}
5353
@check_args(PolynomialKernel, d, d >= one(T₁), "d >= 1")
54-
Polynomial{T₁,ScaleTransform{T₁},T₂,T₃}(ScaleTransform(ρ),c,d)
54+
PolynomialKernel{T₁,ScaleTransform{T₁},T₂,T₃}(ScaleTransform(ρ),c,d)
5555
end
5656

5757
function PolynomialKernel::A,d::T₁=2.0,c::T₂=zero(eltype(ρ))) where {A<:AbstractVector{<:Real},T₁<:Real,T₂<:Real}

src/kernels/rationalquad.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
```
66
κ(x,y)=(1+||x−y||^2/α)^(-α)
77
```
8-
where α is a shape parameter of the Euclidean distance. Check `GammaRationalQuadraticKernel` for a generalization.
8+
where `α` is a shape parameter of the Euclidean distance. Check `GammaRationalQuadraticKernel` for a generalization.
99
"""
1010
struct RationalQuadraticKernel{T,Tr,Tα<:Real} <: Kernel{T,Tr}
1111
transform::Tr
@@ -31,7 +31,7 @@ function RationalQuadraticKernel(t::Tr,α::T=2.0) where {Tr<:Transform,T<:Real}
3131
RationalQuadraticKernel{eltype(t),Tr,T}(t,α)
3232
end
3333

34-
@inline kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+0.5d²/κ.α)^(-κ.α)
34+
@inline kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+d²/κ.α)^(-κ.α)
3535

3636

3737
"""
@@ -48,27 +48,27 @@ struct GammaRationalQuadraticKernel{T,Tr,Tα<:Real,Tγ<:Real} <: Kernel{T,Tr}
4848
metric::SqEuclidean
4949
α::Tα
5050
γ::Tγ
51-
function GammaRationalQuadraticKernel{T,Tr,Tα}(t::Tr::Tα,γ::Tγ) where {T,Tr,Tα<:Real,Tγ<:Real}
51+
function GammaRationalQuadraticKernel{T,Tr,Tα,Tγ}(t::Tr::Tα,γ::Tγ) where {T,Tr,Tα<:Real,Tγ<:Real}
5252
new{T,Tr,Tα,Tγ}(t,SqEuclidean(),α,γ)
5353
end
5454
end
5555

5656
function GammaRationalQuadraticKernel::T₁=1.0::T₂=2.0::T₃=2.0) where {T₁<:Real,T₂<:Real,T₃<:Real}
5757
@check_args(GammaRationalQuadraticKernel, α, α > one(T₂), "α > 1")
58-
@check_args(GammaRationalQuadraticKernel, γ, γ > one(T₂), "γ > 1")
58+
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T₂), "γ >= 1")
5959
GammaRationalQuadraticKernel{T₁,ScaleTransform{T₁},T₂,T₃}(ScaleTransform(ρ),α,γ)
6060
end
6161

6262
function GammaRationalQuadraticKernel::A::T₁=2.0::T₂=2.0) where {A<:AbstractVector{<:Real},T₁<:Real,T₂<:Real}
6363
@check_args(GammaRationalQuadraticKernel, α, α > one(T₁), "α > 1")
64-
@check_args(GammaRationalQuadraticKernel, γ, γ > one(T₂), "γ > 1")
64+
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T₂), "γ >= 1")
6565
GammaRationalQuadraticKernel{eltype(A),ScaleTransform{A},T₁,T₂}(ScaleTransform(ρ),α,γ)
6666
end
6767

6868
function GammaRationalQuadraticKernel(t::Tr::T₁=2.0::T₂=2.0) where {Tr<:Transform,T₁<:Real,T₂<:Real}
6969
@check_args(GammaRationalQuadraticKernel, α, α > one(T₁), "α > 1")
70-
@check_args(GammaRationalQuadraticKernel, γ, γ > one(T₂), "γ > 1")
70+
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T₂), "γ >= 1")
7171
GammaRationalQuadraticKernel{eltype(t),Tr,T₁,T₂}(t,α,γ)
7272
end
7373

74-
@inline kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+0.5/κ.α)^(-κ.α)
74+
@inline kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@ include("test_kernelmatrix.jl")
1313
include("test_constructors.jl")
1414
# include("test_AD.jl")
1515
include("test_transform.jl")
16+
include("test_distances.jl")
17+
include("test_kernels.jl")
1618
#include("types.jl")
1719
end

test/test_kernels.jl

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,66 @@ x = rand()*2; v1 = rand(3); v2 = rand(3)
113113
@test k(v1,v2) (1+sqrt(5)*norm(v.*(v1-v2))+5/3*norm(v.*(v1-v2))^2)exp(-sqrt(5)*norm(v.*(v1-v2)))
114114
end
115115
@testset "Coherence Materns" begin
116-
x = 0.5
117116
@test kappa(MaternKernel(1.0,0.5),x) kappa(ExponentialKernel(),x)
118117
@test kappa(MaternKernel(1.0,1.5),x) kappa(Matern32Kernel(),x)
119118
@test kappa(MaternKernel(1.0,2.5),x) kappa(Matern52Kernel(),x)
120119
end
121120
end
122121
@testset "Polynomial" begin
123122
c = randn();
124-
123+
@testset "LinearKernel" begin
124+
k = LinearKernel()
125+
@test kappa(k,x) x
126+
@test k(v1,v2) dot(v1,v2)
127+
l = 0.5
128+
k = LinearKernel(l,c)
129+
@test k(v1,v2) l^2*dot(v1,v2) + c
130+
v = rand(3)
131+
k = LinearKernel(v,c)
132+
@test k(v1,v2) dot(v.*v1,v.*v2) + c
133+
end
134+
@testset "PolynomialKernel" begin
135+
k = PolynomialKernel()
136+
@test kappa(k,x) x^2
137+
@test k(v1,v2) dot(v1,v2)^2
138+
d = 3.0
139+
l = 0.5
140+
k = PolynomialKernel(l,d,c)
141+
@test k(v1,v2) (l^2*dot(v1,v2) + c)^d
142+
v = rand(3)
143+
k = PolynomialKernel(v,d,c)
144+
@test k(v1,v2) (dot(v.*v1,v.*v2) + c)^d
145+
#Coherence test
146+
@test kappa(PolynomialKernel(1.0,1.0,c),x) kappa(LinearKernel(1.0,c),x)
147+
end
125148
end
126149
@testset "RationalQuadratic" begin
150+
@testset "RationalQuadraticKernel" begin
151+
k = RationalQuadraticKernel()
152+
@test kappa(k,x) (1.0+x/2.0)^-2
153+
@test k(v1,v2) (1.0+norm(v1-v2)^2/2.0)^-2
154+
l = 0.5
155+
a = 1.0 + rand()
156+
k = RationalQuadraticKernel(l,a)
157+
@test k(v1,v2) (1.0+l^2*norm(v1-v2)^2/a)^-a
158+
v = rand(3)
159+
k = RationalQuadraticKernel(v,a)
160+
@test k(v1,v2) (1.0+norm(v.*(v1-v2))^2/a)^-a
161+
end
162+
@testset "GammaRationalQuadraticKernel" begin
163+
k = GammaRationalQuadraticKernel()
164+
@test kappa(k,x) (1.0+x^2.0/2.0)^-2
165+
@test k(v1,v2) (1.0+norm(v1-v2)^4.0/2.0)^-2
166+
l = 0.5
167+
a = 1.0 + rand()
168+
g = 4.0
169+
k = GammaRationalQuadraticKernel(l,a,g)
170+
@test k(v1,v2) (1.0+(l^2g)*norm(v1-v2)^(2g)/a)^-a
171+
v = rand(3)
172+
k = GammaRationalQuadraticKernel(v,a,g)
173+
@test k(v1,v2) (1.0+(norm(v.*(v1-v2))^(2g))/a)^-a
174+
#Coherence test
175+
@test kappa(GammaRationalQuadraticKernel(1.0,a,1.0),x) kappa(RationalQuadraticKernel(1.0,a),x)
176+
end
127177
end
128178
end

0 commit comments

Comments
 (0)