Skip to content

Commit 6d8afad

Browse files
committed
Improved tests considerably and corrected a lot of bugs
1 parent efe525c commit 6d8afad

13 files changed

+202
-39
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[![Build Status](https://travis-ci.org/theogf/KernelFunctions.jl.svg?branch=master)](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl)
2+
[![Coverage Status](https://coveralls.io/repos/github/theogf/KernelFunctions.jl/badge.svg?branch=master)](https://coveralls.io/github/theogf/KernelFunctions.jl?branch=master)
23
[![Documentation](https://img.shields.io/badge/docs-dev-blue.svg)](https://theogf.github.io/KernelFunctions.jl/dev/)
34
# KernelFunctions.jl (WIP)
45
Julia Package for kernel functions for machine learning

docs/src/transform.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
11
# Transform
2+
3+
`Transform` is the object that takes care of transforming the input data before distances are being computed. It can be as standard as `IdentityTransform` returning the same input, can be a scalar with `ScaleTransform` multiplying the vectors by a scalar or a vector.
4+
There is a more general `Transform`: `FunctionTransform` that uses a function and apply it on each vector via `mapslices`.
5+
You can also create a pipeline of `Transform` via `TransformChain`. For example `LowRankTransform(rand(10,5))∘ScaleTransform(2.0)`.
6+
7+
## Transforms :
8+
9+
```@docs
10+
IdentityTransform
11+
ScaleTransform
12+
LowRankTransform
13+
FunctionTransform
14+
TransformChain
15+
```

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
44
export Kernel
5+
export ConstantKernel, WhiteKernel, ZeroKernel
56
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
7+
export ExponentiatedKernel
68
export MaternKernel, Matern32Kernel, Matern52Kernel
79
export LinearKernel, PolynomialKernel
8-
export ConstantKernel, WhiteKernel, ZeroKernel
910

1011

1112

src/generic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
55
@eval begin
66
@inline::$k)(d::Real) = kappa(κ,d)
7-
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
8-
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9-
@inline::$k)(x::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,obsdim=obsdim)
7+
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate.metric,transform(κ,x),transform(κ,y)))
8+
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
9+
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)
1010
end
1111
end
1212

src/kernels/constant.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ struct ZeroKernel{T,Tr} <: Kernel{T,Tr}
88
transform::Tr
99
metric::Delta
1010
function ZeroKernel{T,Tr}(t::Tr) where {T,Tr<:Transform}
11-
new{eltype{Tr},Tr}(t,Delta())
11+
new{T,Tr}(t,Delta())
1212
end
1313
end
1414

15+
function ZeroKernel(t::Tr=IdentityTransform()) where {Tr<:Transform}
16+
ZeroKernel{eltype(Tr),Tr}(t)
17+
end
18+
1519
@inline kappa::ZeroKernel,d::T) where {T<:Real} = zero(T)
1620

