Skip to content

Commit 5b4e9d3

Browse files
authored
Merge branch 'master' into dw/fix_matern
2 parents 8473386 + 35de8d2 commit 5b4e9d3

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

src/KernelFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ include("mokernels/lmm.jl")
124124
include("chainrules.jl")
125125
include("zygoterules.jl")
126126

127-
include("test_utils.jl")
127+
include("TestUtils.jl")
128128

129129
function __init__()
130130
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin

src/test_utils.jl renamed to src/TestUtils.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,42 @@
11
module TestUtils
22

3-
const __ATOL = 1e-9
4-
const __RTOL = 1e-9
5-
63
using Distances
74
using LinearAlgebra
85
using KernelFunctions
96
using Random
107
using Test
118

9+
# default tolerance values for test_interface:
10+
const __ATOL = sqrt(eps(Float64))
11+
const __RTOL = sqrt(eps(Float64))
12+
# ≈ 1.5e-8; chosen for no particular reason other than because it seems to
13+
# satisfy our own test cases within KernelFunctions.jl
14+
1215
"""
1316
test_interface(
1417
k::Kernel,
1518
x0::AbstractVector,
1619
x1::AbstractVector,
1720
x2::AbstractVector;
1821
atol=__ATOL,
22+
rtol=__RTOL,
1923
)
2024
2125
Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`.
2226
`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should
2327
be of different lengths.
2428
25-
test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:AbstractVector}; atol=__ATOL)
29+
These tests are intended to pick up on really substantial issues with a kernel implementation
30+
(e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to
31+
test the numerics in detail, which can be kernel-specific.
32+
The default value of `__ATOL` and `__RTOL` is `sqrt(eps(Float64)) ≈ 1.5e-8`, which satisfied
33+
this intention in the cases tested within KernelFunctions.jl itself.
34+
35+
test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:Real}; atol=__ATOL, rtol=__RTOL)
2636
27-
`test_interface` offers certain types of test data generation to make running these tests
28-
require less code for common input types. For example, `Vector{<:Real}`, `ColVecs{<:Real}`,
29-
and `RowVecs{<:Real}` are supported. For other input vector types, please provide the data
30-
manually.
37+
`test_interface` offers automated test data generation for kernels whose inputs are reals.
38+
This will run the tests for `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
39+
For other input vector types, please provide the data manually.
3140
"""
3241
function test_interface(
3342
k::Kernel,
@@ -50,11 +59,12 @@ function test_interface(
5059
@test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2))
5160

5261
# Check that elementwise is consistent with pairwise.
53-
@test kernelmatrix_diag(k, x0, x1) diag(kernelmatrix(k, x0, x1)) atol = atol
62+
@test kernelmatrix_diag(k, x0, x1) diag(kernelmatrix(k, x0, x1)) atol = atol rtol =
63+
rtol
5464

5565
# Check additional binary elementwise properties for kernels.
5666
@test kernelmatrix_diag(k, x0, x1) kernelmatrix_diag(k, x1, x0)
57-
@test kernelmatrix(k, x0, x2) kernelmatrix(k, x2, x0)' atol = atol
67+
@test kernelmatrix(k, x0, x2) kernelmatrix(k, x2, x0)' atol = atol rtol = rtol
5868

5969
# Check that unary elementwise basically works.
6070
@test kernelmatrix_diag(k, x0) isa AbstractVector
@@ -63,10 +73,10 @@ function test_interface(
6373
# Check that unary pairwise basically works.
6474
@test kernelmatrix(k, x0) isa AbstractMatrix
6575
@test size(kernelmatrix(k, x0)) == (length(x0), length(x0))
66-
@test kernelmatrix(k, x0) kernelmatrix(k, x0)' atol = atol
76+
@test kernelmatrix(k, x0) kernelmatrix(k, x0)' atol = atol rtol = rtol
6777

6878
# Check that unary elementwise is consistent with unary pairwise.
69-
@test kernelmatrix_diag(k, x0) diag(kernelmatrix(k, x0)) atol = atol
79+
@test kernelmatrix_diag(k, x0) diag(kernelmatrix(k, x0)) atol = atol rtol = rtol
7080

7181
# Check that unary pairwise produces a positive definite matrix (approximately).
7282
@test eigmin(Matrix(kernelmatrix(k, x0))) > -atol

src/transform/transform.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ Abstract type defining a transformation of the input.
55
"""
66
abstract type Transform end
77

8+
# We introduce our own _map for Transform so that we can work around
9+
# https://github.com/FluxML/Zygote.jl/issues/646 and define our own pullback
10+
# (see zygoterules.jl)
811
Base.map(t::Transform, x::AbstractVector) = _map(t, x)
912
_map(t::Transform, x::AbstractVector) = t.(x)
1013

0 commit comments

Comments
 (0)