Skip to content

Commit f135086

Browse files
committed
Merge branch 'master-dev'
2 parents 5d40a23 + 015a718 commit f135086

12 files changed

+107
-53
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.json
22
Manifest.toml
3+
coverage/

src/KernelFunctions.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module KernelFunctions
33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa, kernelpdmat
44
export get_params, set_params!
55

6-
76
export Kernel
87
export ConstantKernel, WhiteKernel, ZeroKernel
98
export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel
@@ -13,6 +12,7 @@ export LinearKernel, PolynomialKernel
1312
export RationalQuadraticKernel, GammaRationalQuadraticKernel
1413
export KernelSum, KernelProduct
1514

15+
export SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

1717

1818
using Distances, LinearAlgebra
@@ -25,16 +25,21 @@ const defaultobs = 2
2525
include("utils.jl")
2626
include("distances/dotproduct.jl")
2727
include("distances/delta.jl")
28-
include("transform/transform.jl")
2928

3029

30+
"""
31+
Abstract type defining a slice-wise transformation on an input matrix
32+
"""
33+
abstract type Transform end
3134
abstract type Kernel{T,Tr<:Transform} end
3235

36+
include("transform/transform.jl")
3337
kernels = ["exponential","matern","polynomial","constant","rationalquad","exponentiated"]
3438
for k in kernels
3539
include(joinpath("kernels",k*".jl"))
3640
end
3741
include("matrix/kernelmatrix.jl")
42+
include("matrix/kernelpdmat.jl")
3843
include("kernels/kernelsum.jl")
3944
include("kernels/kernelproduct.jl")
4045

src/generic.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52
2929
end
3030
end
3131

32-
function set!(k::Kernel,x)
32+
function set_params!(k::Kernel,x)
3333
@error "Setting parameters to this kernel is either not possible or has not been implemented"
3434
end
35-
36-
set_params!(k::Kernel{T,<:ScaleTransform{<:Base.RefValue{<:Tρ}}}::AbstractVector{<:Tρ}) where {T,Tρ<:Real} = set!(k.transform,ρ[1])
37-
set_params!(k::Kernel{T,<:ScaleTransform{<:AbstractVector{<:Tρ}}}::AbstractVector{<:Tρ}) where {T,Tρ<:Real} = set!(k.transform,ρ)
38-
set_params!(k::Kernel{T,<:LowRankTransform{<:AbstractMatrix{<:Tm}}},m::AbstractMatrix{<:Tm}) where {T,Tm<:Real} = set!(k.transform,m)
39-
40-
get_params(k::Kernel{T,<:ScaleTransform{<:Base.RefValue{<:Tρ}}}) where {T,Tρ} = [k.transform.s[]]
41-
get_params(k::Kernel{T,<:ScaleTransform{<:AbstractVector{<:Tρ}}}) where {T,Tρ} = k.transform.s

src/kernels/kernelproduct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
2424
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k))
2525

2626
Base.length(k::KernelProduct) = length(k.kernels)
27-
metric(k::KernelProduct) = getmetric.(k.kernels) #TODO Add test
27+
metric(k::KernelProduct) = metric.(k.kernels) #TODO Add test
2828
transform(k::KernelProduct) = transform.(k.kernels) #TODO Add test
2929
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x]) #TODO Add test
3030
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim) #TODO Add test

