Skip to content

Commit b8fd2e9

Browse files
committed
SquaredExponentialKernel becomes SqExponentialKernel
1 parent 3f52bc2 commit b8fd2e9

File tree

10 files changed

+95
-87
lines changed

10 files changed

+95
-87
lines changed

benchmark/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using KernelFunctions
22

33
SUITE["KernelFunctions"] = BenchmarkGroup()
44

5-
kernelnames = ["SquaredExponentialKernel"]
5+
kernelnames = ["SqExponentialKernel"]
66
kerneltypes = ["ARD","ISO"]
77
kernels=Dict{String,Dict{String,KernelFunctions.Kernel}}()
88
for k in kernelnames

dev/debugAD.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,39 @@ testfunction(k,A,B) = det(kernelmatrix(k,A,B))
1212
testfunction(k,A) = sum(kernelmatrix(k,A))
1313
k = MaternKernel(vl)
1414
KernelFunctions.kappa(k,3)
15-
testfunction(SquaredExponentialKernel(vl),A)
15+
testfunction(SqExponentialKernel(vl),A)
1616
testfunction(MaternKernel(vl),A)
1717
@which kernelmatrix(MaternKernel(vl),A,B)
1818
#For debugging
1919
@info "Running Zygote gradients"
2020
Zygote.refresh()
2121
## Zygote
22-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
22+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A),vl)
2323
Zygote.gradient(x->testfunction(MaternKernel(x),A),vl)
24-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)[1]
24+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl)[1]
2525
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),vl)[1]
26-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
26+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A,B),l)
2727
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),l)
28-
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
28+
Zygote.gradient(x->testfunction(SqExponentialKernel(x),A),l)
2929
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
3030
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
3131
Zygote.gradient(x->kernelmatrix(MaternKernel(x,1.0),A)[1],l)
3232
@info "Running Tracker gradients"
3333
## Tracker
34-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
35-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(l),x[:,:]),A)
36-
# # Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
37-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
38-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
39-
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
34+
# Tracker.gradient(x->testfunction(SqExponentialKernel(vl),x,B),A)
35+
# Tracker.gradient(x->testfunction(SqExponentialKernel(l),x[:,:]),A)
36+
# # Tracker.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl)
37+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A),vl)
38+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A,B),l)
39+
# Tracker.gradient(x->testfunction(SqExponentialKernel(x),A),l)
4040

4141
@info "Running ForwardDiff gradients"
4242
## ForwardDiff
43-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl) #
43+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A,B),vl) #
4444
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
45-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl) #
45+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x),A),vl) #
4646
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A),vl) #
47-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A,B),[l])
47+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A,B),[l])
4848
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A,B),[l])
49-
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
49+
ForwardDiff.gradient(x->testfunction(SqExponentialKernel(x[1]),A),[l])
5050
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A),[l])

dev/matrixvsvectors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ timekf = similar(timestheno); memkf = similar(timestheno)
1414
B = randn(D,1001)
1515

1616
# Standardised eq kernel with length-scale 0.1.
17-
medkf = median(@benchmark KernelFunctions.kernelmatrix(SquaredExponentialKernel(0.01),$A,$B,obsdim=2))
17+
medkf = median(@benchmark KernelFunctions.kernelmatrix(SqExponentialKernel(0.01),$A,$B,obsdim=2))
1818
timekf[i] = medkf.time/1e6; memkf[i] = medkf.memory/2^20
1919
medstheno = median(@benchmark pw(eq(; l=0.1), ColsAreObs($A), ColsAreObs($B)))
2020
timestheno[i] = medstheno.time/1e6; memstheno[i] = medstheno.memory/2^20

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
4-
export Kernel, SquaredExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
4+
export Kernel, SqExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
55

66
export Transform, ScaleTransform
77

@@ -18,7 +18,7 @@ include("utils.jl")
1818
include("transform/transform.jl")
1919
include("kernelmatrix.jl")
2020

21-
kernels = ["squaredexponential","matern"]
21+
kernels = ["sqexponential","matern"]
2222
for k in kernels
2323
include(joinpath("kernels",k*".jl"))
2424
end

src/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11

22
@inline metric::Kernel) = κ.metric
33
kernels =
4-
for k in [:SquaredExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
4+
for k in [:SqExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
55
eval(quote
66
@inline::$k)(d::Real) = kappa(κ,d)
77
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
88
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9+
@inline::$k)(x::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,obsdim=obsdim)
910
end)
1011
end
1112
### Transform generics

