Skip to content

Commit afe209c

Browse files
committed
Add basic tests for rrules
1 parent cb3a7f3 commit afe209c

File tree

4 files changed

+21
-4
lines changed

4 files changed

+21
-4
lines changed

src/chainrules.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
116116
val = sum(abs2_sind_r)
117117
gradx = twoπ .* cospi.(d) .* sind ./ s.r .^ 2
118118
function evaluate_pullback::Any)
119-
return (r=-2Δ .* abs2_sind_r ./ s.r,), Δ * gradx, -Δ * gradx
119+
= -2Δ .* abs2_sind_r ./ s.r
120+
= ChainRulesCore.Tangent{typeof(s)}(; r=r̄)
121+
return s̄, Δ * gradx, -Δ * gradx
120122
end
121123
return val, evaluate_pullback
122124
end
@@ -149,7 +151,8 @@ function ChainRulesCore.rrule(
149151
x̄[:, j] -= ds
150152
end
151153
end
152-
return NoTangent(), (r=r̄,), @thunk(project_x(x̄))
154+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
155+
return NoTangent(), d̄, @thunk(project_x(x̄))
153156
end
154157
return Distances.pairwise(d, x; dims), pairwise_pullback
155158
end
@@ -185,7 +188,8 @@ function ChainRulesCore.rrule(
185188
ȳ[:, j] -= ds
186189
end
187190
end
188-
return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ))
191+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
192+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
189193
end
190194
return Distances.pairwise(d, x, y; dims), pairwise_pullback
191195
end
@@ -209,7 +213,8 @@ function ChainRulesCore.rrule(
209213
x̄[:, i] += ds
210214
ȳ[:, i] -= ds
211215
end
212-
return NoTangent(), (r=r̄,), @thunk(project_x(x̄)), @thunk(project_y(ȳ))
216+
= ChainRulesCore.Tangent{typeof(d)}(; r=r̄)
217+
return NoTangent(), d̄, @thunk(project_x(x̄)), @thunk(project_y(ȳ))
213218
end
214219
return Distances.colwise(d, x, y), colwise_pullback
215220
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
[deps]
22
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
35
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
46
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
57
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

test/chainrules.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,12 @@
2828
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
2929
end
3030
end
31+
32+
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
33+
dist = KernelFunctions.Sinus(r)
34+
ddist = (r = ones(length(r)),)
35+
test_rrule(dist, rand(3), rand(3))
36+
test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,))
37+
test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,))
38+
end
3139
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using KernelFunctions
22
using AxisArrays
3+
using ChainRulesCore
4+
using ChainRulesTestUtils
35
using Distances
46
using Documenter
57
using Functors: functor

0 commit comments

Comments
 (0)