Skip to content

Commit 63942bc

Browse files
committed
Merge branch 'master-dev'
2 parents 6fc3a6f + d095935 commit 63942bc

23 files changed

+171
-49
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.2.0"
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
910
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1011
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1112
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

README.md

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,35 @@ KernelFunctions.jl provide a flexible and complete framework for kernel function
88

99
The aim is to make the API as model-agnostic as possible while still being user-friendly.
1010

11+
## Examples
12+
13+
```julia
14+
X = reshape(collect(range(-3.0,3.0,length=100)),:,1)
15+
# Set simple scaling of the data
16+
k₁ = SqExponentialKernel(1.0)
17+
K₁ = kernelmatrix(k,X,obsdim=1)
18+
19+
# Set a function transformation on the data
20+
k₂ = MaternKernel(FunctionTransform(x->sin.(x)))
21+
K₂ = kernelmatrix(k,X,obsdim=1)
22+
23+
# Set a matrix premultiplication on the data
24+
k₃ = PolynomialKernel(LowRankTransform(randn(4,1)),0.0,2.0)
25+
K₃ = kernelmatrix(k,X,obsdim=1)
26+
27+
# Add and sum kernels
28+
k₄ = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*k₂
29+
K₄ = kernelmatrix(k,X,obsdim=1)
30+
31+
heatmap([K₁,K₂,K₃,K₄],yflip=false,colorbar=false)
32+
```
33+
<p align=center>
34+
<img src="docs/src/assets/heatmap_combination.png" width=400px>
35+
</p>
36+
1137
## Objectives (by priority)
12-
- ARD Kernels
13-
- AD Compatible (Zygote, ForwardDiff, ReverseDiff)
14-
- Kernel sum and product
38+
- AD Compatibility (Zygote, ForwardDiff)
1539
- Toeplitz Matrices
1640
- BLAS backend
1741

18-
19-
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package
42+
Directly inspired by the [MLKernels](https://github.com/trthatcher/MLKernels.jl) package.

docs/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
build/
22
site/
3+
4+
#Temp to avoid to many changes

docs/create_kernel_plots.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,33 @@ x₀ = 0.0; l=0.1
1010
n_grid = 101
1111
fill(x₀,n_grid,1)
1212
xrange = reshape(collect(range(-3,3,length=n_grid)),:,1)
13+
14+
k = SqExponentialKernel(1.0)
15+
K1 = kernelmatrix(k,xrange,obsdim=1)
16+
p = heatmap(K1,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
17+
savefig(joinpath(@__DIR__,"src","assets","heatmap_sqexp.png"))
18+
19+
20+
k = Matern32Kernel(FunctionTransform(x->(sin.(x)).^2))
21+
K2 = kernelmatrix(k,xrange,obsdim=1)
22+
p = heatmap(K2,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
23+
savefig(joinpath(@__DIR__,"src","assets","heatmap_matern.png"))
24+
25+
26+
k = PolynomialKernel(LowRankTransform(randn(3,1)),2.0,0.0)
27+
K3 = kernelmatrix(k,xrange,obsdim=1)
28+
p = heatmap(K3,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
29+
savefig(joinpath(@__DIR__,"src","assets","heatmap_poly.png"))
30+
31+
k = 0.5*SqExponentialKernel()*LinearKernel(0.5) + 0.4*Matern32Kernel(FunctionTransform(x->sin.(x)))
32+
K4 = kernelmatrix(k,xrange,obsdim=1)
33+
p = heatmap(K4,yflip=true,colorbar=false,framestyle=:none,background_color=RGBA(0.0,0.0,0.0,0.0))
34+
savefig(joinpath(@__DIR__,"src","assets","heatmap_prodsum.png"))
35+
36+
plot(heatmap.([K1,K2,K3,K4],yflip=true,colorbar=false)...,layout=(2,2))
37+
savefig(joinpath(@__DIR__,"src","assets","heatmap_combination.png"))
38+
39+
1340
for k in [SqExponentialKernel,ExponentialKernel]
1441
K = kernelmatrix(k(),xrange,obsdim=1)
1542
v = rand(MvNormal(K+1e-7I))
67.8 KB
Loading

docs/src/assets/heatmap_matern.png

37.8 KB
Loading

docs/src/assets/heatmap_poly.png

17.8 KB
Loading

docs/src/assets/heatmap_prodsum.png

21.5 KB
Loading

docs/src/assets/heatmap_sqexp.png

7.43 KB
Loading

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Distances, LinearAlgebra
1616
using Zygote: @adjoint
1717
using SpecialFunctions: lgamma, besselk
1818
using StatsFuns: logtwo
19+
using PDMats
1920

2021
const defaultobs = 2
2122

@@ -32,7 +33,7 @@ kernels = ["exponential","matern","polynomial","constant","rationalquad","expone
3233
for k in kernels
3334
include(joinpath("kernels",k*".jl"))
3435
end
35-
include("kernelmatrix.jl")
36+
include("matrix/kernelmatrix.jl")
3637
include("kernels/kernelsum.jl")
3738
include("kernels/kernelproduct.jl")
3839

src/generic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
@inline metric::Kernel) = κ.metric
22

33
## Allows to iterate over kernels
4-
Base.length(::Kernel) = 1
4+
Base.length(::Kernel) = 1 #TODO Add test
55

6-
Base.iterate(k::Kernel) = (k,nothing)
7-
Base.iterate(k::Kernel, ::Any) = nothing
6+
Base.iterate(k::Kernel) = (k,nothing) #TODO Add test
7+
Base.iterate(k::Kernel, ::Any) = nothing #TODO Add test
88

99
### Syntactic sugar for creating matrices and using kernel functions
1010
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
1111
@eval begin
12-
@inline::$k)(d::Real) = kappa(κ,d)
12+
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
1313
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate.metric,transform(κ,x),transform(κ,y)))
1414
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
1515
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)

