Skip to content

Commit 75d343f

Browse files
YingboMaAayushSabharwal
authored andcommitted
Support array states in ODEProblem/ODEFunction
Co-authored-by: Aayush Sabharwal <[email protected]>
1 parent 298c29c commit 75d343f

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
152152
nothing,
153153
isdde = false,
154154
kwargs...)
155+
array_vars = Dict{Any, Vector{Int}}()
156+
for (j, x) in enumerate(dvs)
157+
if istree(x) && operation(x) == getindex
158+
arg = arguments(x)[1]
159+
inds = get!(() -> Int[], array_vars, arg)
160+
push!(inds, j)
161+
end
162+
end
163+
subs = Dict()
164+
for (k, inds) in array_vars
165+
if inds == (inds′ = inds[1]:inds[end])
166+
inds = inds′
167+
end
168+
subs[k] = term(view, Sym{Any}(Symbol("ˍ₋arg1")), inds)
169+
end
155170
if isdde
156171
eqs = delay_to_function(sys)
157172
else
@@ -164,6 +179,7 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
164179
# substitute x(t) by just x
165180
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
166181
[eq.rhs for eq in eqs]
182+
rhss = fast_substitute(rhss, subs)
167183

168184
# TODO: add an optional check on the ordering of observed equations
169185
u = map(x -> time_varying_as_func(value(x), sys), dvs)
@@ -764,6 +780,17 @@ function get_u0_p(sys,
764780
defs = mergedefaults(defs, parammap, ps)
765781
end
766782
defs = mergedefaults(defs, u0map, dvs)
783+
for (k, v) in defs
784+
if Symbolics.isarraysymbolic(k)
785+
ks = scalarize(k)
786+
length(ks) == length(v) || error("$k has default value $v with unmatched size")
787+
for (kk, vv) in zip(ks, v)
788+
if !haskey(defs, kk)
789+
defs[kk] = vv
790+
end
791+
end
792+
end
793+
end
767794

768795
if symbolic_u0
769796
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ end
810810
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
811811
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
812812
end
813-
fast_substitute(eqs::AbstractArray{<:Eq}, subs) = fast_substitute.(eqs, (subs,))
813+
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
814814
fast_substitute(a, b) = substitute(a, b)
815815
function fast_substitute(expr, pair::Pair)
816816
a, b = pair

0 commit comments

Comments
 (0)