Skip to content

Commit 71c5ac0

Browse files
authored
Fix single number (v/h)cat (#152)
* fix single number (v/h)cat * fix some indents * fix tests
1 parent c14e0e2 commit 71c5ac0

File tree

2 files changed

+47
-37
lines changed

2 files changed

+47
-37
lines changed

src/derivatives/arrays.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ end
5858

5959
@grad function vcat(xs::Union{Number, AbstractVecOrMat}...)
6060
xs_value = value.(xs)
61-
out_value = reduce(vcat,xs_value)
61+
out_value = vcat(xs_value...,)
6262
function back(Δ)
6363
start = 0
6464
Δs = map(xs) do xsi
65-
if xsi isa Number
66-
d = Δ[start+1]
67-
else
68-
d = Δ[start+1:start+size(xsi,1), :]
69-
end
70-
start += size(xsi, 1)
71-
d
65+
if xsi isa Number
66+
d = Δ[start+1]
67+
else
68+
d = Δ[start+1:start+size(xsi,1), :]
69+
end
70+
start += size(xsi, 1)
71+
d
7272
end
7373
return (Δs...,)
7474
end
@@ -77,20 +77,20 @@ end
7777

7878
@grad function hcat(xs::Union{Number, AbstractVecOrMat}...)
7979
xs_value = value.(xs)
80-
out_value = reduce(hcat, xs_value)
80+
out_value = hcat(xs_value...,)
8181
function back(Δ)
8282
start = 0
8383
Δs = map(xs) do xsi
84-
d = if ndims(xsi) == 0
85-
Δ[start+1]
86-
elseif ndims(xsi) == 1
87-
Δ[:, start+1]
88-
else
89-
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
90-
Δ[:, start+1:start+size(xsi,2), i...]
91-
end
92-
start += size(xsi, 2)
93-
d
84+
d = if ndims(xsi) == 0
85+
Δ[start+1]
86+
elseif ndims(xsi) == 1
87+
Δ[:, start+1]
88+
else
89+
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
90+
Δ[:, start+1:start+size(xsi,2), i...]
91+
end
92+
start += size(xsi, 2)
93+
d
9494
end
9595
return (Δs...,)
9696
end
@@ -106,17 +106,17 @@ end
106106
return cat(Xs_value...; dims = dims), Δ -> begin
107107
start = ntuple(i -> 0, Val(ndims(Δ)))
108108
Δs = map(Xs) do xs
109-
if xs isa Number
110-
d = Δ[start+1]
111-
start = start .+ 1
112-
else
113-
dim_xs = 1:ndims(xs)
114-
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
115-
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
116-
d = reshape(Δ[xs_in_Δ...],size(xs))
117-
start = start .+ till_xs
118-
end
119-
d
109+
if xs isa Number
110+
d = Δ[start+1]
111+
start = start .+ 1
112+
else
113+
dim_xs = 1:ndims(xs)
114+
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
115+
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
116+
d = reshape(Δ[xs_in_Δ...],size(xs))
117+
start = start .+ till_xs
118+
end
119+
d
120120
end
121121
return (Δs...,)
122122
end

test/derivatives/ArrayFunctionTests.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,25 @@ end
1616
@test any(iszero, track([ones(2); 0.0]))
1717
end
1818

19-
function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple())
19+
function testcat(f, args::Tuple, type, kwargs=NamedTuple())
2020
x = f(track.(args)...; kwargs...)
2121
@test x isa type
2222
@test value(x) == f(args...; kwargs...)
2323

24-
x = f(track(args[1]), args[2]; kwargs...)
25-
@test x isa type
26-
@test value(x) == f(args...; kwargs...)
24+
if length(args) == 1
25+
x = f(track(args[1]); kwargs...)
26+
@test x isa type
27+
@test value(x) == f(args...; kwargs...)
28+
else
29+
@assert length(args) == 2
30+
x = f(track(args[1]), args[2]; kwargs...)
31+
@test x isa type
32+
@test value(x) == f(args...; kwargs...)
2733

28-
x = f(args[1], track(args[2]); kwargs...)
29-
@test x isa type
30-
@test value(x) == f(args...; kwargs...)
34+
x = f(args[1], track(args[2]); kwargs...)
35+
@test x isa type
36+
@test value(x) == f(args...; kwargs...)
37+
end
3138

3239
args = (args..., args...)
3340
x = f(track.(args)...; kwargs...)
@@ -64,6 +71,7 @@ end
6471
a = rand(3,3,3)
6572
n = rand()
6673

74+
testcat(cat, (n,), TrackedVector, (dims=1,))
6775
testcat(cat, (n, n), TrackedVector, (dims=1,))
6876
testcat(cat, (n, n), TrackedMatrix, (dims=2,))
6977
testcat(cat, (v, n), TrackedVector, (dims=1,))
@@ -79,11 +87,13 @@ end
7987
testcat(cat, (a, a), TrackedArray, (dims=3,))
8088
testcat(cat, (a, m), TrackedArray, (dims=3,))
8189

90+
testcat(vcat, (n,), TrackedVector)
8291
testcat(vcat, (n, n), TrackedVector)
8392
testcat(vcat, (v, n), TrackedVector)
8493
testcat(vcat, (n, v), TrackedVector)
8594
testcat(vcat, (v, v), TrackedVector)
8695

96+
testcat(hcat, (n,), TrackedMatrix)
8797
testcat(hcat, (n, n), TrackedMatrix)
8898
testcat(hcat, (v, v), TrackedMatrix)
8999
testcat(hcat, (v, m), TrackedMatrix)

0 commit comments

Comments
 (0)