Skip to content

Commit 4140290

Browse files
Merge pull request #680 from AayushSabharwal/myb/batch
fix: batch observed function eval in plotting, bug fixes
2 parents b8e95a9 + 668f3f7 commit 4140290

File tree

5 files changed

+30
-9
lines changed

5 files changed

+30
-9
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ PyCall = "1.96"
7676
PythonCall = "0.9.15"
7777
RCall = "0.14.0"
7878
RecipesBase = "1.3.4"
79-
RecursiveArrayTools = "3.8.0"
79+
RecursiveArrayTools = "3.14.0"
8080
Reexport = "1"
8181
RuntimeGeneratedFunctions = "0.5.12"
8282
SciMLOperators = "0.3.7"
8383
SciMLStructures = "1.1"
8484
StaticArrays = "1.7"
8585
StaticArraysCore = "1.4"
8686
Statistics = "1.10"
87-
SymbolicIndexingInterface = "0.3.15"
87+
SymbolicIndexingInterface = "0.3.20"
8888
Tables = "1.11"
8989
Zygote = "0.6.67"
9090
julia = "1.10"

src/solutions/ode_solutions.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ end
172172
function (sol::AbstractODESolution)(t::Number, ::Type{deriv},
173173
idxs::AbstractVector{<:Integer},
174174
continuity) where {deriv}
175+
if eltype(sol.u) <: Number
176+
idxs = only(idxs)
177+
end
175178
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
176179
end
177180
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
@@ -183,6 +186,9 @@ end
183186
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
184187
idxs::AbstractVector{<:Integer},
185188
continuity) where {deriv}
189+
if eltype(sol.u) <: Number
190+
idxs = only(idxs)
191+
end
186192
A = sol.interp(t, idxs, deriv, sol.prob.p, continuity)
187193
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
188194
return DiffEqArray(A.u, A.t, p, sol)
@@ -203,7 +209,7 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
203209
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
204210
error("Incorrect specification of `idxs`")
205211
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
206-
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
212+
first(interp_sol[idxs])
207213
end
208214

209215
function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
@@ -224,8 +230,9 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
224230
error("Incorrect specification of `idxs`")
225231
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
226232
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
233+
indexed_sol = interp_sol[idxs]
227234
return DiffEqArray(
228-
[[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
235+
[indexed_sol[i] for i in 1:length(t)], t, p, sol)
229236
end
230237

231238
function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},

src/solutions/solution_interface.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,18 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
436436
plot_vecs = []
437437
labels = String[]
438438
varsyms = variable_symbols(sol)
439+
batch_symbolic_vars = []
440+
for x in vars
441+
for j in 2:length(x)
442+
if (x[j] isa Integer && x[j] == 0) || isequal(x[j], getindepsym_defaultt(sol))
443+
else
444+
push!(batch_symbolic_vars, x[j])
445+
end
446+
end
447+
end
448+
batch_symbolic_vars = identity.(batch_symbolic_vars)
449+
indexed_solution = sol(plott; idxs = batch_symbolic_vars)
450+
idxx = 0
439451
for x in vars
440452
tmp = []
441453
strs = String[]
@@ -444,7 +456,8 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
444456
push!(tmp, plott)
445457
push!(strs, "t")
446458
else
447-
push!(tmp, sol(plott; idxs = x[j]))
459+
idxx += 1
460+
push!(tmp, indexed_solution[idxx, :])
448461
if !isempty(varsyms) && x[j] isa Integer
449462
push!(strs, String(getname(varsyms[x[j]])))
450463
elseif hasname(x[j])

test/downstream/modelingtoolkit_remake.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ eqs = [D(x) ~ Hold(ud)
137137
xd ~ Sample(t, dt)(x)]
138138
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [p3 => 2p1])
139139
prob = ODEProblem(sys, [x => 1.0], (0.0, 5.0),
140-
[p1 => 1.0, p2 => 2, ud(k - 1) => 3.0, xd(k - 1) => 4.0, xd(k - 2) => 5.0])
140+
[p1 => 1.0, p2 => 2, ud(k - 1) => 3.0,
141+
xd(k - 1) => 4.0, xd(k - 2) => 5.0, yd(k - 1) => 0.0])
141142

142143
# parameter dependencies
143144
prob2 = @inferred ODEProblem remake(prob; p = [p1 => 2.0])

test/downstream/symbol_indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ end
9393
@test length(sol[(lorenz1.x, lorenz2.x)]) == length(sol)
9494
@test all(length.(sol[(lorenz1.x, lorenz2.x)]) .== 2)
9595

96-
@test sol[[lorenz1.x, lorenz2.x], :] isa Matrix{Float64}
97-
@test size(sol[[lorenz1.x, lorenz2.x], :]) == (2, length(sol))
98-
@test size(sol[[lorenz1.x, lorenz2.x], :]) == size(sol[[1, 2], :]) == size(sol[1:2, :])
96+
@test sol[[lorenz1.x, lorenz2.x], :] isa Vector{Vector{Float64}}
97+
@test length(sol[[lorenz1.x, lorenz2.x], :]) == length(sol)
98+
@test length(sol[[lorenz1.x, lorenz2.x], :][1]) == 2
9999

100100
@variables q(t)[1:2] = [1.0, 2.0]
101101
eqs = [D(q[1]) ~ 2q[1]

0 commit comments

Comments
 (0)