Skip to content

Commit fdd317f

Browse files
authored
Merge pull request #83 from theogf/general_kernelmatrix
[WIP] Rework on kernelmatrix to work with Vectors and more complex kernels
2 parents d52def6 + efa1479 commit fdd317f

26 files changed

+219
-273
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export duplicate, set! # Helpers
1010

1111
export Kernel
1212
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel
13+
export CosineKernel
1314
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
1415
export LaplacianKernel, ExponentialKernel, GammaExponentialKernel
1516
export ExponentiatedKernel
@@ -43,6 +44,7 @@ Abstract type defining a slice-wise transformation on an input matrix
4344
abstract type Transform end
4445
abstract type Kernel end
4546
abstract type BaseKernel <: Kernel end
47+
abstract type SimpleKernel <: BaseKernel end
4648

4749
include("utils.jl")
4850
include("distances/dotproduct.jl")

src/basekernels/constant.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Create a kernel that always returning zero
77
```
88
The output type depends of `x` and `y`
99
"""
10-
struct ZeroKernel <: BaseKernel end
10+
struct ZeroKernel <: SimpleKernel end
1111

1212
kappa::ZeroKernel, d::T) where {T<:Real} = zero(T)
1313

@@ -24,7 +24,7 @@ Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
2424
```
2525
Kernel function working as an equivalent to add white noise. Can also be called via `EyeKernel()`
2626
"""
27-
struct WhiteKernel <: BaseKernel end
27+
struct WhiteKernel <: SimpleKernel end
2828

2929
"""
3030
EyeKernel()
@@ -48,7 +48,7 @@ Kernel function always returning a constant value `c`
4848
κ(x,y) = c
4949
```
5050
"""
51-
struct ConstantKernel{Tc<:Real} <: BaseKernel
51+
struct ConstantKernel{Tc<:Real} <: SimpleKernel
5252
c::Vector{Tc}
5353
function ConstantKernel(;c::T=1.0) where {T<:Real}
5454
new{T}([c])

src/basekernels/cosine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The cosine kernel is a stationary kernel for a sinusoidal given by
66
κ(x,y) = cos( π * (x-y) )
77
```
88
"""
9-
struct CosineKernel <: BaseKernel end
9+
struct CosineKernel <: SimpleKernel end
1010

1111
kappa::CosineKernel, d::Real) = cospi(d)
1212
metric(::CosineKernel) = Euclidean()

src/basekernels/exponential.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Can also be called via `SEKernel`, `GaussianKernel` or `SEKernel`.
99
See also [`ExponentialKernel`](@ref) for a
1010
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
1111
"""
12-
struct SqExponentialKernel <: BaseKernel end
12+
struct SqExponentialKernel <: SimpleKernel end
1313

1414
kappa::SqExponentialKernel, d²::Real) = exp(-d²)
1515
iskroncompatible(::SqExponentialKernel) = true
@@ -30,7 +30,7 @@ The exponential kernel is a Mercer kernel given by the formula:
3030
κ(x,y) = exp(-‖x-y‖)
3131
```
3232
"""
33-
struct ExponentialKernel <: BaseKernel end
33+
struct ExponentialKernel <: SimpleKernel end
3434

3535
kappa::ExponentialKernel, d::Real) = exp(-d)
3636
iskroncompatible(::ExponentialKernel) = true
@@ -51,7 +51,7 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
5151
Where `γ > 0`, (the keyword `γ` can be replaced by `gamma`)
5252
For `γ = 1`, see `SqExponentialKernel` and `γ = 0.5`, see `ExponentialKernel`
5353
"""
54-
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
54+
struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
5555
γ::Vector{Tγ}
5656
function GammaExponentialKernel(; gamma::T=2.0, γ::T=gamma) where {T<:Real}
5757
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")

src/basekernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The exponentiated kernel is a Mercer kernel given by:
66
κ(x,y) = exp(xᵀy)
77
```
88
"""
9-
struct ExponentiatedKernel <: BaseKernel end
9+
struct ExponentiatedKernel <: SimpleKernel end
1010

1111
kappa::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy)
1212
metric(::ExponentiatedKernel) = DotProduct()

src/basekernels/fbm.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,6 @@ function kernelmatrix!(
6666
return K
6767
end
6868

69-
## Apply kernel on two vectors ##
70-
function _kernel(
71-
κ::FBMKernel,
72-
x::AbstractVector,
73-
y::AbstractVector;
74-
obsdim::Int = defaultobs
75-
)
76-
@assert length(x) == length(y) "x and y don't have the same dimension!"
77-
return kappa(κ, x, y)
78-
end
79-
8069
function kappa::FBMKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
8170
modX = sum(abs2, x)
8271
modY = sum(abs2, y)

src/basekernels/maha.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Mahalanobis distance-based kernel given by
88
where the matrix P is the metric.
99
1010
"""
11-
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: BaseKernel
11+
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1212
P::A
1313
function MahalanobisKernel(P::AbstractMatrix{T}) where {T<:Real}
1414
LinearAlgebra.checksquare(P)