1721
"""
@@ -30,11 +34,7 @@ struct WhiteKernel{T,Tr} <: Kernel{T,Tr}
3034
end
3135
end
3236

33-
function WhiteKernel()
34-
WhiteKernel{Float64,IdentityTransform}(IdentityTransform())
35-
end
36-
37-
function WhiteKernel(t::Tr) where {Tr<:Transform}
37+
function WhiteKernel(t::Tr=IdentityTransform()) where {Tr<:Transform}
3838
WhiteKernel{eltype(Tr),Tr}(t)
3939
end
4040

src/kernels/exponential.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080

8181
function GammaExponentialKernel::T₁=1.0,gamma::T₂=2.0) where {T₁<:Real,T₂<:Real}
8282
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T₂), "gamma > 0")
83-
Polynomial{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),gamma)
83+
GammaExponentialKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),gamma)
8484
end
8585

8686
function GammaExponentialKernel::A,gamma::T₁=2.0) where {A<:AbstractVector{<:Real},T₁<:Real}
@@ -93,4 +93,4 @@ function GammaExponentialKernel(t::Tr,gamma::T₁=2.0) where {Tr<:Transform,T₁
9393
GammaExponentialKernel{eltype(Tr),Tr,T₁}(t,gamma)
9494
end
9595

96-
@inline kappa::GammaExponentialKernel, d²::Real) where {T} = exp(-^γ)
96+
@inline kappa::GammaExponentialKernel, d²::Real) where {T} = exp(-^κ.γ)

src/kernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
struct ExponentiatedKernel{T,Tr} <: Kernel{T,Tr}
1111
transform::Tr
1212
metric::DotProduct
13-
function ExponentiatedKernel{T}(transform::Tr) where {T,Tr<:Transform}
13+
function ExponentiatedKernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
1414
return new{T,Tr}(transform,DotProduct())
1515
end
1616
end

src/transform/scaletransform.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Scale Transform
3+
"""
14
struct ScaleTransform{T<:Union{Real,AbstractVector{<:Real}}} <: Transform
25
s::T
36
end
@@ -26,7 +29,7 @@ function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix
2629
end
2730
_transform(t,X,obsdim)
2831
end
29-
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = t.s .* x
32+
transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real},obsdim::Int=defaultobs) = t.s .* x
3033
_transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.s'.*X : t.s .* X
3134

3235
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat,obsdim::Int=defaultobs) = t.s .* x

src/transform/transform.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export Transform, ScaleTransform, LowRankTransform, FunctionTransform, TransformChain
1+
export Transform, IdentityTransform, ScaleTransform, LowRankTransform, FunctionTransform, ChainTransform
22

33

44
abstract type Transform end
@@ -7,27 +7,27 @@ include("scaletransform.jl")
77
include("lowranktransform.jl")
88
include("functiontransform.jl")
99

10-
struct TransformChain <: Transform
10+
struct ChainTransform <: Transform
1111
transforms::Vector{Transform}
1212
end
1313

14-
Base.length(t::TransformChain) = length(t.transforms)
14+
Base.length(t::ChainTransform) = length(t.transforms)
1515

16-
function TransformChain(v::AbstractVector{<:Transform})
17-
TransformChain(v)
16+
function ChainTransform(v::AbstractVector{<:Transform})
17+
ChainTransform(v)
1818
end
1919

20-
function transform(t::TransformChain,X::T,obsdim::Int=defaultobs) where {T}
20+
function transform(t::ChainTransform,X::T,obsdim::Int=defaultobs) where {T}
2121
Xtr = copy(X)
2222
for tr in t.transforms
2323
Xtr = transform(tr,Xtr,obsdim)
2424
end
2525
return Xtr
2626
end
2727

28-
Base.:(t₁::Transform,t₂::Transform) = TransformChain([t₂,t₁])
29-
Base.:(t::Transform,tc::TransformChain) = TransformChain(vcat(tc.transforms,t))
30-
Base.:(tc::TransformChain,t::Transform) = TransformChain(vcat(t,tc.transforms))
28+
Base.:(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁])
29+
Base.:(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t))
30+
Base.:(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms))
3131

3232
struct IdentityTransform <: Transform end
3333

test/test_distances.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Test
2+
using Distances, LinearAlgebra
3+
using KernelFunctions
4+
5+
A = rand(10,5)
6+
B = rand(20,5)
7+
@testset "Distance" begin
8+
@testset "Dot Product" begin
9+
d = KernelFunctions.DotProduct()
10+
@test diag(pairwise(d,A,dims=2)) == dot.(eachcol(A),eachcol(A))
11+
@test_throws DimensionMismatch d(rand(3),rand(4))
12+
@test d(3.0,2.0) == 6.0
13+
end
14+
@testset "Delta" begin
15+
d = KernelFunctions.Delta()
16+
@test pairwise(d,A,dims=1) == Matrix(I,size(A,1),size(A,1))
17+
@test pairwise(d,A,B,dims=1) == zeros(size(A,1),size(B,1))
18+
@test d(1,2) == 0
19+
@test d(1,1) == 1
20+
@test_throws DimensionMismatch d(rand(3),rand(4))
21+
end
22+
end

