Skip to content

Commit e0a7af6

Browse files
devmotionwilltebbuttgithub-actions[bot]
authored
Remove _get_ν (#452)
* Fix derivative of `_get_ν` Co-authored-by: willtebbutt <[email protected]> * Update Project.toml * Update src/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix `rrule` * Add test * Load ChainRulesTestUtils * Try to specify cotangent * Update test/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update matern.jl * Update test/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Remove workaround * Update Project.toml * Fix test_interface Co-authored-by: willtebbutt <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 35de8d2 commit e0a7af6

File tree

4 files changed

+17
-25
lines changed

4 files changed

+17
-25
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.38"
3+
version = "0.10.39"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/TestUtils.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,14 @@ using KernelFunctions
66
using Random
77
using Test
88

9-
# default tolerance values for test_interface:
10-
const __ATOL = sqrt(eps(Float64))
11-
const __RTOL = sqrt(eps(Float64))
12-
# ≈ 1.5e-8; chosen for no particular reason other than because it seems to
13-
# satisfy our own test cases within KernelFunctions.jl
14-
159
"""
1610
test_interface(
1711
k::Kernel,
1812
x0::AbstractVector,
1913
x1::AbstractVector,
2014
x2::AbstractVector;
21-
atol=__ATOL,
22-
rtol=__RTOL,
15+
rtol=1e-6,
16+
atol=rtol,
2317
)
2418
2519
Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`.
@@ -29,22 +23,14 @@ be of different lengths.
2923
These tests are intended to pick up on really substantial issues with a kernel implementation
3024
(e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to
3125
test the numerics in detail, which can be kernel-specific.
32-
The default value of `__ATOL` and `__RTOL` is `sqrt(eps(Float64)) ≈ 1.5e-8`, which satisfied
33-
this intention in the cases tested within KernelFunctions.jl itself.
34-
35-
test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:Real}; atol=__ATOL, rtol=__RTOL)
36-
37-
`test_interface` offers automated test data generation for kernels whose inputs are reals.
38-
This will run the tests for `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
39-
For other input vector types, please provide the data manually.
4026
"""
4127
function test_interface(
4228
k::Kernel,
4329
x0::AbstractVector,
4430
x1::AbstractVector,
4531
x2::AbstractVector;
46-
atol=__ATOL,
47-
rtol=__RTOL,
32+
rtol=1e-6,
33+
atol=rtol,
4834
)
4935
# Ensure that we have the required inputs.
5036
@assert length(x0) == length(x1)
@@ -160,7 +146,16 @@ function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
160146
return test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
161147
end
162148

163-
function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...)
149+
"""
150+
test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
151+
152+
Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`.
153+
154+
For other input types, please provide the data manually.
155+
156+
The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the randomly generated inputs.
157+
"""
158+
function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real}
164159
@testset "Vector{$T}" begin
165160
test_interface(rng, k, Vector{T}; kwargs...)
166161
end

src/basekernels/matern.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3737

3838
@functor MaternKernel
3939

40-
@inline _get_ν(k::MaternKernel) = only(k.ν)
41-
ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent
42-
4340
@inline function kappa(k::MaternKernel, d::Real)
44-
result = _matern(_get_ν(k), d)
41+
result = _matern(only(k.ν), d)
4542
return ifelse(iszero(d), one(result), result)
4643
end
4744

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ LogExpFunctions = "0.2, 0.3"
3030
PDMats = "0.9, 0.10, 0.11"
3131
ReverseDiff = "1.2"
3232
SpecialFunctions = "0.10, 1, 2"
33-
Zygote = "0.4, 0.5, 0.6"
33+
Zygote = "0.6.38"

0 commit comments

Comments
 (0)