Skip to content

Commit bc0bcf5

Browse files
committed
revert to simpler workaround
1 parent ff49855 commit bc0bcf5

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

src/basekernels/matern.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3737

3838
@functor MaternKernel
3939

40-
# Work-around for Zygote -- `NotImplemented` doesn't appear to play nicely with whatever
41-
# rule currently exists for `only`.
42-
_get_ν(k::MaternKernel) = only(k.ν)
43-
function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel}
44-
function _get_ν_pullback(Δ)
45-
= ChainRulesCore.@not_implemented("Derivatives w.r.t. ν are not implemented.")
46-
return Tangent{T}=dν, metric=NoTangent())
47-
end
48-
return _get_ν(k), _get_ν_pullback
49-
end
40+
@inline _get_ν(k::MaternKernel) = ChainRulesCore.@ignore_derivatives only(k.ν) # work-around for Zygote AD
5041

5142
@inline function kappa(k::MaternKernel, d::Real)
5243
result = _matern(_get_ν(k), d)

0 commit comments

Comments
 (0)