Skip to content

Commit e7f658b

Browse files
authored
Remove workaround
1 parent d87cc84 commit e7f658b

File tree

4 files changed

+2
-27
lines changed

4 files changed

+2
-27
lines changed

src/basekernels/matern.jl

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,8 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3737

3838
@functor MaternKernel
3939

40-
# workaround for Zygote
41-
# unclear why it's needed but it is fine since it's stated officially that we don't support differentiation with respect to ν
42-
@inline _get_ν(k::MaternKernel) = only(k.ν)
43-
function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel}
44-
function _get_ν_pullback(Δ)
45-
= ChainRulesCore.@not_implemented(
46-
"derivatives of `MaternKernel` w.r.t. order `ν` are not implemented."
47-
)
48-
return NoTangent(), Tangent{T}(; ν=dν, metric=NoTangent())
49-
end
50-
return _get_ν(k), _get_ν_pullback
51-
end
52-
5340
@inline function kappa(k::MaternKernel, d::Real)
54-
result = _matern(_get_ν(k), d)
41+
result = _matern(only(k.ν), d)
5542
return ifelse(iszero(d), one(result), result)
5643
end
5744

test/Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[deps]
22
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
43
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
54
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
65
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
@@ -20,7 +19,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2019

2120
[compat]
2221
AxisArrays = "0.4.3"
23-
ChainRulesTestUtils = "1.7"
2422
Compat = "3"
2523
Distances = "0.10"
2624
Documenter = "0.25, 0.26, 0.27"
@@ -32,4 +30,4 @@ LogExpFunctions = "0.2, 0.3"
3230
PDMats = "0.9, 0.10, 0.11"
3331
ReverseDiff = "1.2"
3432
SpecialFunctions = "0.10, 1, 2"
35-
Zygote = "0.4, 0.5, 0.6"
33+
Zygote = "0.6.38"

test/basekernels/matern.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,6 @@
1818
@test metric(k2) isa WeightedEuclidean
1919
@test k2(v1, v2) k(v1, v2)
2020

21-
# Test custom `rrule` (Zygote workaround).
22-
k = MaternKernel(; ν=rand())
23-
test_rrule(
24-
KernelFunctions._get_ν,
25-
k ChainRulesTestUtils.Tangent{typeof(k)}(;
26-
ν=randn(), metric=ChainRulesTestUtils.NoTangent()
27-
),
28-
)
29-
3021
# Standardised tests.
3122
TestUtils.test_interface(k, Float64)
3223
test_ADs(() -> MaternKernel(; nu=ν))

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using KernelFunctions
22
using AxisArrays
3-
using ChainRulesTestUtils
43
using Distances
54
using Documenter
65
using Functors: functor

0 commit comments

Comments
 (0)