Skip to content

Commit 3f52bc2

Browse files
committed
Corrected Matern inference bug : fixes #7
1 parent 9f18535 commit 3f52bc2

File tree

1 file changed

+16
-46
lines changed

1 file changed

+16
-46
lines changed

src/kernels/matern.jl

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""
2-
MaternKernel([[ρ=1],ν=3/2])
2+
MaternKernel([ρ=1.0,[ν=1.0]])
33
44
The matern kernel is an isotropic Mercer kernel given by the formula:
55
66
```
7-
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
7+
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖)^ν K_ν(√(2ν)‖x-y‖)
88
```
99
10-
For `ν=n+1/2, n=0,1,2,...` it can be simplified (it will be converted automatically).
11-
`ρ` is a lengthscale parameter.
10+
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use `ExponentialKernel` for `n=0`, `Matern32Kernel`, for `n=1`, Matern52Kernel for `n=2` and `SqExponentialKernel` for `n=∞`.
11+
`ρ` is the lengthscale parameter(s) or the transform object.
1212
1313
# Examples
1414
1515
```jldoctest; setup = :(using KernelFunctions)
1616
julia> MaternKernel()
17-
Matern3_2Kernel{Float64,Float64}(1.0)
17+
MaternKernel{Float64,Float64}(1.0,1.0)
1818
1919
julia> MaternKernel(2.0f0,3.0)
2020
MaternKernel{Float32,Float32}(2.0,3.0)
2121
22-
julia> MaternKernel([2.0,3.0],5/2)
23-
Matern5_2Kernel{Float64,Array{Float64}}([2.0,3.0])
22+
julia> MaternKernel([2.0,3.0],2.5)
23+
MaternKernel{Float64,Array{Float64}}([2.0,3.0],2.5)
2424
```
2525
"""
2626
struct MaternKernel{T,Tr<:Transform} <: Kernel{T,Tr}
@@ -34,47 +34,17 @@ end
3434

3535
function MaternKernel::T₁=1.0::T₂=1.5) where {T₁<:Real,T₂<:Real}
3636
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
37-
if ν == 0.5
38-
ExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
39-
elseif ν == 1.5
40-
Matern32Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
41-
elseif ν == 2.5
42-
Matern52Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
43-
elseif ν == Inf
44-
SquaredExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
45-
else
46-
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
47-
end
37+
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
4838
end
4939

5040
function MaternKernel::A::T=1.5) where {A<:AbstractVector{<:Real},T<:Real}
5141
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
52-
if ν == 0.5
53-
ExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
54-
elseif ν == 1.5
55-
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
56-
elseif ν == 2.5
57-
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
58-
elseif ν == Inf
59-
SquaredExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
60-
else
61-
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
62-
end
42+
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
6343
end
6444

65-
function MaternKernel(t::T₁::T₂=1.5) where {T₁<:Transform,T₂<:Real}
66-
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
67-
if ν == 0.5
68-
ExponentialKernel{eltype(t),T₁}(t)
69-
elseif ν == 1.5
70-
Matern32Kernel{eltype(t),T₁}(t)
71-
elseif ν == 2.5
72-
Matern52Kernel{eltype(t),T₁}(t)
73-
elseif ν == Inf
74-
SquaredExponentialKernel{eltype(t),T₁}(t)
75-
else
76-
MaternKernel{eltype(t),T₁}(t,ν)
77-
end
45+
function MaternKernel(t::Tr::T=1.5) where {Tr<:Transform,T<:Real}
46+
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
47+
MaternKernel{eltype(t),Tr}(t,ν)
7848
end
7949

8050
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*logtwo - lgamma.ν) - κ.ν*log(sqrt(2κ.ν)*d))*besselk.ν,sqrt(2κ.ν)*d)
@@ -96,8 +66,8 @@ function Matern32Kernel(ρ::A) where {A<:AbstractVector{<:Real}}
9666
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
9767
end
9868

99-
function Matern32Kernel(t::Transform)
100-
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
69+
function Matern32Kernel(t::Tr) where {Tr<:Transform}
70+
Matern52Kernel{eltype(Tr),Tr}(t)
10171
end
10272

10373
@inline kappa::Matern32Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
@@ -118,8 +88,8 @@ function Matern52Kernel(ρ::A) where {A<:AbstractVector{<:Real}}
11888
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
11989
end
12090

121-
function Matern52Kernel(t::Transform)
122-
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
91+
function Matern52Kernel(t::Tr) where {Tr<:Transform}
92+
Matern52Kernel{eltype(Tr),Tr}(t)
12393
end
12494

12595
@inline kappa::Matern52Kernel, d::Real) where {T} = (1+sqrt(5)*d+5*d^2/3)*exp(-sqrt(5)*d)

0 commit comments

Comments
 (0)