Skip to content

Commit f90b50b

Browse files
committed
Added KernelSum and KernelProduct and removed some useless parametrization
1 parent b8a483c commit f90b50b

File tree

7 files changed

+145
-33
lines changed

7 files changed

+145
-33
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
[![Build Status](https://travis-ci.org/theogf/KernelFunctions.jl.svg?branch=master)](https://travis-ci.org/theogf/AugmentedGaussianProcesses.jl)
22
[![Coverage Status](https://coveralls.io/repos/github/theogf/KernelFunctions.jl/badge.svg?branch=master)](https://coveralls.io/github/theogf/KernelFunctions.jl?branch=master)
33
[![Documentation](https://img.shields.io/badge/docs-dev-blue.svg)](https://theogf.github.io/KernelFunctions.jl/dev/)
4-
# KernelFunctions.jl (WIP)
5-
Julia Package for kernel functions for machine learning
4+
# KernelFunctions.jl
5+
## Kernel functions for machine learning
6+
7+
KernelFunctions.jl provide a flexible and complete framework for kernel functions, pretransforming the input data.
8+
9+
The aim is to make the API as model-agnostic as possible while still being user-friendly.
610

711
## Objectives (by priority)
812
- ARD Kernels

src/KernelFunctions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ export ExponentiatedKernel
88
export MaternKernel, Matern32Kernel, Matern52Kernel
99
export LinearKernel, PolynomialKernel
1010
export RationalQuadraticKernel, GammaRationalQuadraticKernel
11+
export KernelSum, KernelProduct
1112

1213

1314

@@ -32,6 +33,8 @@ for k in kernels
3233
include(joinpath("kernels",k*".jl"))
3334
end
3435
include("kernelmatrix.jl")
36+
include("kernels/kernelsum.jl")
37+
include("kernels/kernelproduct.jl")
3538

3639
include("generic.jl")
3740

src/generic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
@inline metric::Kernel) = κ.metric
22

3+
## Allows to iterate over kernels
4+
Base.length(::Kernel) = 1
5+
6+
Base.iterate(k::Kernel) = (k,nothing)
7+
Base.iterate(k::Kernel, ::Any) = nothing
8+
39
### Syntactic sugar for creating matrices and using kernel functions
410
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
511
@eval begin

src/kernelmatrix.jl

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
66
"""
77
function kernelmatrix!(
8-
K::Matrix{T₁},
9-
κ::Kernel{T},
10-
X::AbstractMatrix{T₂};
8+
K::Matrix,
9+
κ::Kernel,
10+
X::AbstractMatrix;
1111
obsdim::Int = defaultobs
12-
) where {T,T₁<:Real,T₂<:Real}
12+
)
1313
if !check_dims(K,X,X,feature_dim(obsdim),obsdim)
1414
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
1515
end
@@ -23,12 +23,12 @@ end
2323
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
2424
"""
2525
function kernelmatrix!(
26-
K::AbstractMatrix{T₁},
27-
κ::Kernel{T},
28-
X::AbstractMatrix{T₂},
29-
Y::AbstractMatrix{T₃};
26+
K::AbstractMatrix,
27+
κ::Kernel,
28+
X::AbstractMatrix,
29+
Y::AbstractMatrix;
3030
obsdim::Int = defaultobs
31-
) where {T,T₁,T₂,T₃}
31+
)
3232
if !check_dims(K,X,Y,feature_dim(obsdim),obsdim)
3333
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
3434
end
@@ -41,16 +41,16 @@ end
4141
```
4242
Apply the kernel `κ` to `x` and `y`.
4343
"""
44-
function kernel::Kernel{T}, x::Real, y::Real) where {T}
45-
kernel(κ, [T(x)], [T(y)])
44+
function kernel::Kernel, x::Real, y::Real)
45+
kernel(κ, [x], [y])
4646
end
4747

4848
function kernel(
49-
κ::Kernel{T},
50-
x::AbstractArray{T₁},
51-
y::AbstractArray{T₂};
49+
κ::Kernel,
50+
x::AbstractVector,
51+
y::AbstractVector;
5252
obsdim::Int = defaultobs
53-
) where {T,T₁<:Real,T₂<:Real}
53+
)
5454
@assert length(x) == length(y) "x and y don't have the same dimension!"
5555
kappa(κ, evaluate(metric(κ),transform(κ,x),transform(κ,y)))
5656
end
@@ -64,11 +64,11 @@ Calculate the kernel matrix of `X` with respect to kernel `κ`.
6464
`obsdim=2` means the matrix `X` has size #dimension x #samples
6565
"""
6666
function kernelmatrix(
67-
κ::Kernel{T,<:Transform},
67+
κ::Kernel,
6868
X::AbstractMatrix;
6969
obsdim::Int = defaultobs,
7070
symmetrize::Bool = true
71-
) where {T}
71+
)
7272
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
7373
end
7474

@@ -81,18 +81,19 @@ Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
8181
`obsdim=2` means the matrices `X` and `Y` have size #dimension x #samples
8282
"""
8383
function kernelmatrix(
84-
κ::Kernel{T},
85-
X::AbstractMatrix{T₁},
86-
Y::AbstractMatrix{T₂};
84+
κ::Kernel,
85+
X::AbstractMatrix,
86+
Y::AbstractMatrix;
8787
obsdim=defaultobs
88-
) where {T,T₁<:Real,T₂<:Real}
88+
)
8989
if !check_dims(X,Y,feature_dim(obsdim),obsdim)
9090
throw(DimensionMismatch("X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"))
9191
end
92-
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
93-
return K
92+
_kernelmatrix(κ,X,Y,obsdim)
9493
end
9594

95+
@inline _kernelmatrix(κ,X,Y,obsdim) = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
96+
9697
"""
9798
```
9899
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
@@ -102,14 +103,14 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
102103
`obsdim=2` means the matrix `X` has size #dimension x #samples
103104
"""
104105
function kerneldiagmatrix(
105-
κ::Kernel{T},
106-
X::AbstractMatrix{T₁};
106+
κ::Kernel,
107+
X::AbstractMatrix;
107108
obsdim::Int = defaultobs
108-
) where {T,T₁}
109+
)
109110
if obsdim == 1
110111
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
111112
elseif obsdim == 2
112-
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
113+
[@views kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
113114
end
114115
end
115116

src/kernels/kernelproduct.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,41 @@
11
struct KernelProduct{T,Tr} <: Kernel{T,Tr}
2+
kernels::Vector{Kernel}
3+
end
4+
5+
function KernelProduct(kernels::AbstractVector{<:Kernel})
6+
KernelProduct{eltype(kernels),Transform}(kernels)
7+
end
8+
9+
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
10+
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
11+
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k))
12+
13+
Base.length(k::KernelProduct) = length(k.kernels)
14+
metric(k::KernelProduct) = getmetric.(k.kernels)
15+
transform(k::KernelProduct) = transform.(k.kernels)
16+
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x])
17+
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim)
18+
19+
hadamard(x,y) = x.*y
20+
21+
function kernelmatrix(
22+
κ::KernelProduct,
23+
X::AbstractMatrix;
24+
obsdim::Int=defaultobs)
25+
reduce(hadamard,kernelmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
26+
end
27+
28+
function kernelmatrix(
29+
κ::KernelProduct,
30+
X::AbstractMatrix,
31+
Y::AbstractMatrix;
32+
obsdim::Int=defaultobs)
33+
reduce(hadamard,_kernelmatrix.kernels[i],X,Y,obsdim) for i in 1:length(κ))
34+
end
235

36+
function kerneldiagmatrix(
37+
κ::KernelProduct,
38+
X::AbstractMatrix;
39+
obsdim::Int=defaultobs)
40+
reduce(hadamard,kerneldiagmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
341
end

src/kernels/kernelsum.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,46 @@
11
struct KernelSum{T,Tr} <: Kernel{T,Tr}
22
kernels::Vector{Kernel}
33
weights::Vector{Real}
4+
function KernelSum{T,Tr}(kernels::AbstractVector{<:Kernel},weights::AbstractVector{<:Real}) where {T,Tr}
5+
new{T,Tr}(kernels,weights)
6+
end
47
end
58

6-
function Base.:+(k1::Kernel,k2::Kernel)
7-
KernelSum([k1,k2],[1.0,1.0])
9+
10+
function KernelSum(kernels::AbstractVector{<:Kernel}; weights::AbstractVector{<:Real}=ones(Float64,length(kernels)))
11+
@assert length(kernels)==length(weights) "Weights and kernel vector should be of the same length"
12+
@assert all(weights.>=0) "All weights should be positive"
13+
KernelSum{eltype(kernels),Transform}(kernels,weights)
14+
end
15+
16+
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
17+
Base.:+(k::Kernel,ks::KernelSum) = KernelSum(vcat(k,ks.kernels),weights=vcat(1.0,ks.weights))
18+
Base.:+(ks::KernelSum,k::Kernel) = KernelSum(vcat(ks.kernels,k),weights=vcat(ks.weights,1.0))
19+
20+
Base.length(k::KernelSum) = length(k.kernels)
21+
metric(k::KernelSum) = metric.(k.kernels)
22+
transform(k::KernelSum) = transform.(k.kernels)
23+
transform(k::KernelSum,x::AbstractVecOrMat) = transform.(k.kernels,[x])
24+
transform(k::KernelSum,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim)
25+
26+
function kernelmatrix(
27+
κ::KernelSum,
28+
X::AbstractMatrix;
29+
obsdim::Int=defaultobs)
30+
sum.weights[i]*kernelmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
31+
end
32+
33+
function kernelmatrix(
34+
κ::KernelSum,
35+
X::AbstractMatrix,
36+
Y::AbstractMatrix;
37+
obsdim::Int=defaultobs)
38+
sum.weights[i]*_kernelmatrix.kernels[i],X,Y,obsdim) for i in 1:length(κ))
39+
end
40+
41+
function kerneldiagmatrix(
42+
κ::KernelSum,
43+
X::AbstractMatrix;
44+
obsdim::Int=defaultobs)
45+
sum(kerneldiagmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
846
end

test/test_kernelmatrix.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,34 @@ k = SqExponentialKernel()
2323
end
2424
@testset "Kernel matrix" begin
2525
for obsdim in [1,2]
26-
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
27-
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
26+
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
27+
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,dims=obsdim))
28+
@test kerneldiagmatrix(k,A,obsdim=obsdim) == diag(kernelmatrix(k,A,obsdim=obsdim))
2829
@test k(A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
2930
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
3031
@test kernel(k,1.0,2.0) == kernel(k,[1.0],[2.0])
3132
@test_throws DimensionMismatch kernelmatrix(k,A,C,obsdim=obsdim)
3233
end
3334
end
35+
@testset "KernelSum" begin
36+
k1 = SqExponentialKernel()
37+
k2 = LinearKernel()
38+
k3 =
39+
ks = k1 + k2
40+
w1 = 0.4; w2 = 1.2;
41+
ks2 = KernelSum([k1,k2],weights=[w1,w2])
42+
@test all(kernelmatrix(ks,A) .== kernelmatrix(k1,A) + kernelmatrix(k2,A))
43+
@test all(kernelmatrix(ks,A,B) .== kernelmatrix(k1,A,B) + kernelmatrix(k2,A,B))
44+
@test all(kerneldiagmatrix(ks,A) .== kerneldiagmatrix(k1,A) + kerneldiagmatrix(k2,A))
45+
@test all(kernelmatrix(ks2,A) .== w1*kernelmatrix(k1,A) + w2*kernelmatrix(k2,A))
46+
end
47+
@testset "KernelProduct" begin
48+
k1 = SqExponentialKernel()
49+
k2 = LinearKernel()
50+
kp = k1 * k2
51+
@test all(kernelmatrix(kp,A) .== kernelmatrix(k1,A) .* kernelmatrix(k2,A))
52+
@test all(kernelmatrix(kp,A,B) .== kernelmatrix(k1,A,B) .* kernelmatrix(k2,A,B))
53+
@test all(kerneldiagmatrix(kp,A) .== kerneldiagmatrix(k1,A) .* kerneldiagmatrix(k2,A))
54+
@test all(kernelmatrix(kp,A) .== kernelmatrix(k1,A) .* kernelmatrix(k2,A))
55+
end
3456
end

0 commit comments

Comments
 (0)