src/basekernels/matern.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The matern kernel is a Mercer kernel given by the formula:
77
```
88
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
99
"""
10-
struct MaternKernel{Tν<:Real} <: BaseKernel
10+
struct MaternKernel{Tν<:Real} <: SimpleKernel
1111
ν::Vector{Tν}
1212
function MaternKernel(;nu::T=1.5, ν::T=nu) where {T<:Real}
1313
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
@@ -37,7 +37,7 @@ The matern 3/2 kernel is a Mercer kernel given by the formula:
3737
κ(x,y) = (1+√(3)‖x-y‖)exp(-√(3)‖x-y‖)
3838
```
3939
"""
40-
struct Matern32Kernel <: BaseKernel end
40+
struct Matern32Kernel <: SimpleKernel end
4141

4242
kappa::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d)
4343
metric(::Matern32Kernel) = Euclidean()
@@ -52,7 +52,7 @@ The matern 5/2 kernel is a Mercer kernel given by the formula:
5252
κ(x,y) = (1+√(5)‖x-y‖ + 5/3‖x-y‖^2)exp(-√(5)‖x-y‖)
5353
```
5454
"""
55-
struct Matern52Kernel <: BaseKernel end
55+
struct Matern52Kernel <: SimpleKernel end
5656

5757
kappa::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)
5858
metric(::Matern52Kernel) = Euclidean()