src/matrix/kernelmatrix.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
6262
"""
6363
kernelmatrix
6464

65+
function kernelmatrix(
66+
κ::Kernel,
67+
X::AbstractVector{<:Real};
68+
obsdim::Int=defaultobs
69+
)
70+
kernelmatrix(κ,reshape(X,1,:),obsdim=2)
71+
end
6572

6673
function kernelmatrix(
6774
κ::Kernel,

src/matrix/kernelpdmat.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ function kernelpdmat(
1010
K = kernelmatrix(κ,X,obsdim=obsdim)
1111
Kmax =maximum(K)
1212
α = eps(eltype(K))
13-
while !isposdef(K+αI) && α < 0.01*Kmax
13+
while !isposdef(K+α*I) && α < 0.01*Kmax
1414
α *= 2.0
1515
end
1616
if α >= 0.01*Kmax
17-
@error "Adding noise on the diagonal was not sufficient to build a positive-definite matrix:\n - Check that your kernel parameters are not extreme\n - Check that your data is sufficiently sparse\n - Maybe use a different kernel"
17+
throw(ErrorException("Adding noise on the diagonal was not sufficient to build a positive-definite matrix:\n\t- Check that your kernel parameters are not extreme\n\t- Check that your data is sufficiently sparse\n\t- Maybe use a different kernel"))
1818
end
19-
return PDMat(K+αI)
19+
return PDMat(K+α*I)
2020
end
21+
22+
kernelpdmat::Kernel,X::AbstractVector{<:Real};obsdim=defaultobs) = kernelpdmat(κ,reshape(X,1,:),obsdim=2)

src/transform/lowranktransform.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ function set!(t::LowRankTransform{<:AbstractMatrix{T}},M::AbstractMatrix{T}) whe
1515
@assert size(t) == size(M) "Size of the given matrix $(size(M)) and the projection matrix $(size(t)) are not the same"
1616
t.proj .= M
1717
end
18+
set_params!(k::Kernel{T,<:LowRankTransform{<:AbstractMatrix{<:Tm}}},m::AbstractMatrix{<:Tm}) where {T,Tm<:Real} = set!(k.transform,m)
19+
20+
get_params(k::Kernel{T,<:LowRankTransform}) where {T} = get_params(k.transform)
21+
get_params(t::LowRankTransform) = t.proj
1822

1923
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
2024
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test
@@ -27,7 +31,7 @@ end
2731

2832
function transform(t::LowRankTransform,x::AbstractVector{<:Real},obsdim::Int=defaultobs) #TODO Add test
2933
@assert size(t,2) == length(x) "Vector has wrong dimensions $(length(x)) compared to projection matrix"
30-
t.proj*X
34+
t.proj*x
3135
end
3236

3337
_transform(t::LowRankTransform,X::AbstractVecOrMat{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? t.proj * X : X * t.proj'

src/transform/scaletransform.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,23 @@ end
3131
function set!(t::ScaleTransform{Base.RefValue{T}}::T) where {T<:Real}
3232
t.s[] = ρ
3333
end
34-
3534
function set!(t::ScaleTransform{AbstractVector{T}}::AbstractVector{T}) where {T<:Real}
3635
@assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ScaleTransform of dimension $(dim(t))"
3736
t.s .= ρ
3837
end
38+
set_params!(k::Kernel{T,<:ScaleTransform{<:Base.RefValue{<:Tρ}}}::Tρ) where {T,Tρ<:Real} = set!(k.transform,ρ)
39+
set_params!(k::Kernel{T,<:ScaleTransform{<:Base.RefValue{<:Tρ}}}::AbstractVector{<:Tρ}) where {T,Tρ<:Real} = set!(k.transform,ρ[1])
40+
set_params!(k::Kernel{T,<:ScaleTransform{<:AbstractVector{<:Tρ}}}::AbstractVector{<:Tρ}) where {T,Tρ<:Real} = set!(k.transform,ρ)
41+
42+
get_params(k::Kernel{T,<:ScaleTransform}) where {T} = get_params(k.transform)
43+
get_params(t::ScaleTransform{<:Base.RefValue}) = [t.s[]]
44+
get_params(t::ScaleTransform{<:AbstractVector}) = t.s
3945

4046
dim(str::ScaleTransform{Base.RefValue{<:Real}}) = 1 #TODO Add test
4147
dim(str::ScaleTransform{<:AbstractVector{<:Real}}) = length(str.s)
4248

4349
function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int)
44-
@boundscheck if dim(t) != size(X,!Bool(obsdim-1)+1)
50+
@boundscheck if dim(t) != size(X,feature_dim(obsdim))
4551
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t)).")) #TODO Add test
4652
end
4753
_transform(t,X,obsdim)

src/transform/selecttransform.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,30 @@ Select the dimensions `dims` that the kernel is applied to.
1010
"""
1111
struct SelectTransform{T<:AbstractVector{<:Int}} <: Transform
1212
select::T
13-
dim_max::Int
1413
end
1514

