Skip to content

Commit c9411ee

Browse files
committed
Merge branch 'master-dev'
2 parents 9e8dc48 + b8fd2e9 commit c9411ee

File tree

11 files changed

+111
-133
lines changed

11 files changed

+111
-133
lines changed

benchmark/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using KernelFunctions
22

33
SUITE["KernelFunctions"] = BenchmarkGroup()
44

5-
kernelnames = ["SquaredExponentialKernel"]
5+
kernelnames = ["SqExponentialKernel"]
66
kerneltypes = ["ARD","ISO"]
77
kernels=Dict{String,Dict{String,KernelFunctions.Kernel}}()
88
for k in kernelnames

dev/debugAD.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,39 @@ testfunction(k,A,B) = det(kernelmatrix(k,A,B))
1212
testfunction(k,A) = sum(kernelmatrix(k,A))
1313
k = MaternKernel(vl)
1414
KernelFunctions.kappa(k,3)
15-
testfunction(SquaredExponentialKernel(vl),A)
15+
testfunction(SqExponentialKernel(vl),A)
1616
testfunction(MaternKernel(vl),A)
1717
@which kernelmatrix(MaternKernel(vl),A,B)
1818
#For debugging
1919
@info "Running Zygote gradients"
2020
Zygote.refresh()
2121
## Zygote
22-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
22+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A),vl)
2323
Zygote.gradient(x->testfunction(MaternKernel(x),A),vl)
24-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)[1]
24+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl)[1]
2525
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),vl)[1]
26-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
26+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A,B),l)
2727
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),l)
28-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
28+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A),l)
2929
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
3030
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
3131
Zygote.gradient(x->kernelmatrix(MaternKernel(x,1.0),A)[1],l)
3232
@info "Running Tracker gradients"
3333
## Tracker
34-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
35-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(l),x[:,:]),A)
36-
# # Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
37-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
38-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
39-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
34+
# Tracker.gradient(x->testfunction(SqExponentialKernel(vl),x,B),A)
35+
# Tracker.gradient(x->testfunction(SqExponentialKernel(l),x[:,:]),A)
36+
# # Tracker.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl)
37+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A),vl)
38+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A,B),l)
39+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A),l)
4040

4141
@info "Running ForwardDiff gradients"
4242
## ForwardDiff
43-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl) #
43+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl) #
4444
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
45-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl) #
45+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A),vl) #
4646
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A),vl) #
47-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A,B),[l])
47+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A,B),[l])
4848
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A,B),[l])
49-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
49+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A),[l])
5050
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A),[l])

dev/matrixvsvectors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ timekf = similar(timestheno); memkf = similar(timestheno)
1414
B = randn(D,1001)
1515

1616
# Standardised eq kernel with length-scale 0.1.
17-
medkf = median(@benchmark KernelFunctions.kernelmatrix(SquaredExponentialKernel(0.01),$A,$B,obsdim=2))
17+
medkf = median(@benchmark KernelFunctions.kernelmatrix(SqExponentialKernel(0.01),$A,$B,obsdim=2))
1818
timekf[i] = medkf.time/1e6; memkf[i] = medkf.memory/2^20
1919
medstheno = median(@benchmark pw(eq(; l=0.1), ColsAreObs($A), ColsAreObs($B)))
2020
timestheno[i] = medstheno.time/1e6; memstheno[i] = medstheno.memory/2^20

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4-
export Kernel, SquaredExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
4+
export Kernel, SqExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
55

66
export Transform, ScaleTransform
77

@@ -18,7 +18,7 @@ include("utils.jl")
1818
include("transform/transform.jl")
1919
include("kernelmatrix.jl")
2020

21-
kernels = ["squaredexponential","matern"]
21+
kernels = ["sqexponential","matern"]
2222
for k in kernels
2323
include(joinpath("kernels",k*".jl"))
2424
end

src/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11

22
@inline metric::Kernel) = κ.metric
33
kernels =
4-
for k in [:SquaredExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
4+
for k in [:SqExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
55
eval(quote
66
@inline::$k)(d::Real) = kappa(κ,d)
77
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
88
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9+
@inline::$k)(x::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,obsdim=obsdim)
910
end)
1011
end
1112
### Transform generics