src/kernels/constant.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ZeroKernel()
2+
ZeroKernel([tr=IdentityTransform()])
33
44
Create a kernel that always return a zero kernel matrix
55
@@ -19,7 +19,7 @@ end
1919
@inline kappa::ZeroKernel,d::T) where {T<:Real} = zero(T)
2020

2121
"""
22-
WhiteKernel()
22+
WhiteKernel([tr=IdentityTransform()])
2323
2424
```
2525
κ(x,y) = δ(x,y)
@@ -41,7 +41,7 @@ end
4141
@inline kappa::WhiteKernel,δₓₓ::Real) = δₓₓ
4242

4343
"""
44-
ConstantKernel([c=1.0])
44+
ConstantKernel([tr=IdentityTransform(),[c=1.0]])
4545
4646
```
4747
κ(x,y) = c

src/kernels/exponential.jl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,6 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
5555
```
5656
κ(x,y) = exp(-‖x-y‖^2γ)
5757
```
58-
59-
# Examples
60-
61-
```jldoctest; setup = :(using KernelFunctions)
62-
julia> GammaExponentialKernel()
63-
GammaExponentialKernel{Float64,Float64,Float64}(1.0,2.0)
64-
65-
julia> GammaExponentialKernel(2.0f0,3.0)
66-
GammaExponentialKernel{Float32,Float32,Float64}(2.0,3.0)
67-
68-
julia> GammaExponentialKernel([2.0,3.0],2f0)
69-
GammaExponentialKernel{Float64,Array{Float64},Float32}([2.0,3.0],2.0)
70-
```
7158
"""
7259
struct GammaExponentialKernel{T,Tr,Tᵧ<:Real} <: Kernel{T,Tr}
7360
transform::Tr

src/kernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ExponentiatedKernel([α=1])
2+
ExponentiatedKernel([ρ=1])
33
44
The exponentiated kernel is a Mercer kernel given by:
55

src/kernels/kernelproduct.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
KernelProduct(kernels::Array{Kernel})
3+
Create a multiplication of kernels.
4+
One can also use the operator `*`
5+
```
6+
kernelmatrix(SqExponentialKernel()*LinearKernel(),X) == kernelmatrix(SqExponentialKernel(),X).*kernelmatrix(LinearKernel(),X)
7+
```
8+
"""
19
struct KernelProduct{T,Tr} <: Kernel{T,Tr}
210
kernels::Vector{Kernel}
311
end
@@ -7,14 +15,15 @@ function KernelProduct(kernels::AbstractVector{<:Kernel})
715
end
816

917
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
18+
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
1019
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))
1120
Base.:*(kp::KernelProduct,k::Kernel) = KernelProduct(vcat(kp.kernels,k))
1221

1322
Base.length(k::KernelProduct) = length(k.kernels)
14-
metric(k::KernelProduct) = getmetric.(k.kernels)
15-
transform(k::KernelProduct) = transform.(k.kernels)
16-
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x])
17-
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim)
23+
metric(k::KernelProduct) = getmetric.(k.kernels) #TODO Add test
24+
transform(k::KernelProduct) = transform.(k.kernels) #TODO Add test
25+
transform(k::KernelProduct,x::AbstractVecOrMat) = transform.(k.kernels,[x]) #TODO Add test
26+
transform(k::KernelProduct,x::AbstractVecOrMat,obsdim::Int) = transform.(k.kernels,[x],obsdim) #TODO Add test
1827

1928
hadamard(x,y) = x.*y
2029

@@ -36,6 +45,6 @@ end
3645
function kerneldiagmatrix(
3746
κ::KernelProduct,
3847
X::AbstractMatrix;
39-
obsdim::Int=defaultobs)
48+
obsdim::Int=defaultobs) #TODO Add test
4049
reduce(hadamard,kerneldiagmatrix.kernels[i],X,obsdim=obsdim) for i in 1:length(κ))
4150
end

src/kernels/kernelsum.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
"""
2+
KernelSum(kernels::Array{Kernel};weights::Array{Real}=ones(length(kernels)))
3+
Create a positive weighted sum of kernels.
4+
One can also use the operator `+`
5+
```
6+
kernelmatrix(SqExponentialKernel()+LinearKernel(),X) == kernelmatrix(SqExponentialKernel(),X).+kernelmatrix(LinearKernel(),X)
7+
```
8+
"""
19
struct KernelSum{T,Tr} <: Kernel{T,Tr}
210
kernels::Vector{Kernel}
311
weights::Vector{Real}
@@ -14,8 +22,12 @@ function KernelSum(kernels::AbstractVector{<:Kernel}; weights::AbstractVector{<:
1422
end
1523

1624
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
25+
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))
1726
Base.:+(k::Kernel,ks::KernelSum) = KernelSum(vcat(k,ks.kernels),weights=vcat(1.0,ks.weights))
1827
Base.:+(ks::KernelSum,k::Kernel) = KernelSum(vcat(ks.kernels,k),weights=vcat(ks.weights,1.0))
28+
Base.:*(w::Real,k::Kernel) = KernelSum([k],weights=[w]) #TODO add tests
29+
Base.:*(w::Real,k::KernelSum) = KernelSum(k.kernels,weights=w*k.weights) #TODO add tests
30+
1931

2032
Base.length(k::KernelSum) = length(k.kernels)
2133
metric(k::KernelSum) = metric.(k.kernels)

src/kernels/matern.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-lgamma.ν) + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
3939

4040
"""
41-
Matern32Kernel(ρ=1.0)
41+
Matern32Kernel([ρ=1.0])
4242
4343
The matern 3/2 kernel is an isotropic Mercer kernel given by the formula:
4444
@@ -59,7 +59,7 @@ end
5959
@inline kappa::Matern32Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
6060

6161
"""
62-
Matern52Kernel(ρ=1.0)
62+
Matern52Kernel([ρ=1.0])
6363
6464
The matern 5/2 kernel is an isotropic Mercer kernel given by the formula:
6565

src/kernelmatrix.jl renamed to src/matrix/kernelmatrix.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function kernelmatrix(
8787
obsdim=defaultobs
8888
)
8989
if !check_dims(X,Y,feature_dim(obsdim),obsdim)
90-
throw(DimensionMismatch("X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"))
90+
throw(DimensionMismatch("X $(size(X)) and Y $(size(Y)) do not have the same number of features on the dimension : $(feature_dim(obsdim))"))
9191
end
9292
_kernelmatrix(κ,X,Y,obsdim)
9393
end
@@ -114,12 +114,18 @@ function kerneldiagmatrix(
114114
end
115115
end
116116

117+
"""
118+
```
119+
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int=2)
120+
```
121+
In place version of `kerneldiagmatrix`
122+
"""
117123
function kerneldiagmatrix!(
118-
K::AbstractVector{T₁},
119-
κ::Kernel{T},
120-
X::AbstractMatrix{T₂};
124+
K::AbstractVector,
125+
κ::Kernel,
126+
X::AbstractMatrix;
121127
obsdim::Int = defaultobs
122-
) where {T,T₁,T₂}
128+
)
123129
if length(K) != size(X,obsdim)
124130
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
125131
end

src/matrix/kernelpdmat.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Guarantees to return a positive-definite matrix in the form of a `PDMat` matrix with the cholesky decomposition precomputed
3+
"""
4+
function kernelpdmat(
5+
κ::Kernel,
6+
X::AbstractMatrix;
7+
obsdim::Int = defaultobs
8+
)
9+
K = kernelmatrix(κ,X,obsdim=obsdim)
10+
α = eps(eltype(K))
11+
while !isposdef(K+αI) && α < 0.01*maximum(K)
12+
α *= 2.0
13+
end
14+
if α >= 0.01*maximum(K)
15+
@error "Adding noise on the diagonal was not sufficient to build a positive-definite matrix:\n - Check that your kernel parameters are not extreme\n - Check that your data is sufficiently sparse\n - Maybe use a different kernel"
16+
end
17+
return PDMat(K+αI)
18+
end

src/transform/functiontransform.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""
22
FunctionTransform
3-
3+
```
4+
f(x) = abs.(x)
5+
tr = FunctionTransform(f)
6+
```
47
Take a function `f` as an argument which is going to act on each vector individually.
58
Make sure that `f` is supposed to act on a vector by eventually using broadcasting
69
For example `f(x)=sin(x)` -> `f(x)=sin.(x)`

src/transform/lowranktransform.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1+
"""
2+
LowRankTransform
3+
```
4+
P = rand(10,5)
5+
tr = LowRankTransform(P)
6+
```
7+
Apply the low-rank projection realised by the matrix `P`
8+
The second dimension of `P` must match the number of features of the target.
9+
"""
110
struct LowRankTransform{T<:AbstractMatrix{<:Real}} <: Transform
211
proj::T
312
end
413

514
Base.size(tr::LowRankTransform,i::Int) = size(tr.proj,i)
6-
Base.size(tr::LowRankTransform) = size(tr.proj)
15+
Base.size(tr::LowRankTransform) = size(tr.proj) # TODO Add test
716

817
function transform(t::LowRankTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs)
918
@boundscheck size(t,2) != size(X,feature_dim(obsdim)) ?
1019
throw(DimensionMismatch("The projection matrix has size $(size(t)) and cannot be used on X with dimensions $(size(X))")) : nothing
1120
@inbounds _transform(t,X,obsdim)
1221
end
13-
function transform(t::LowRankTransform,x::AbstractVector{<:Real})
14-
@assert size(t,2) == length(x) "Vector has wrong dimensions"
22+
23+
function transform(t::LowRankTransform,x::AbstractVector{<:Real},obsdim::Int=defaultobs) #TODO Add test
24+
@assert size(t,2) == length(x) "Vector has wrong dimensions $(length(x)) compared to projection matrix"
1525
t.proj*X
1626
end
1727

src/transform/scaletransform.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
"""
22
Scale Transform
3+
```
4+
l = 2.0
5+
tr = ScaleTransform(l)
6+
v = rand(3)
7+
tr = ScaleTransform(v)
8+
```
9+
Multiply every element of the matrix by `l` for a scalar
10+
Multiply every vector of observation by `v` element-wise for a vector
311
"""
412
struct ScaleTransform{T<:Union{Real,AbstractVector{<:Real}}} <: Transform
513
s::T
@@ -10,7 +18,7 @@ function ScaleTransform(s::T=1.0) where {T<:Real}
1018
ScaleTransform{T}(s)
1119
end
1220

13-
function ScaleTransform(s::T,dims::Integer) where {T<:Real}
21+
function ScaleTransform(s::T,dims::Integer) where {T<:Real} # TODO Add test
1422
@check_args(ScaleTransform, s, s > zero(T), "s > 0")
1523
ScaleTransform{Vector{T}}(fill(s,dims))
1624
end
@@ -20,12 +28,12 @@ function ScaleTransform(s::A) where {A<:AbstractVector{<:Real}}
2028
ScaleTransform{A}(s)
2129
end
2230

23-
dim(str::ScaleTransform{<:Real}) = 1
31+
dim(str::ScaleTransform{<:Real}) = 1 #TODO Add test
2432
dim(str::ScaleTransform{<:AbstractVector{<:Real}}) = length(str.s)
2533

2634
function transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int)
2735
@boundscheck if dim(t) != size(X,!Bool(obsdim-1)+1)
28-
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t))."))
36+
throw(DimensionMismatch("Array has size $(size(X,!Bool(obsdim-1)+1)) on dimension $(!Bool(obsdim-1)+1)) which does not match the length of the scale transform length , $(dim(t)).")) #TODO Add test
2937
end
3038
_transform(t,X,obsdim)
3139
end

0 commit comments

Comments
 (0)