Skip to content

Commit f6c3e7a

Browse files
committed
Adapting tests and constructors
1 parent 3d9a52e commit f6c3e7a

File tree

5 files changed

+27
-54
lines changed

5 files changed

+27
-54
lines changed

src/kernels/polynomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Where `c` is a real number, and `d` is a shape parameter bigger than 1
3131
struct PolynomialKernel{Td<:Real,Tc<:Real} <: BaseKernel
3232
d::Td
3333
c::Tc
34-
function PolynomialKernel(;d::Td=2.0, c::Tc=zero(Td)) where {Td<:Real, Tc<:Real}
34+
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}
3535
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
3636
return new{Td, Tc}(d, c)
3737
end

src/kernels/rationalquad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ where `α` is a shape parameter of the Euclidean distance and `γ` is another sh
3232
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
3333
α::Tα
3434
γ::Tγ
35-
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2*one(Tα)) where {Tα<:Real, Tγ<:Real}
35+
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2.0) where {Tα<:Real, Tγ<:Real}
3636
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
3737
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
3838
return new{Tα, Tγ}(α, γ)

src/transform/scaletransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ dim(str::ScaleTransform) = 1
2222
apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s) * x
2323

2424
Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))
25+
26+
Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform s=$(first(t.s))")

test/test_constructors.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,15 @@ s = ScaleTransform(l)
1212
@test KernelFunctions.metric(ExponentialKernel()) == Euclidean()
1313
@test KernelFunctions.metric(SqExponentialKernel()) == SqEuclidean()
1414
@test KernelFunctions.metric(GammaExponentialKernel()) == SqEuclidean()
15-
@test KernelFunctions.metric(GammaExponentialKernel(2.0)) == SqEuclidean()
16-
# @test isequal(transform(SqExponentialKernel(l)),s)
17-
# @test KernelFunctions.transform(SqExponentialKernel(vl)) == ARDTransform(vl)
18-
# @test isequal(KernelFunctions.transform(SqExponentialKernel(s)),s)
15+
@test KernelFunctions.metric(GammaExponentialKernel=2.0)) == SqEuclidean()
1916
end
2017

2118
## MaternKernel
2219
@testset "MaternKernel" begin
2320
@test KernelFunctions.metric(MaternKernel()) == Euclidean()
24-
@test KernelFunctions.metric(MaternKernel(2.0)) == Euclidean()
21+
@test KernelFunctions.metric(MaternKernel(ν=2.0)) == Euclidean()
2522
@test KernelFunctions.metric(Matern32Kernel()) == Euclidean()
2623
@test KernelFunctions.metric(Matern52Kernel()) == Euclidean()
27-
# @test isequal(KernelFunctions.transform(MaternKernel(l)),s)
28-
# @test isequal(KernelFunctions.transform(Matern32Kernel(l)),s)
29-
# @test isequal(KernelFunctions.transform(Matern52Kernel(l)),s)
30-
# @test KernelFunctions.transform(MaternKernel(vl)) == ARDTransform(vl)
31-
# @test KernelFunctions.transform(Matern32Kernel(vl)) == ARDTransform(vl)
32-
# @test KernelFunctions.transform(Matern52Kernel(vl)) == ARDTransform(vl)
33-
# @test KernelFunctions.transform(MaternKernel(s)) == s
34-
# @test KernelFunctions.transform(Matern32Kernel(s)) == s
35-
# @test KernelFunctions.transform(Matern52Kernel(s)) == s
3624
end
3725

3826
@testset "Exponentiated" begin
@@ -41,23 +29,23 @@ end
4129

4230
@testset "Constant" begin
4331
@test KernelFunctions.metric(ConstantKernel()) == KernelFunctions.Delta()
44-
@test KernelFunctions.metric(ConstantKernel(2.0)) == KernelFunctions.Delta()
32+
@test KernelFunctions.metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
4533
@test KernelFunctions.metric(WhiteKernel()) == KernelFunctions.Delta()
4634
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
4735
end
4836