src/kernels/matern.jl

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""
2-
MaternKernel([[ρ=1],ν=3/2])
2+
MaternKernel([ρ=1.0,[ν=1.0]])
33
44
The matern kernel is an isotropic Mercer kernel given by the formula:
55
66
```
7-
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
7+
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
88
```
99
10-
For `ν=n+1/2, n=0,1,2,...` it can be simplified (it will be converted automatically).
11-
`ρ` is a lengthscale parameter.
10+
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use `ExponentialKernel` for `n=0`, `Matern32Kernel`, for `n=1`, Matern52Kernel for `n=2` and `SqExponentialKernel` for `n=∞`.
11+
`ρ` is the lengthscale parameter(s) or the transform object.
1212
1313
# Examples
1414
1515
```jldoctest; setup = :(using KernelFunctions)
1616
julia> MaternKernel()
17-
Matern3_2Kernel{Float64,Float64}(1.0)
17+
MaternKernel{Float64,Float64}(1.0,1.0)
1818
1919
julia> MaternKernel(2.0f0,3.0)
2020
MaternKernel{Float32,Float32}(2.0,3.0)
2121
22-
julia> MaternKernel([2.0,3.0],5/2)
23-
Matern5_2Kernel{Float64,Array{Float64}}([2.0,3.0])
22+
julia> MaternKernel([2.0,3.0],2.5)
23+
MaternKernel{Float64,Array{Float64}}([2.0,3.0],2.5)
2424
```
2525
"""
2626
struct MaternKernel{T,Tr<:Transform} <: Kernel{T,Tr}
@@ -34,47 +34,17 @@ end
3434

3535
function MaternKernel::T₁=1.0::T₂=1.5) where {T₁<:Real,T₂<:Real}
3636
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
37-
if ν == 0.5
38-
ExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
39-
elseif ν == 1.5
40-
Matern32Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
41-
elseif ν == 2.5
42-
Matern52Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
43-
elseif ν == Inf
44-
SquaredExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
45-
else
46-
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
47-
end
37+
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
4838
end
4939

5040
function MaternKernel::A::T=1.5) where {A<:AbstractVector{<:Real},T<:Real}
5141
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
52-
if ν == 0.5
53-
ExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
54-
elseif ν == 1.5
55-
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
56-
elseif ν == 2.5
57-
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
58-
elseif ν == Inf
59-
SquaredExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
60-
else
61-
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
62-
end
42+
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
6343
end
6444

65-
function MaternKernel(t::T₁::T₂=1.5) where {T₁<:Transform,T₂<:Real}
66-
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
67-
if ν == 0.5
68-
ExponentialKernel{eltype(t),T₁}(t)
69-
elseif ν == 1.5
70-
Matern32Kernel{eltype(t),T₁}(t)
71-
elseif ν == 2.5
72-
Matern52Kernel{eltype(t),T₁}(t)
73-
elseif ν == Inf
74-
SquaredExponentialKernel{eltype(t),T₁}(t)
75-
else
76-
MaternKernel{eltype(t),T₁}(t,ν)
77-
end
45+
function MaternKernel(t::Tr::T=1.5) where {Tr<:Transform,T<:Real}
46+
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
47+
MaternKernel{eltype(t),Tr}(t,ν)
7848
end
7949

8050
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*logtwo - lgamma.ν) - κ.ν*log(sqrt(2κ.ν)*d))*besselk.ν,sqrt(2κ.ν)*d)
@@ -96,8 +66,8 @@ function Matern32Kernel(ρ::A) where {A<:AbstractVector{<:Real}}
9666
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
9767
end
9868

99-
function Matern32Kernel(t::Transform)
100-
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
69+
function Matern32Kernel(t::Tr) where {Tr<:Transform}
70+
Matern52Kernel{eltype(Tr),Tr}(t)
10171
end
10272

10373
@inline kappa::Matern32Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
@@ -118,8 +88,8 @@ function Matern52Kernel(ρ::A) where {A<:AbstractVector{<:Real}}
11888
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
11989
end
12090

121-
function Matern52Kernel(t::Transform)
122-
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
91+
function Matern52Kernel(t::Tr) where {Tr<:Transform}
92+
Matern52Kernel{eltype(Tr),Tr}(t)
12393
end
12494

12595
@inline kappa::Matern52Kernel, d::Real) where {T} = (1+sqrt(5)*d+5*d^2/3)*exp(-sqrt(5)*d)

src/kernels/sqexponential.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
SqExponentialKernel([α=1])
3+
4+
The squared exponential kernel is an isotropic Mercer kernel given by the formula:
5+
6+
```
7+
κ(x,y) = exp(-‖x-y‖²)
8+
```
9+
10+
See also [`ExponentialKernel`](@ref) for a
11+
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
12+
13+
# Examples
14+
15+
```jldoctest; setup = :(using KernelFunctions)
16+
julia> SqExponentialKernel()
17+
SqExponentialKernel{Float64,Float64}(1.0)
18+
19+
julia> SqExponentialKernel(2.0f0)
20+
SqExponentialKernel{Float32,Float32}(2.0)
21+
22+
julia> SqExponentialKernel([2.0,3.0])
23+
SqExponentialKernel{Float64,Array{Float64}}(1.0)
24+
```
25+
"""
26+
struct SqExponentialKernel{T,Tr<:Transform} <: Kernel{T,Tr}
27+
transform::Tr
28+
metric::SemiMetric
29+
function SqExponentialKernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
30+
return new{T,Tr}(transform,SqEuclidean())
31+
end
32+
end
33+
34+
function SqExponentialKernel::T=1.0) where {T<:Real}
35+
SqExponentialKernel{T,ScaleTransform{T}}(ScaleTransform(α))
36+
end
37+
38+
function SqExponentialKernel::A) where {A<:AbstractVector{<:Real}}
39+
SqExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(α))
40+
end
41+
42+
function SqExponentialKernel(t::T) where {T<:Transform}
43+
SqExponentialKernel{eltype(t),T}(t)
44+
end
45+
46+
@inline kappa::SqExponentialKernel, d²::Real) where {T} = exp(-d²)
47+
48+
# function convert(
49+
# ::Type{K},
50+
# κ::SqExponentialKernel
51+
# ) where {K>:SqExponentialKernel{T,A} where {T,A}}
52+
# return SqExponentialKernel{T}(T.(κ.α))
53+
# end

