Skip to content

Commit 7a0414d

Browse files
feat: handle vectorofarray better while projecting
1 parent 7fa7a4f commit 7a0414d

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ end
3131

3232
@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
3333
function AbstractVectorOfArray_getindex_adjoint(Δ)
34+
@show "in hete at vecint"
3435
iter = 0
3536
Δ′ = [(j i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x)))
3637
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -77,14 +78,14 @@ end
7778
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
7879
end
7980

80-
@adjoint function VectorOfArray(u)
81-
VectorOfArray(u),
82-
y -> begin
83-
y isa Ref && (y = VectorOfArray(y[].u))
84-
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
85-
for i in 1:size(y)[end]]),)
86-
end
87-
end
81+
# @adjoint function VectorOfArray(u)
82+
# VectorOfArray(u),
83+
# y -> begin
84+
# y isa Ref && (y = VectorOfArray(y[].u))
85+
# (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
86+
# for i in 1:size(y)[end]]),)
87+
# end
88+
# end
8889

8990
@adjoint function Base.copy(u::VectorOfArray)
9091
copy(u),
@@ -145,8 +146,12 @@ end
145146

146147
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{
147148
AbstractArray, AbstractVectorOfArray})
148-
arr = reshape(x, p.sz)
149-
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
149+
if eltype(x) <: Number
150+
arr = reshape(x, p.sz)
151+
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
152+
elseif eltype(x) <: AbstractArray
153+
return VectorOfArray(x)
154+
end
150155
end
151156

152157
@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
@@ -271,6 +276,7 @@ end
271276
ȳ -> (nothing, Zygote._project(x, ȳ))
272277

273278
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
279+
@show
274280
N = ndims(x̄)
275281
if length(x) == length(x̄)
276282
Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors

0 commit comments

Comments
 (0)