Skip to content

Commit 73d940d

Browse files
committed
Cover StaticArrays
1 parent fbc1fd6 commit 73d940d

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

src/chainrules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function ChainRulesCore.rrule(
130130
function pairwise_pullback(z̄)
131131
Δ = unthunk(z̄)
132132
n = size(x, dims)
133-
= zero(x)
133+
= collect(zero(x))
134134
= zero(d.r)
135135
if dims == 1
136136
for j in 1:n, i in 1:n
@@ -166,8 +166,8 @@ function ChainRulesCore.rrule(
166166
Δ = unthunk(z̄)
167167
n = size(x, dims)
168168
m = size(y, dims)
169-
= zero(x)
170-
= zero(y)
169+
= collect(zero(x))
170+
= collect(zero(y))
171171
= zero(d.r)
172172
if dims == 1
173173
for j in 1:m, i in 1:n
@@ -202,8 +202,8 @@ function ChainRulesCore.rrule(
202202
function colwise_pullback(z̄)
203203
Δ = unthunk(z̄)
204204
n = size(x, 2)
205-
= zero(x)
206-
= zero(y)
205+
= collect(zero(x))
206+
= collect(zero(y))
207207
= zero(d.r)
208208
for i in 1:n
209209
xi = view(x, :, i)

test/chainrules.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,20 @@
3131

3232
@testset "rrules for Sinus(r=$r)" for r in (rand(3),)
3333
dist = KernelFunctions.Sinus(r)
34-
test_rrule(dist, rand(3), rand(3))
35-
test_rrule(Distances.pairwise, dist, rand(3, 2); fkwargs=(dims=2,))
36-
test_rrule(Distances.pairwise, dist, rand(3, 2), rand(3, 3); fkwargs=(dims=2,))
37-
test_rrule(Distances.colwise, dist, rand(3, 2), rand(3, 2))
34+
@testset "$type" for type in (Vector, SVector{3})
35+
test_rrule(dist, type(rand(3)), type(rand(3)))
36+
end
37+
@testset "$type1, $type2" for type1 in (Matrix, SMatrix{3, 2}),
38+
type2 in (Matrix, SMatrix{3, 4})
39+
test_rrule(
40+
Distances.pairwise, dist, type1(rand(3, 2));
41+
fkwargs=(dims=2,)
42+
)
43+
test_rrule(
44+
Distances.pairwise, dist, type1(rand(3, 2)), type2(rand(3, 4));
45+
fkwargs=(dims=2,)
46+
)
47+
test_rrule(Distances.colwise, dist, type1(rand(3, 2)), type1(rand(3, 2)))
48+
end
3849
end
3950
end

0 commit comments

Comments
 (0)