|
31 | 31 |
|
32 | 32 | @adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
|
33 | 33 | function AbstractVectorOfArray_getindex_adjoint(Δ)
|
| 34 | + @show "in hete at vecint" |
34 | 35 | iter = 0
|
35 | 36 | Δ′ = [(j ∈ i ? Δ[iter += 1] : FillArrays.Fill(zero(eltype(x)), size(x)))
|
36 | 37 | for (x, j) in zip(VA.u, 1:length(VA))]
|
|
77 | 78 | ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
|
78 | 79 | end
|
79 | 80 |
|
80 |
| -@adjoint function VectorOfArray(u) |
81 |
| - VectorOfArray(u), |
82 |
| - y -> begin |
83 |
| - y isa Ref && (y = VectorOfArray(y[].u)) |
84 |
| - (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] |
85 |
| - for i in 1:size(y)[end]]),) |
86 |
| - end |
87 |
| -end |
| 81 | +# @adjoint function VectorOfArray(u) |
| 82 | +# VectorOfArray(u), |
| 83 | +# y -> begin |
| 84 | +# y isa Ref && (y = VectorOfArray(y[].u)) |
| 85 | +# (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] |
| 86 | +# for i in 1:size(y)[end]]),) |
| 87 | +# end |
| 88 | +# end |
88 | 89 |
|
89 | 90 | @adjoint function Base.copy(u::VectorOfArray)
|
90 | 91 | copy(u),
|
|
145 | 146 |
|
146 | 147 | function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{
|
147 | 148 | AbstractArray, AbstractVectorOfArray})
|
148 |
| - arr = reshape(x, p.sz) |
149 |
| - return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) |
| 149 | + if eltype(x) <: Number |
| 150 | + arr = reshape(x, p.sz) |
| 151 | + return VectorOfArray([arr[:, i] for i in 1:p.sz[end]]) |
| 152 | + elseif eltype(x) <: AbstractArray |
| 153 | + return VectorOfArray(x) |
| 154 | + end |
150 | 155 | end
|
151 | 156 |
|
152 | 157 | @adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
|
|
271 | 276 | ȳ -> (nothing, Zygote._project(x, ȳ))
|
272 | 277 |
|
273 | 278 | function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
|
| 279 | + @show x̄ |
274 | 280 | N = ndims(x̄)
|
275 | 281 | if length(x) == length(x̄)
|
276 | 282 | Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
|
|
0 commit comments