4937
@testset "Polynomial" begin
5038
@test KernelFunctions.metric(LinearKernel()) == KernelFunctions.DotProduct()
51-
@test KernelFunctions.metric(LinearKernel(2.0)) == KernelFunctions.DotProduct()
39+
@test KernelFunctions.metric(LinearKernel(c=2.0)) == KernelFunctions.DotProduct()
5240
@test KernelFunctions.metric(PolynomialKernel()) == KernelFunctions.DotProduct()
53-
@test KernelFunctions.metric(PolynomialKernel(3.0)) == KernelFunctions.DotProduct()
54-
@test KernelFunctions.metric(PolynomialKernel(3.0,2.0)) == KernelFunctions.DotProduct()
41+
@test KernelFunctions.metric(PolynomialKernel(d=3.0)) == KernelFunctions.DotProduct()
42+
@test KernelFunctions.metric(PolynomialKernel(d=3.0,c=2.0)) == KernelFunctions.DotProduct()
5543
end
5644

5745
@testset "RationalQuadratic" begin
5846
@test KernelFunctions.metric(RationalQuadraticKernel()) == SqEuclidean()
59-
@test KernelFunctions.metric(RationalQuadraticKernel(2.0)) == SqEuclidean()
47+
@test KernelFunctions.metric(RationalQuadraticKernel(α=2.0)) == SqEuclidean()
6048
@test KernelFunctions.metric(GammaRationalQuadraticKernel()) == SqEuclidean()
61-
@test KernelFunctions.metric(GammaRationalQuadraticKernel(2.0)) == SqEuclidean()
62-
@test KernelFunctions.metric(GammaRationalQuadraticKernel(2.0,3.0)) == SqEuclidean()
49+
@test KernelFunctions.metric(GammaRationalQuadraticKernel(γ=2.0)) == SqEuclidean()
50+
@test KernelFunctions.metric(GammaRationalQuadraticKernel(γ=2.0,α=3.0)) == SqEuclidean()
6351
end

