Skip to content

Commit f3734d5

Browse files
committed
Correction matern and tests
1 parent 682bae7 commit f3734d5

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KernelFunctions
22

3-
export kernelmatrix, kernelmatrix!, kappa
3+
export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa
44
export Kernel, SquaredExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
55

66
export Transform, ScaleTransform

src/kernels/matern.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ struct Matern32Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
8888
end
8989
end
9090

91-
function Matern32Kernel::T) where {T<:Real}
92-
Matern32Kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
91+
function Matern32Kernel::T=1.0) where {T<:Real}
92+
Matern32Kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
9393
end
9494

9595
function Matern32Kernel::A) where {A<:AbstractVector{<:Real}}
@@ -110,8 +110,8 @@ struct Matern52Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
110110
end
111111
end
112112

113-
function Matern52Kernel::T) where {T<:Real}
114-
Matern52Kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
113+
function Matern52Kernel::T=1.0) where {T<:Real}
114+
Matern52Kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
115115
end
116116

117117
function Matern52Kernel::A) where {A<:AbstractVector{<:Real}}

test/constructors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
using KernelFunctions, Test, Distances
22
#TODO Test metric weights for ARD, test equivalency for different constructors,
33
# test type conversion
44
l = 2.0
@@ -21,4 +21,5 @@ end
2121
@test isa(MaternKernel(1.0,1.0),MaternKernel)
2222
@test isa(MaternKernel(1.0,1.5),Matern32Kernel)
2323
@test isa(MaternKernel(1.0,2.5),Matern52Kernel)
24+
@test isa(MaternKernel(1.0,Inf),SquaredExponentialKernel)
2425
end

test/kernelmatrix.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
1-
using Distances
1+
using Distances, LinearAlgebra
2+
using KernelFunctions
23

34
dims = [10,5]
45

56
A = rand(dims...)
67
B = rand(dims...)
78
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
8-
k = SquaredExponentialKernel()
9-
k = MaternKernel()
10-
9+
Kdiag = [zeros(dims[1]),zeros(dims[2])]
10+
kernels = [SquaredExponentialKernel(),MaternKernel(1.0,1.0),Matern32Kernel(),Matern52Kernel()]
1111
@testset "Inplace Kernel Matrix" begin
12-
for obsdim in [1,2]
13-
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
12+
for k in kernels
13+
@testset "$k" begin
14+
for obsdim in [1,2]
15+
@test kernelmatrix!(K[obsdim],k,A,B,obsdim=obsdim) == kernelmatrix(k,A,B,obsdim=obsdim)
16+
@test kerneldiagmatrix!(Kdiag[obsdim],k,A,obsdim=obsdim) == kerneldiagmatrix(k,A,obsdim=obsdim)
17+
end
18+
end
1419
end
1520
end
1621

1722
@testset "Kernel matrix" begin
18-
for obsdim in [1,2]
19-
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
23+
for k in kernels
24+
@testset "$k" begin
25+
for obsdim in [1,2]
26+
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
27+
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.([k],pairwise(KernelFunctions.metric(k),A,dims=obsdim))
28+
end
29+
end
2030
end
2131
end

0 commit comments

Comments
 (0)