src/basekernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Periodic Kernel as described in http://www.inference.org.uk/mackay/gpB.pdf eq. 4
88
κ(x,y) = exp( - 0.5 sum_i(sin (π(x_i - y_i))/r_i))
99
```
1010
"""
11-
struct PeriodicKernel{T} <: BaseKernel
11+
struct PeriodicKernel{T} <: SimpleKernel
1212
r::Vector{T}
1313
function PeriodicKernel(; r::AbstractVector{T} = ones(Float64, 1)) where {T<:Real}
1414
@assert all(r .> 0)

src/basekernels/piecewisepolynomial.jl

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ processes are hence v times mean-square differentiable. The kernel function is:
1010
where `r` is the Mahalanobis distance mahalanobis(x,y) with `maha` as the metric.
1111
1212
"""
13-
struct PiecewisePolynomialKernel{V, A<:AbstractMatrix{<:Real}} <: BaseKernel
13+
struct PiecewisePolynomialKernel{V, A<:AbstractMatrix{<:Real}} <: SimpleKernel
1414
maha::A
15+
j::Int
1516
function PiecewisePolynomialKernel{V}(maha::AbstractMatrix{<:Real}) where V
1617
V in (0, 1, 2, 3) || error("Invalid paramter v=$(V). Should be 0, 1, 2 or 3.")
1718
LinearAlgebra.checksquare(maha)
18-
return new{V,typeof(maha)}(maha)
19+
j = div(size(maha, 1), 2) + V + 1
20+
return new{V,typeof(maha)}(maha, j)
1921
end
2022
end
2123

@@ -29,78 +31,7 @@ _f(κ::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3)
2931
_f::PiecewisePolynomialKernel{3}, r, j) = 1 + (j + 3) * r +
3032
(6 * j^2 + 36j + 45) / 15 * r.^2 + (j^3 + 9 * j^2 + 23j + 15) / 15 * r.^3
3133

32-
function _piecewisepolynomial::PiecewisePolynomialKernel{V}, r, j) where V
33-
return max(1 - r, 0)^(j + V) * _f(κ, r, j)
34-
end
35-
36-
function kappa(
37-
κ::PiecewisePolynomialKernel{V},
38-
x::AbstractVector{<:Real},
39-
y::AbstractVector{<:Real},
40-
) where {V}
41-
r = evaluate(metric(κ), x, y)
42-
j = div(size(x, 2), 1) + V + 1
43-
return _piecewisepolynomial(κ, r, j)
44-
end
45-
46-
function _kernel(
47-
κ::PiecewisePolynomialKernel,
48-
x::AbstractVector,
49-
y::AbstractVector;
50-
obsdim::Int = defaultobs,
51-
)
52-
@assert length(x) == length(y) "x and y don't have the same dimension!"
53-
return kappa(κ,x,y)
54-
end
55-
56-
function kernelmatrix(
57-
κ::PiecewisePolynomialKernel{V},
58-
X::AbstractMatrix;
59-
obsdim::Int = defaultobs
60-
) where {V}
61-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
62-
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X; dims=obsdim))
63-
end
64-
65-
function _kernelmatrix::PiecewisePolynomialKernel{V}, X, Y, obsdim) where {V}
66-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
67-
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X, Y; dims=obsdim))
68-
end
69-
70-
function kernelmatrix!(
71-
K::AbstractMatrix,
72-
κ::PiecewisePolynomialKernel{V},
73-
X::AbstractMatrix;
74-
obsdim::Int = defaultobs
75-
) where {V}
76-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
77-
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
78-
throw(DimensionMismatch(
79-
"Dimensions of the target array K $(size(K)) are not consistent with X " *
80-
"$(size(X))",
81-
))
82-
end
83-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
84-
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X; dims=obsdim))
85-
end
86-
87-
function kernelmatrix!(
88-
K::AbstractMatrix,
89-
κ::PiecewisePolynomialKernel{V},
90-
X::AbstractMatrix,
91-
Y::AbstractMatrix;
92-
obsdim::Int = defaultobs,
93-
) where {V}
94-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
95-
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
96-
throw(DimensionMismatch(
97-
"Dimensions $(size(K)) of the target array K are not consistent with X " *
98-
"($(size(X))) and Y ($(size(Y)))",
99-
))
100-
end
101-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
102-
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X, Y; dims=obsdim))
103-
end
34+
kappa::PiecewisePolynomialKernel{V}, r) where V = max(1 - r, 0)^.j + V) * _f(κ, r, κ.j)
10435

10536
metric::PiecewisePolynomialKernel) = Mahalanobis.maha)
10637

src/basekernels/polynomial.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The linear kernel is a Mercer kernel given by
77
```
88
Where `c` is a real number
99
"""
10-
struct LinearKernel{Tc<:Real} <: BaseKernel
10+
struct LinearKernel{Tc<:Real} <: SimpleKernel
1111
c::Vector{Tc}
1212
function LinearKernel(;c::T=0.0) where {T}
1313
new{T}([c])
@@ -28,7 +28,7 @@ The polynomial kernel is a Mercer kernel given by
2828
```
2929
Where `c` is a real number, and `d` is a shape parameter bigger than 1. For `d = 1` see [`LinearKernel`](@ref)
3030
"""
31-
struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
31+
struct PolynomialKernel{Td<:Real, Tc<:Real} <: SimpleKernel
3232
d::Vector{Td}
3333
c::Vector{Tc}
3434
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}