test/test_kernels.jl

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
1919
end
2020
@testset "ConstantKernel" begin
2121
c = 2.0
22-
k = ConstantKernel(c)
22+
k = ConstantKernel(c=c)
2323
@test eltype(k) == Any
2424
@test kappa(k,1.0) == c
2525
@test kappa(k,0.5) == c
@@ -31,39 +31,22 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
3131
@test kappa(k,x) exp(-x)
3232
@test k(v1,v2) exp(-norm(v1-v2)^2)
3333
@test kappa(SqExponentialKernel(),x) == kappa(k,x)
34-
# l = 0.5
35-
# k = SqExponentialKernel(l)
36-
# @test k(v1,v2) ≈ exp(-l^2*norm(v1-v2)^2)
37-
# v = rand(3)
38-
# k = SqExponentialKernel(v)
39-
# @test k(v1,v2) ≈ exp(-norm(v.*(v1-v2))^2)
4034
end
4135
@testset "ExponentialKernel" begin
4236
k = ExponentialKernel()
4337
@test kappa(k,x) exp(-x)
4438
@test k(v1,v2) exp(-norm(v1-v2))
4539
@test kappa(ExponentialKernel(),x) == kappa(k,x)
46-
# l = 0.5
47-
# k = ExponentialKernel(l)
48-
# @test k(v1,v2) ≈ exp(-l*norm(v1-v2))
49-
# v = rand(3)
50-
# k = ExponentialKernel(v)
51-
# @test k(v1,v2) ≈ exp(-norm(v.*(v1-v2)))
5240
end
5341
@testset "GammaExponentialKernel" begin
54-
k = GammaExponentialKernel(2.0)
55-
@test kappa(k,x) exp(-(x)^(k.γ))
56-
@test k(v1,v2) exp(-norm(v1-v2)^(2k.γ))
42+
γ = 2.0
43+
k = GammaExponentialKernel=γ)
44+
@test kappa(k,x) exp(-(x)^(γ))
45+
@test k(v1,v2) exp(-norm(v1-v2)^(2γ))
5746
@test kappa(GammaExponentialKernel(),x) == kappa(k,x)
58-
# l = 0.5
59-
# k = GammaExponentialKernel(l,1.5)
60-
# @test k(v1,v2) ≈ exp(-l^(3.0)*norm(v1-v2)^(3.0))
61-
# v = rand(3)
62-
# k = GammaExponentialKernel(v,3.0)
63-
# @test k(v1,v2) ≈ exp(-norm(v.*(v1-v2)).^6.0)
6447
#Coherence :
65-
@test KernelFunctions._kernel(GammaExponentialKernel(1.0),v1,v2) KernelFunctions._kernel(SqExponentialKernel(),v1,v2)
66-
@test KernelFunctions._kernel(GammaExponentialKernel(0.5),v1,v2) KernelFunctions._kernel(ExponentialKernel(),v1,v2)
48+
@test KernelFunctions._kernel(GammaExponentialKernel(γ=1.0),v1,v2) KernelFunctions._kernel(SqExponentialKernel(),v1,v2)
49+
@test KernelFunctions._kernel(GammaExponentialKernel(γ=0.5),v1,v2) KernelFunctions._kernel(ExponentialKernel(),v1,v2)
6750
end
6851
end
6952
@testset "Exponentiated" begin
@@ -77,11 +60,11 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
7760
@testset "Matern" begin
7861
@testset "MaternKernel" begin
7962
ν = 2.0
80-
k = MaternKernel(ν)
63+
k = MaternKernel=ν)
8164
matern(x,ν) = 2^(1-ν)/gamma(ν)*(sqrt(2ν)*x)^ν*besselk(ν,sqrt(2ν)*x)
8265
@test kappa(k,x) matern(x,ν)
8366
@test kappa(k,0.0) == 1.0
84-
@test kappa(MaternKernel(ν),x) == kappa(k,x)
67+
@test kappa(MaternKernel=ν),x) == kappa(k,x)
8568
end
8669
@testset "Matern32Kernel" begin
8770
k = Matern32Kernel()
@@ -96,9 +79,9 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
9679
@test kappa(Matern52Kernel(),x) == kappa(k,x)
9780
end
9881
@testset "Coherence Materns" begin
99-
@test kappa(MaternKernel(0.5),x) kappa(ExponentialKernel(),x)
100-
@test kappa(MaternKernel(1.5),x) kappa(Matern32Kernel(),x)
101-
@test kappa(MaternKernel(2.5),x) kappa(Matern52Kernel(),x)
82+
@test kappa(MaternKernel(ν=0.5),x) kappa(ExponentialKernel(),x)
83+
@test kappa(MaternKernel(ν=1.5),x) kappa(Matern32Kernel(),x)
84+
@test kappa(MaternKernel(ν=2.5),x) kappa(Matern52Kernel(),x)
10285
end
10386
end
10487
@testset "Polynomial" begin
@@ -115,7 +98,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
11598
@test k(v1,v2) dot(v1,v2)^2
11699
@test kappa(PolynomialKernel(),x) == kappa(k,x)
117100
#Coherence test
118-
@test kappa(PolynomialKernel(1.0,c),x) kappa(LinearKernel(c),x)
101+
@test kappa(PolynomialKernel(d=1.0,c=c),x) kappa(LinearKernel(c=c),x)
119102
end
120103
end
121104
@testset "RationalQuadratic" begin
@@ -132,7 +115,7 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
132115
@test kappa(GammaRationalQuadraticKernel(),x) == kappa(k,x)
133116
a = 1.0 + rand()
134117
#Coherence test
135-
@test kappa(GammaRationalQuadraticKernel(a,1.0),x) kappa(RationalQuadraticKernel(a),x)
118+
@test kappa(GammaRationalQuadraticKernel(α=a,γ=1.0),x) kappa(RationalQuadraticKernel(α=a),x)
136119
end
137120
end
138121
@testset "Transformed/Scaled Kernel" begin

0 commit comments

Comments
 (0)