Skip to content

Commit 11eb759

Browse files
authored
Improve testing (#182)
* Bump Distances bound * Fix up the tests * Bump patch * Relax bound on exponentiated * Improve exponential kernel robustness * Revert changes * Produce some output during basekernel tests
1 parent 4b644ad commit 11eb759

File tree

6 files changed

+45
-19
lines changed

6 files changed

+45
-19
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.8.4"
3+
version = "0.8.5"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -18,7 +18,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1818

1919
[compat]
2020
Compat = "2.2, 3"
21-
Distances = "0.9"
21+
Distances = "0.9.1"
2222
Functors = "0.1"
2323
Requires = "1.0.1"
2424
SpecialFunctions = "0.8, 0.9, 0.10"

src/test_utils.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
module TestUtils
22

33
const __ATOL = 1e-9
4+
const __RTOL = 1e-9
45

6+
using Distances
57
using LinearAlgebra
68
using KernelFunctions
79
using Random
@@ -33,6 +35,7 @@ function test_interface(
3335
x1::AbstractVector,
3436
x2::AbstractVector;
3537
atol=__ATOL,
38+
rtol=__RTOL,
3639
)
3740
# Ensure that we have the required inputs.
3841
@assert length(x0) == length(x1)
@@ -69,8 +72,8 @@ function test_interface(
6972
@test eigmin(Matrix(kernelmatrix(k, x0))) > -atol
7073

7174
# Check that unary elementwise / pairwise are consistent with the binary versions.
72-
@test kerneldiagmatrix(k, x0) kerneldiagmatrix(k, x0, x0) atol=atol
73-
@test kernelmatrix(k, x0) kernelmatrix(k, x0, x0) atol=atol
75+
@test kerneldiagmatrix(k, x0) kerneldiagmatrix(k, x0, x0) atol=atol rtol=rtol
76+
@test kernelmatrix(k, x0) kernelmatrix(k, x0, x0) atol=atol rtol=rtol
7477

7578
# Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`.
7679
@test k(first(x0), first(x1)) isa Real
@@ -89,17 +92,17 @@ end
8992
function test_interface(
9093
rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
9194
) where {T<:Real}
92-
test_interface(k, randn(rng, T, 3), randn(rng, T, 3), randn(rng, T, 2); kwargs...)
95+
test_interface(k, randn(rng, T, 1001), randn(rng, T, 1001), randn(rng, T, 1000); kwargs...)
9396
end
9497

9598
function test_interface(
9699
rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...,
97100
) where {T<:Real}
98101
test_interface(
99102
k,
100-
ColVecs(randn(rng, T, dim_in, 3)),
101-
ColVecs(randn(rng, T, dim_in, 3)),
102-
ColVecs(randn(rng, T, dim_in, 2));
103+
ColVecs(randn(rng, T, dim_in, 1001)),
104+
ColVecs(randn(rng, T, dim_in, 1001)),
105+
ColVecs(randn(rng, T, dim_in, 1000));
103106
kwargs...,
104107
)
105108
end
@@ -109,9 +112,9 @@ function test_interface(
109112
) where {T<:Real}
110113
test_interface(
111114
k,
112-
RowVecs(randn(rng, T, 3, dim_in)),
113-
RowVecs(randn(rng, T, 3, dim_in)),
114-
RowVecs(randn(rng, T, 2, dim_in));
115+
RowVecs(randn(rng, T, 1001, dim_in)),
116+
RowVecs(randn(rng, T, 1001, dim_in)),
117+
RowVecs(randn(rng, T, 1000, dim_in));
115118
kwargs...,
116119
)
117120
end
@@ -121,9 +124,15 @@ function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
121124
end
122125

123126
function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...)
124-
test_interface(rng, k, Vector{T}; kwargs...)
125-
test_interface(rng, k, ColVecs{T}; kwargs...)
126-
test_interface(rng, k, RowVecs{T}; kwargs...)
127+
@testset "Vector{$T}" begin
128+
test_interface(rng, k, Vector{T}; kwargs...)
129+
end
130+
@testset "ColVecs{$T}" begin
131+
test_interface(rng, k, ColVecs{T}; kwargs...)
132+
end
133+
@testset "RowVecs{$T}" begin
134+
test_interface(rng, k, RowVecs{T}; kwargs...)
135+
end
127136
end
128137

129138
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)

test/basekernels/exponential.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@test KernelFunctions.iskroncompatible(k) == true
1717

1818
# Standardised tests.
19-
TestUtils.test_interface(k, Float64)
19+
TestUtils.test_interface(k)
2020
test_ADs(SEKernel)
2121
end
2222
@testset "ExponentialKernel" begin
@@ -30,7 +30,7 @@
3030
@test KernelFunctions.iskroncompatible(k) == true
3131

3232
# Standardised tests.
33-
TestUtils.test_interface(k, Float64)
33+
TestUtils.test_interface(k)
3434
test_ADs(ExponentialKernel)
3535
end
3636
@testset "GammaExponentialKernel" begin
@@ -46,6 +46,8 @@
4646

4747
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ])
4848
test_params(k, ([γ],))
49+
TestUtils.test_interface(GammaExponentialKernel=1.36))
50+
4951
#Coherence :
5052
@test isapprox(
5153
GammaExponentialKernel=2.0)(sqrt(0.5) * v1, sqrt(0.5) * v2),

test/basekernels/exponentiated.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct()
1212
@test repr(k) == "Exponentiated Kernel"
1313

14-
# Standardised tests.
15-
TestUtils.test_interface(k, Float64)
14+
# Standardised tests. This kernel appears to be fairly numerically unstable.
15+
TestUtils.test_interface(k; atol=1e-3)
1616
test_ADs(ExponentiatedKernel)
1717
end

test/basekernels/gabor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@test k.p 1.0 atol=1e-5
2020
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
2121

22-
test_interface(k)
22+
test_interface(k, Vector{Float64})
2323

2424
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:Zygote])
2525

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,35 @@ include("test_utils.jl")
8282

8383
@testset "basekernels" begin
8484
include(joinpath("basekernels", "constant.jl"))
85+
print(" ")
8586
include(joinpath("basekernels", "cosine.jl"))
87+
print(" ")
8688
include(joinpath("basekernels", "exponential.jl"))
89+
print(" ")
8790
include(joinpath("basekernels", "exponentiated.jl"))
91+
print(" ")
8892
include(joinpath("basekernels", "fbm.jl"))
93+
print(" ")
8994
include(joinpath("basekernels", "gabor.jl"))
95+
print(" ")
9096
include(joinpath("basekernels", "maha.jl"))
97+
print(" ")
9198
include(joinpath("basekernels", "matern.jl"))
99+
print(" ")
92100
include(joinpath("basekernels", "nn.jl"))
101+
print(" ")
93102
include(joinpath("basekernels", "periodic.jl"))
103+
print(" ")
94104
include(joinpath("basekernels", "piecewisepolynomial.jl"))
105+
print(" ")
95106
include(joinpath("basekernels", "polynomial.jl"))
107+
print(" ")
96108
include(joinpath("basekernels", "rationalquad.jl"))
109+
print(" ")
97110
include(joinpath("basekernels", "sm.jl"))
111+
print(" ")
98112
include(joinpath("basekernels", "wiener.jl"))
113+
print(" ")
99114
end
100115
@info "Ran tests on BaseKernel"
101116

0 commit comments

Comments
 (0)