src/basekernels/rationalquad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The rational-quadratic kernel is a Mercer kernel given by the formula:
77
```
88
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
99
"""
10-
struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
10+
struct RationalQuadraticKernel{Tα<:Real} <: SimpleKernel
1111
α::Vector{Tα}
1212
function RationalQuadraticKernel(;alpha::T=2.0, α::T=alpha) where {T}
1313
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
@@ -28,7 +28,7 @@ The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the f
2828
```
2929
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
3030
"""
31-
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
31+
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: SimpleKernel
3232
α::Vector{Tα}
3333
γ::Vector{Tγ}
3434
function GammaRationalQuadraticKernel(;alpha::Tα=2.0, gamma::Tγ=2.0, α::Tα=alpha, γ::Tγ=gamma) where {Tα<:Real, Tγ<:Real}

src/generic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
for k in concretetypes(Kernel, [])
2424
@eval begin
25-
@inline::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
25+
@inline::$k)(x, y) = kappa(κ, x, y)
2626
@inline::$k)(X::AbstractMatrix{T}, Y::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, Y, obsdim=obsdim)
2727
@inline::$k)(X::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, obsdim=obsdim)
2828
end

src/kernels/kernelproduct.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,26 @@ hadamard(x,y) = x.*y
2929
function kernelmatrix(
3030
κ::KernelProduct,
3131
X::AbstractMatrix;
32-
obsdim::Int=defaultobs)
33-
reduce(hadamard,kernelmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
32+
obsdim::Int=defaultobs,
33+
)
34+
reduce(hadamard, kernelmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
3435
end
3536

3637
function kernelmatrix(
3738
κ::KernelProduct,
3839
X::AbstractMatrix,
3940
Y::AbstractMatrix;
40-
obsdim::Int=defaultobs)
41-
reduce(hadamard,_kernelmatrix.kernels[i],X,Y,obsdim) for i in 1:length(κ))
41+
obsdim::Int=defaultobs,
42+
)
43+
reduce(hadamard, kernelmatrix.kernels[i], X, Y, obsdim = obsdim) for i in 1:length(κ))
4244
end
4345

4446
function kerneldiagmatrix(
4547
κ::KernelProduct,
4648
X::AbstractMatrix;
47-
obsdim::Int=defaultobs) #TODO Add test
48-
reduce(hadamard,kerneldiagmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
49+
obsdim::Int=defaultobs,
50+
) #TODO Add test
51+
reduce(hadamard, kerneldiagmatrix.kernels[i], X, obsdim = obsdim) for i in 1:length(κ))
4952
end
5053

5154
function Base.show(io::IO, κ::KernelProduct)

src/kernels/kernelsum.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function kernelmatrix(
5858
Y::AbstractMatrix;
5959
obsdim::Int = defaultobs,
6060
)
61-
sum.weights[i] * _kernelmatrix.kernels[i], X, Y, obsdim) for i in 1:length(κ))
61+
sum.weights[i] * kernelmatrix.kernels[i], X, Y, obsdim = obsdim) for i in 1:length(κ))
6262
end
6363

6464
function kerneldiagmatrix(

src/kernels/scaledkernel.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ end
1515

1616
kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)
1717

18+
kappa(k::ScaledKernel, x, y) = first(k.σ²) * kappa(k.kernel, x, y)
19+
1820
metric(k::ScaledKernel) = metric(k.kernel)
1921

2022
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)

src/kernels/tensorproduct.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ end
2222

2323
Base.length(kernel::TensorProduct) = length(kernel.kernels)
2424

25-
(kernel::TensorProduct)(x, y) = kappa(kernel, x, y)
2625
function kappa(kernel::TensorProduct, x, y)
2726
return prod(kappa(k, xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
2827
end
@@ -97,7 +96,7 @@ function kernelmatrix(
9796
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
9897

9998
featuredim = feature_dim(obsdim)
100-
if !check_dims(X, X, featuredim, obsdim)
99+
if !check_dims(X, X, featuredim)
101100
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not " *
102101
"consistent with X $(size(X))"))
103102
end
@@ -120,7 +119,7 @@ function kernelmatrix(
120119
obsdim (1, 2) || error("obsdim should be 1 or 2 (see docs of kernelmatrix))")
121120

122121
featuredim = feature_dim(obsdim)
123-
if !check_dims(X, Y, featuredim, obsdim)
122+
if !check_dims(X, Y, featuredim)
124123
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not " *
125124
"consistent with X ($(size(X))) and Y ($(size(Y)))"))
126125
end

src/kernels/transformedkernel.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,17 @@ function printshifted(io::IO, κ::TransformedKernel, shift::Int)
3939
printshifted(io, κ.kernel, shift)
4040
print(io,"\n" * ("\t" ^ (shift + 1)) * "- $(κ.transform)")
4141
end
42+
43+
# Kernel matrix operations
44+
45+
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
46+
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
47+
48+
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
49+
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
50+
51+
kernelmatrix::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
52+
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
53+
54+
kernelmatrix::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs) =
55+
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)

0 commit comments

Comments
 (0)