Skip to content

Commit ae6bd49

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
feat: scalarize equations in ODESystem, fix vars! and namespace_expr
1 parent 75d343f commit ae6bd49

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
704704
rescoped = renamespace(n, O)
705705
similarterm(O, operation(rescoped), renamed,
706706
metadata = metadata(rescoped))
707+
elseif Symbolics.isarraysymbolic(O)
708+
# promote_symtype doesn't work for array symbolics
709+
similarterm(O, operation(O), renamed, symtype(O), metadata = metadata(O))
707710
else
708711
similarterm(O, operation(O), renamed, metadata = metadata(O))
709712
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
195195
gui_metadata = nothing)
196196
name === nothing &&
197197
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
198-
#deqs = scalarize(deqs)
199198
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
200-
199+
deqs = reduce(vcat, scalarize(deqs); init = Equation[])
201200
iv′ = value(iv)
202201
ps′ = value.(ps)
203202
ctrl′ = value.(controls)
@@ -236,7 +235,6 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
236235
end
237236

238237
function ODESystem(eqs, iv; kwargs...)
239-
eqs = scalarize(eqs)
240238
# NOTE: this assumes that the order of algebraic equations doesn't matter
241239
diffvars = OrderedSet()
242240
allunknowns = OrderedSet()

src/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,22 +342,22 @@ v == Set([D(y), u])
342342
function vars(exprs::Symbolic; op = Differential)
343343
istree(exprs) ? vars([exprs]; op = op) : Set([exprs])
344344
end
345+
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
346+
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
345347
vars(exprs; op = Differential) = foldl((x, y) -> vars!(x, y; op = op), exprs; init = Set())
346348
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
347349
function vars!(vars, eq::Equation; op = Differential)
348350
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
349351
end
350352
function vars!(vars, O; op = Differential)
353+
if isvariable(O) && !(istree(O) && operation(O) === getindex)
354+
return push!(vars, O)
355+
end
356+
351357
!istree(O) && return vars
352358
if operation(O) === (getindex)
353359
arr = first(arguments(O))
354-
!istree(arr) && return vars
355-
operation(arr) isa op && return push!(vars, O)
356-
isvariable(operation(O)) && return push!(vars, O)
357-
end
358-
359-
if isvariable(O)
360-
return push!(vars, O)
360+
return vars!(vars, arr)
361361
end
362362

363363
operation(O) isa op && return push!(vars, O)

0 commit comments

Comments
 (0)