test/test_kernelmatrix.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,19 @@ A = rand(dims...)
88
B = rand(dims...)
99
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
1010
Kdiag = [zeros(dims[1]),zeros(dims[2])]
11-
kernels = [SqExponentialKernel(),MaternKernel(),Matern32Kernel(),Matern52Kernel()]
11+
k = SqExponentialKernel()
1212
@testset "Inplace Kernel Matrix" begin
13-
for k in kernels
14-
@testset "$k" begin
15-
for obsdim in [1,2]
16-
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
17-
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
18-
end
19-
end
13+
for obsdim in [1,2]
14+
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
15+
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
2016
end
2117
end
2218

2319
@testset "Kernel matrix" begin
24-
for k in kernels
25-
@testset "$k" begin
26-
for obsdim in [1,2]
27-
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
28-
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
29-
end
30-
end
20+
for obsdim in [1,2]
21+
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
22+
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
23+
@test k(A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
24+
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
3125
end
3226
end

test/test_kernels.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
using Test
2+
using LinearAlgebra
3+
using KernelFunctions
4+
using SpecialFunctions
5+
6+
x = rand()*2; v1 = rand(3); v2 = rand(3)
7+
@testset "Kappa functions of kernels" begin
8+
@testset "Constant" begin
9+
@testset "ZeroKernel" begin
10+
k = ZeroKernel()
11+
@test eltype(k) == Any
12+
@test kappa(k,2.0) == 0.0
13+
end
14+
@testset "WhiteKernel" begin
15+
k = WhiteKernel()
16+
@test eltype(k) == Any
17+
@test kappa(k,1.0) == 1.0
18+
@test kappa(k,0.0) == 0.0
19+
end
20+
@testset "ConstantKernel" begin
21+
c = 2.0
22+
k = ConstantKernel(c)
23+
k₂ = ConstantKernel(IdentityTransform(),c)
24+
@test eltype(k) == Any
25+
@test kappa(k,1.5)== kappa(k₂,1.5)
26+
@test kappa(k,1.0) == c
27+
@test kappa(k,0.5) == c
28+
end
29+
end
30+
@testset "Exponential" begin
31+
@testset "SqExponentialKernel" begin
32+
k = SqExponentialKernel()
33+
@test kappa(k,x) exp(-x)
34+
@test k(v1,v2) exp(-norm(v1-v2)^2)
35+
l = 0.5
36+
k = SqExponentialKernel(l)
37+
@test k(v1,v2) exp(-l^2*norm(v1-v2)^2)
38+
v = rand(3)
39+
k = SqExponentialKernel(v)
40+
@test k(v1,v2) exp(-norm(v.*(v1-v2))^2)
41+
end
42+
@testset "ExponentialKernel" begin
43+
k = ExponentialKernel()
44+
@test kappa(k,x) exp(-x)
45+
@test k(v1,v2) exp(-norm(v1-v2))
46+
l = 0.5
47+
k = ExponentialKernel(l)
48+
@test k(v1,v2) exp(-l*norm(v1-v2))
49+
v = rand(3)
50+
k = ExponentialKernel(v)
51+
@test k(v1,v2) exp(-norm(v.*(v1-v2)))
52+
end
53+
@testset "GammaExponentialKernel" begin
54+
k = GammaExponentialKernel(1.0,2.0)
55+
@test kappa(k,x) exp(-(x)^(k.γ))
56+
@test k(v1,v2) exp(-norm(v1-v2)^(2k.γ))
57+
l = 0.5
58+
k = GammaExponentialKernel(l,1.5)
59+
@test k(v1,v2) exp(-l^(3.0)*norm(v1-v2)^(3.0))
60+
v = rand(3)
61+
k = GammaExponentialKernel(v,3.0)
62+
@test k(v1,v2) exp(-norm(v.*(v1-v2)).^6.0)
63+
end
64+
end
65+
@testset "Exponentiated" begin
66+
@testset "ExponentiatedKernel" begin
67+
k = ExponentiatedKernel()
68+
@test kappa(k,x) exp(x)
69+
@test kappa(k,-x) exp(-x)
70+
@test k(v1,v2) exp(dot(v1,v2))
71+
l = 0.5
72+
k = ExponentiatedKernel(l)
73+
@test k(v1,v2) exp(l^2*dot(v1,v2))
74+
v = rand(3)
75+
k = ExponentiatedKernel(v)
76+
@test k(v1,v2) exp(dot(v.*v1,v.*v2))
77+
end
78+
end
79+
@testset "Matern" begin
80+
@testset "MaternKernel" begin
81+
ν = 2.0
82+
k = MaternKernel(1.0,ν)
83+
matern(x,ν) = 2^(1-ν)/gamma(ν)*(sqrt(2ν)*x)^ν*besselk(ν,sqrt(2ν)*x)
84+
@test kappa(k,x) matern(x,ν)
85+
@test kappa(k,0.0) == 1.0
86+
l = 0.5; ν = 3.0
87+
k = MaternKernel(l,ν)
88+
@test k(v1,v2) matern(l*norm(v1-v2),ν)
89+
v = rand(3); ν = 2.1
90+
k = MaternKernel(v,ν)
91+
@test k(v1,v2) matern(norm(v.*(v1-v2)),ν)
92+
end
93+
@testset "Matern32Kernel" begin
94+
k = Matern32Kernel()
95+
@test kappa(k,x) (1+sqrt(3)*x)exp(-sqrt(3)*x)
96+
@test k(v1,v2) (1+sqrt(3)*norm(v1-v2))exp(-sqrt(3)*norm(v1-v2))
97+
l = 0.5
98+
k = Matern32Kernel(l)
99+
@test k(v1,v2) (1+l*sqrt(3)*norm(v1-v2))exp(-l*sqrt(3)*norm(v1-v2))
100+
v = rand(3)
101+
k = Matern32Kernel(v)
102+
@test k(v1,v2) (1+sqrt(3)*norm(v.*(v1-v2)))exp(-sqrt(3)*norm(v.*(v1-v2)))
103+
end
104+
@testset "Matern52Kernel" begin
105+
k = Matern52Kernel()
106+
@test kappa(k,x) (1+sqrt(5)*x+5/3*x^2)exp(-sqrt(5)*x)
107+
@test k(v1,v2) (1+sqrt(5)*norm(v1-v2)+5/3*norm(v1-v2)^2)exp(-sqrt(5)*norm(v1-v2))
108+
l = 0.5
109+
k = Matern52Kernel(l)
110+
@test k(v1,v2) (1+l*sqrt(5)*norm(v1-v2)+l^2*5/3*norm(v1-v2)^2)exp(-l*sqrt(5)*norm(v1-v2))
111+
v = rand(3)
112+
k = Matern52Kernel(v)
113+
@test k(v1,v2) (1+sqrt(5)*norm(v.*(v1-v2))+5/3*norm(v.*(v1-v2))^2)exp(-sqrt(5)*norm(v.*(v1-v2)))
114+
end
115+
@testset "Coherence Materns" begin
116+
x = 0.5
117+
@test kappa(MaternKernel(1.0,0.5),x) kappa(ExponentialKernel(),x)
118+
@test kappa(MaternKernel(1.0,1.5),x) kappa(Matern32Kernel(),x)
119+
@test kappa(MaternKernel(1.0,2.5),x) kappa(Matern52Kernel(),x)
120+
end
121+
end
122+
@testset "Polynomial" begin
123+
c = randn();
124+
125+
end
126+
@testset "RationalQuadratic" begin
127+
end
128+
end

test/test_transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ KernelFunctions.transform(tf,X,1)
2929
tchain = TransformChain([t,tp,tf])
3030
ttptf
3131
TransformChain([t,tp])
32-
@test all(KernelFunctions.transform(tchain,X).==f(P*(s*X)))
32+
@test all(KernelFunctions.transform(tchain,X,2).==f(P*(s*X)))

0 commit comments

Comments
 (0)