16-
function SelectTransform(dims::AbstractVector{T}) where {T<:Int}
15+
function SelectTransform(dims::V) where {V<:AbstractVector{T} where {T<:Int}}
1716
@assert all(dims.>0) "Selective dimensions should all be positive integers"
18-
SelectTransform{T}(dims,maximum(dims))
17+
SelectTransform{V}(dims)
1918
end
2019

21-
function set!(t::SelectTransform{<:AbstractVector{T}},s::AbstractVector{T}) where {T<:Real}
22-
t.proj .= s
23-
end
20+
get_params(t::SelectTransform) = t.select
21+
get_params(k::Kernel{T,<:SelectTransform}) where {T} = get_params(k.transform)
22+
23+
set!(t::SelectTransform{<:AbstractVector{T}},dims::AbstractVector{T}) where {T<:Int} = t.select .= dims
24+
set_params!(k::Kernel{T,<:SelectTransform{Td}},dims::AbstractVector{Td}) where {T,Td<:Int} = set!(k.transform,dims)
2425

2526
Base.maximum(t::SelectTransform) = maximum(t.select)
2627

2728
function transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs)
28-
@boundscheck t.dim_max <= size(X,feature_dim(obsdim)) ?
29-
throw(DimensionMismatch("The highest index $(t.dim_max) is higher then the feature dimension of X : $(size(X,feature_dim(obsdim)))")) : nothing
29+
@boundscheck maximum(t) >= size(X,feature_dim(obsdim)) ?
30+
throw(DimensionMismatch("The highest index $(maximum(t)) is higher then the feature dimension of X : $(size(X,feature_dim(obsdim)))")) : nothing
3031
@inbounds _transform(t,X,obsdim)
3132
end
3233

3334
function transform(t::SelectTransform,x::AbstractVector{<:Real},obsdim::Int=defaultobs) #TODO Add test
34-
@assert t.dim_max <= length(x) "The highest index $(t.dim_max) is higher then the vector length : $(length(x))"
35-
return x[t.select]
35+
@assert maximum(t) <= length(x) "The highest index $(maximum(t)) is higher then the vector length : $(length(x))"
36+
return @inbounds x[t.select]
3637
end
3738

3839
_transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? X[t.select,:] : X[:,t.select]

src/transform/transform.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
export Transform, IdentityTransform, ScaleTransform, LowRankTransform, FunctionTransform, ChainTransform
22
export transform
33