src/kernels/squaredexponential.jl

Lines changed: 0 additions & 53 deletions
This file was deleted.

test/constructors.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,30 @@ using KernelFunctions, Test, Distances
33
# test type conversion
44
l = 2.0
55
vl = [l,l]
6+
s = ScaleTransform(3.0)
67

7-
## SquaredExponentialKernel
8-
@testset "SquaredExponentialKernel" begin
9-
@test KernelFunctions.metric(SquaredExponentialKernel(l)) == SqEuclidean()
10-
@test KernelFunctions.transform(SquaredExponentialKernel(l)) == ScaleTransform(l)
11-
@test KernelFunctions.transform(SquaredExponentialKernel(vl)) == ScaleTransform(vl)
8+
## SqExponentialKernel
9+
@testset "SqExponentialKernel" begin
10+
@test KernelFunctions.metric(SqExponentialKernel(l)) == SqEuclidean()
11+
@test KernelFunctions.transform(SqExponentialKernel(l)) == ScaleTransform(l)
12+
@test KernelFunctions.transform(SqExponentialKernel(vl)) == ScaleTransform(vl)
13+
@test KernelFunctions.transform(SqExponentialKernel(s)) == s
1214
end
1315

16+
## MaternKernel
17+
1418
@testset "MaternKernel" begin
1519
@test KernelFunctions.metric(MaternKernel(l)) == Euclidean()
16-
@test KernelFunctions.metric(MaternKernel(l,1.5)) == Euclidean()
17-
@test KernelFunctions.metric(MaternKernel(l,2.5)) == Euclidean()
20+
@test KernelFunctions.metric(Matern32Kernel(l)) == Euclidean()
21+
@test KernelFunctions.metric(Matern52Kernel(l)) == Euclidean()
1822
@test KernelFunctions.transform(MaternKernel(l)) == ScaleTransform(l)
23+
@test KernelFunctions.transform(Matern32Kernel(l)) == ScaleTransform(l)
24+
@test KernelFunctions.transform(Matern52Kernel(l)) == ScaleTransform(l)
1925
@test KernelFunctions.transform(MaternKernel(vl)) == ScaleTransform(vl)
20-
@test isa(MaternKernel(),Matern32Kernel)
21-
@test isa(MaternKernel(1.0,1.0),MaternKernel)
22-
@test isa(MaternKernel(1.0,1.5),Matern32Kernel)
23-
@test isa(MaternKernel(1.0,2.5),Matern52Kernel)
24-
@test isa(MaternKernel(1.0,Inf),SquaredExponentialKernel)
26+
@test KernelFunctions.transform(Matern32Kernel(vl)) == ScaleTransform(vl)
27+
@test KernelFunctions.transform(Matern52Kernel(vl)) == ScaleTransform(vl)
28+
@test KernelFunctions.transform(MaternKernel(s)) == s
29+
@test KernelFunctions.transform(Matern32Kernel(s)) == s
30+
@test KernelFunctions.transform(Matern52Kernel(s)) == s
31+
2532
end

test/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ A = rand(dims...)
77
B = rand(dims...)
88
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
99
Kdiag = [zeros(dims[1]),zeros(dims[2])]
10-
kernels = [SquaredExponentialKernel(),MaternKernel(1.0,1.0),Matern32Kernel(),Matern52Kernel()]
10+
kernels = [SqExponentialKernel(),MaternKernel(),Matern32Kernel(),Matern52Kernel()]
1111
@testset "Inplace Kernel Matrix" begin
1212
for k in kernels
1313
@testset "$k" begin

test/testAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dims = [10,5]
77
A = rand(dims...)
88
B = rand(dims...)
99
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
10-
kernels = [SquaredExponentialKernel,MaternKernel]
10+
kernels = [SqExponentialKernel,MaternKernel,Matern32Kernel,Matern52Kernel]
1111
l = 2.0
1212
vl = l*ones(dims[1])
1313
testfunction(k,A,B) = det(kernelmatrix(k,A,B))

0 commit comments

Comments
 (0)