src/kernels/sqexponential.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
SqExponentialKernel([α=1])
3+
4+
The squared exponential kernel is an isotropic Mercer kernel given by the formula:
5+
6+
```
7+
κ(x,y) = exp(-‖x-y‖²)
8+
```
9+
10+
See also [`ExponentialKernel`](@ref) for a
11+
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
12+
13+
# Examples
14+
15+
```jldoctest; setup = :(using KernelFunctions)
16+
julia> SqExponentialKernel()
17+
SqExponentialKernel{Float64,Float64}(1.0)
18+
19+
julia> SqExponentialKernel(2.0f0)
20+
SqExponentialKernel{Float32,Float32}(2.0)
21+
22+
julia> SqExponentialKernel([2.0,3.0])
23+
SqExponentialKernel{Float64,Array{Float64}}(1.0)
24+
```
25+
"""
26+
struct SqExponentialKernel{T,Tr<:Transform} <: Kernel{T,Tr}
27+
transform::Tr
28+
metric::SemiMetric
29+
function SqExponentialKernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
30+
return new{T,Tr}(transform,SqEuclidean())
31+
end
32+
end
33+
34+
function SqExponentialKernel::T=1.0) where {T<:Real}
35+
SqExponentialKernel{T,ScaleTransform{T}}(ScaleTransform(α))
36+
end
37+
38+
function SqExponentialKernel::A) where {A<:AbstractVector{<:Real}}
39+
SqExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(α))
40+
end
41+
42+
function SqExponentialKernel(t::T) where {T<:Transform}
43+
SqExponentialKernel{eltype(t),T}(t)
44+
end
45+
46+
@inline kappa::SqExponentialKernel, d²::Real) where {T} = exp(-d²)
47+
48+
# function convert(
49+
# ::Type{K},
50+
# κ::SqExponentialKernel
51+
# ) where {K>:SqExponentialKernel{T,A} where {T,A}}
52+
# return SqExponentialKernel{T}(T.(κ.α))
53+
# end

src/kernels/squaredexponential.jl

Lines changed: 0 additions & 53 deletions
This file was deleted.

test/constructors.jl

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,30 @@ using KernelFunctions, Test, Distances
33
# test type conversion
44
l = 2.0
55
vl = [l,l]
6+
s = ScaleTransform(3.0)
67

7-
## SquaredExponentialKernel
8-
@testset "SquaredExponentialKernel" begin
9-
@test KernelFunctions.metric(SquaredExponentialKernel(l)) == SqEuclidean()
10-
@test KernelFunctions.transform(SquaredExponentialKernel(l)) == ScaleTransform(l)
11-
@test KernelFunctions.transform(SquaredExponentialKernel(vl)) == ScaleTransform(vl)
8+
## SqExponentialKernel
9+
@testset "SqExponentialKernel" begin
10+
@test KernelFunctions.metric(SqExponentialKernel(l)) == SqEuclidean()
11+
@test KernelFunctions.transform(SqExponentialKernel(l)) == ScaleTransform(l)
12+
@test KernelFunctions.transform(SqExponentialKernel(vl)) == ScaleTransform(vl)
13+
@test KernelFunctions.transform(SqExponentialKernel(s)) == s
1214
end
1315

16+
## MaternKernel
17+
1418
@testset "MaternKernel" begin
1519
@test KernelFunctions.metric(MaternKernel(l)) == Euclidean()
16-
@test KernelFunctions.metric(MaternKernel(l,1.5)) == Euclidean()
17-
@test KernelFunctions.metric(MaternKernel(l,2.5)) == Euclidean()
20+
@test KernelFunctions.metric(Matern32Kernel(l)) == Euclidean()
21+
@test KernelFunctions.metric(Matern52Kernel(l)) == Euclidean()
1822
@test KernelFunctions.transform(MaternKernel(l)) == ScaleTransform(l)
23+
@test KernelFunctions.transform(Matern32Kernel(l)) == ScaleTransform(l)
24+
@test KernelFunctions.transform(Matern52Kernel(l)) == ScaleTransform(l)
1925
@test KernelFunctions.transform(MaternKernel(vl)) == ScaleTransform(vl)
20-
@test isa(MaternKernel(),Matern32Kernel)
21-
@test isa(MaternKernel(1.0,1.0),MaternKernel)
22-
@test isa(MaternKernel(1.0,1.5),Matern32Kernel)
23-
@test isa(MaternKernel(1.0,2.5),Matern52Kernel)
24-
@test isa(MaternKernel(1.0,Inf),SquaredExponentialKernel)
26+
@test KernelFunctions.transform(Matern32Kernel(vl)) == ScaleTransform(vl)
27+
@test KernelFunctions.transform(Matern52Kernel(vl)) == ScaleTransform(vl)
28+
@test KernelFunctions.transform(MaternKernel(s)) == s
29+
@test KernelFunctions.transform(Matern32Kernel(s)) == s
30+
@test KernelFunctions.transform(Matern52Kernel(s)) == s
31+
2532
end

test/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ A = rand(dims...)
77
B = rand(dims...)
88
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
99
Kdiag = [zeros(dims[1]),zeros(dims[2])]
10-
kernels = [SquaredExponentialKernel(),MaternKernel(1.0,1.0),Matern32Kernel(),Matern52Kernel()]
10+
kernels = [SqExponentialKernel(),MaternKernel(),Matern32Kernel(),Matern52Kernel()]
1111
@testset "Inplace Kernel Matrix" begin
1212
for k in kernels
1313
@testset "$k" begin

test/testAD.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dims = [10,5]
77
A = rand(dims...)
88
B = rand(dims...)
99
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
10-
kernels = [SquaredExponentialKernel,MaternKernel]
10+
kernels = [SqExponentialKernel,MaternKernel,Matern32Kernel,Matern52Kernel]
1111
l = 2.0
1212
vl = l*ones(dims[1])
1313
testfunction(k,A,B) = det(kernelmatrix(k,A,B))

0 commit comments

Comments
 (0)