4-
"""
5-
Abstract type defining a slice-wise transformation on an input matrix
6-
"""
7-
abstract type Transform end
8-
9-
104
"""
115
```julia
126
transform(t::Transform, X::AbstractMatrix)

test/test_kernelmatrix.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Distances, LinearAlgebra
22
using Test
33
using KernelFunctions
4+
using PDMats
45

56
dims = [10,5]
67

@@ -35,7 +36,6 @@ k = SqExponentialKernel()
3536
@testset "KernelSum" begin
3637
k1 = SqExponentialKernel()
3738
k2 = LinearKernel()
38-
k3 =
3939
ks = k1 + k2
4040
w1 = 0.4; w2 = 1.2;
4141
ks2 = KernelSum([k1,k2],weights=[w1,w2])
@@ -49,12 +49,22 @@ k = SqExponentialKernel()
4949
@testset "KernelProduct" begin
5050
k1 = SqExponentialKernel()
5151
k2 = LinearKernel()
52+
k3 = RationalQuadraticKernel()
5253
kp = k1 * k2
54+
kp2 = k1 * k3
55+
@test all(KernelFunctions.metric(kp).==[KernelFunctions.metric(k1),KernelFunctions.metric(k2)])
5356
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
5457
@test all(kernelmatrix(kp*k1,A) .≈ kernelmatrix(k1,A).^2 .* kernelmatrix(k2,A))
5558
@test all(kernelmatrix(k1*kp,A) .≈ kernelmatrix(k1,A).^2 .* kernelmatrix(k2,A))
5659
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
5760
@test all(kernelmatrix(kp,A,B) .≈ kernelmatrix(k1,A,B) .* kernelmatrix(k2,A,B))
5861
@test all(kernelmatrix(kp,A) .≈ kernelmatrix(k1,A) .* kernelmatrix(k2,A))
62+
@test all(kerneldiagmatrix(kp,A) .== kerneldiagmatrix(k1,A) .* kerneldiagmatrix(k2,A))
63+
end
64+
@testset "PDMat" begin
65+
for obsdim in [1,2]
66+
@test all(kernelpdmat(k,A,obsdim=obsdim) .≈ PDMat(kernelmatrix(k,A,obsdim=obsdim)))
67+
# @test_throws ErrorException kernelpdmat(k,ones(100,100),obsdim=obsdim)
68+
end
5969
end
6070
end

test/test_transform.jl

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,58 @@ seed!(42)
66

77
dims = (10,5)
88
X = rand(dims...)
9-
##
9+
x = rand(dims[1])
1010
s = 3.0
1111
v1 = vcat(3.0,4.0*ones(dims[2]-1))
1212
v2 = vcat(3.0,4.0*ones(dims[1]-1))
13-
t = ScaleTransform(s)
14-
vt1 = ScaleTransform(v1)
15-
vt2 = ScaleTransform(v2)
16-
@test all(KernelFunctions.transform(t,X).==s*X)
17-
@test all(KernelFunctions.transform(vt1,X,1).==v1'.*X)
18-
@test all(KernelFunctions.transform(vt2,X,2).==v2.*X)
19-
##
2013
P = rand(5,10)
21-
tp = LowRankTransform(P)
22-
@test all(KernelFunctions.transform(tp,X,2).==P*X)
23-
##
14+
sdims = [1,2,3]
2415
f(x) = sin.(x)
25-
tf = FunctionTransform(f)
26-
KernelFunctions.transform(tf,X,1)
27-
@test all(KernelFunctions.transform(tf,X,1).==f(X))
28-
##
29-
tchain = ChainTransform([t,tp,tf])
30-
@test all(KernelFunctions.transform(tchain,X,2).==f(P*(s*X)))
31-
@test all(KernelFunctions.transform(tchain,X,2).==
32-
KernelFunctions.transform(tftpt,X,2))
16+
17+
@testset "Transform Test" begin
18+
## Test Scale Transform
19+
@testset "ScaleTransform" begin
20+
t = ScaleTransform(s)
21+
vt1 = ScaleTransform(v1)
22+
vt2 = ScaleTransform(v2)
23+
@test all(KernelFunctions.transform(t,X).==s*X)
24+
@test all(KernelFunctions.transform(vt1,X,1).==v1'.*X)
25+
@test all(KernelFunctions.transform(vt2,X,2).==v2.*X)
26+
end
27+
## Test LowRankTransform
28+
@testset "LowRankTransform" begin
29+
tp = LowRankTransform(P)
30+
@test all(KernelFunctions.transform(tp,X,2).==P*X)
31+
@test all(KernelFunctions.transform(tp,x).==P*x)
32+
@test all(get_params(SqExponentialKernel(tp)).==P)
33+
P2 = rand(5,10)
34+
KernelFunctions.set!(tp,P2)
35+
@test all(tp.proj.==P2)
36+
end
37+
## Test FunctionTransform
38+
@testset "FunctionTransform" begin
39+
tf = FunctionTransform(f)
40+
KernelFunctions.transform(tf,X,1)
41+
@test all(KernelFunctions.transform(tf,X,1).==f(X))
42+
end
43+
## Test SelectTransform
44+
@testset "SelectTransform" begin
45+
ts = SelectTransform(sdims)
46+
@test all(KernelFunctions.transform(ts,X,2).==X[sdims,:])
47+
@test all(KernelFunctions.transform(ts,x).==x[sdims])
48+
@test all(get_params(SqExponentialKernel(ts)).==sdims)
49+
sdims2 = [2,3,5]
50+
KernelFunctions.set!(ts,sdims2)
51+
@test all(ts.select.==sdims2)
52+
end
53+
## Test ChainTransform
54+
@testset "ChainTransform" begin
55+
t = ScaleTransform(s)
56+
tp = LowRankTransform(P)
57+
tf = FunctionTransform(f)
58+
tchain = ChainTransform([t,tp,tf])
59+
@test all(KernelFunctions.transform(tchain,X,2).==f(P*(s*X)))
60+
@test all(KernelFunctions.transform(tchain,X,2).==
61+
KernelFunctions.transform(tftpt,X,2))
62+
end
63+
end

0 commit comments

Comments
 (0)