Skip to content

Commit c664868

Browse files
committed
Batch observed function eval if possible
1 parent b568dde commit c664868

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/vector_of_array.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,20 @@ Base.@propagate_inbounds function _getindex(
406406
if all(x -> is_parameter(A, x), sym)
407407
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
408408
else
409-
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
409+
return A[sym, eachindex(A.t)]
410410
end
411411
end
412412

413413
Base.@propagate_inbounds function _getindex(
414414
A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...)
415-
return reduce(vcat, map(s -> A[s, args...]', sym))
415+
u = A.u[args...]
416+
t = A.t[args...]
417+
observed_fn = observed(A, sym)
418+
if t isa AbstractArray
419+
return observed_fn.(u, (parameter_values(A),), t)
420+
else
421+
return observed_fn(u, parameter_values(A), t)
422+
end
416423
end
417424

418425
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,

test/downstream/symbol_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ sol_new = DiffEqArray(sol.u[1:10],
3535
@test_throws Exception sol_new[τ]
3636

3737
gs, = Zygote.gradient(sol) do sol
38-
sum(sol[fol_separate.x])
38+
sum(sol[fol_separate.x])
3939
end
4040

4141
@test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))

0 commit comments

Comments
 (0)