Skip to content

Commit e822993

Browse files
fixup! fix: fix view adjoints
1 parent 97edb58 commit e822993

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
44
version = "3.19.0"
55

66
[deps]
7+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -36,6 +37,7 @@ RecursiveArrayToolsTrackerExt = "Tracker"
3637
RecursiveArrayToolsZygoteExt = "Zygote"
3738

3839
[compat]
40+
Accessors = "0.1"
3941
Adapt = "3.4, 4"
4042
Aqua = "0.8"
4143
ArrayInterface = "7.6"

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module RecursiveArrayToolsZygoteExt
22

33
using RecursiveArrayTools
4+
using RecursiveArrayTools.Accessors: @set, @reset
45

56
if isdefined(Base, :get_extension)
67
using Zygote
@@ -125,9 +126,10 @@ end
125126
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
126127
view_adjoint = let A = A, I = I
127128
function (y)
128-
A = recursivecopy(A)
129-
A .= y
130-
(A, map(_ -> nothing, I)...)
129+
u = collect.(eachslice(y, dims=ndims(y)))
130+
B = @set A.u = u
131+
132+
(B, map(_ -> nothing, I)...)
131133
end
132134
end
133135
return view(A, I...), view_adjoint
@@ -136,11 +138,9 @@ end
136138
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
137139
view_adjoint = let A = A, I = I
138140
function (y)
139-
A = recursivecopy(A)
140-
recursivefill!(A, zero(eltype(A)))
141-
v = view(A, I...)
142-
v .= y
143-
return (A, map(_ -> nothing, I)...)
141+
B = @set A .= zero(eltype(A))
142+
@reset B[I...] = y
143+
return (B, map(_ -> nothing, I)...)
144144
end
145145
end
146146
view(A, I...), view_adjoint

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DocStringExtensions
88
using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
11+
import Accessors
1112
using SparseArrays
1213

1314
import Adapt

0 commit comments

Comments
 (0)