Skip to content

Commit 5770229

Browse files
fixup! fix: fix view adjoints
1 parent 97edb58 commit 5770229

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

ext/RecursiveArrayToolsReverseDiffExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,16 @@ end
2323
end
2424
return Array(VA), Array_adjoint
2525
end
26+
27+
@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
28+
view_adjoint = let A = A, I = I
29+
function (y)
30+
A = recursivecopy(A)
31+
trackedarraycopyto!(A, y)
32+
(A, map(_ -> nothing, I)...)
33+
end
34+
end
35+
return view(A, I...), view_adjoint
36+
end
37+
2638
end # module

0 commit comments

Comments
 (0)