Skip to content

Commit b568dde

Browse files
Merge pull request #367 from DhairyaLGandhi/dg/sym
Feat: Support for symbolic indexing of solution objects
2 parents 7fa7a4f + a8e204c commit b568dde

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,12 @@ end
145145

146146
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{
147147
AbstractArray, AbstractVectorOfArray})
148-
arr = reshape(x, p.sz)
149-
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
148+
if eltype(x) <: Number
149+
arr = reshape(x, p.sz)
150+
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
151+
elseif eltype(x) <: AbstractArray
152+
return VectorOfArray(x)
153+
end
150154
end
151155

152156
@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,

test/downstream/symbol_indexing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ sol_new = DiffEqArray(sol.u[1:10],
3434
@test_throws Exception sol[τ]
3535
@test_throws Exception sol_new[τ]
3636

37+
gs, = Zygote.gradient(sol) do sol
38+
sum(sol[fol_separate.x])
39+
end
40+
41+
@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))
42+
3743
# Tables interface
3844
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))
3945

0 commit comments

Comments
 (0)