Skip to content

Commit c9015fa

Browse files
Fix derivative of _get_ν
Co-authored-by: willtebbutt <[email protected]>
1 parent 8e805ef commit c9015fa

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/basekernels/matern.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,16 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
3737

3838
@functor MaternKernel
3939

40+
# workaround for Zygote
41+
# unclear why it's needed but it is fine since it's stated officially that we don't support differentiation with respect to ν
4042
@inline _get_ν(k::MaternKernel) = only(k.ν)
41-
ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent
43+
function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel}
44+
function _get_ν_pullback(Δ)
45+
= ChainRulesCore.@not_implemented("derivatives of `MaternKernel` w.r.t. order `ν` are not implemented.")
46+
return Tangent{T}=dν, metric=NoTangent())
47+
end
48+
return _get_ν(k), _get_ν_pullback
49+
end
4250

4351
@inline function kappa(k::MaternKernel, d::Real)
4452
result = _matern(_get_ν(k), d)

0 commit comments

Comments
 (0)