Skip to content

Commit 298c29c

Browse files
YingboMaChrisRackauckas
authored andcommitted
Array equations/variables support in structural_simplify
1 parent 65799a1 commit 298c29c

File tree

6 files changed

+52
-35
lines changed

6 files changed

+52
-35
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,9 @@ function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys
703703
# metadata from the rescoped variable
704704
rescoped = renamespace(n, O)
705705
similarterm(O, operation(rescoped), renamed,
706-
metadata = metadata(rescoped))::T
706+
metadata = metadata(rescoped))
707707
else
708-
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
708+
similarterm(O, operation(O), renamed, metadata = metadata(O))
709709
end
710710
elseif isvariable(O)
711711
renamespace(n, O)

src/systems/connectors.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,18 +292,20 @@ function connection2set!(connectionsets, namespace, ss, isouter)
292292
end
293293
end
294294

295-
function generate_connection_set(sys::AbstractSystem, find = nothing, replace = nothing)
295+
function generate_connection_set(
296+
sys::AbstractSystem, find = nothing, replace = nothing; scalarize = false)
296297
connectionsets = ConnectionSet[]
297298
domain_csets = ConnectionSet[]
298-
sys = generate_connection_set!(connectionsets, domain_csets, sys, find, replace)
299+
sys = generate_connection_set!(
300+
connectionsets, domain_csets, sys, find, replace, scalarize)
299301
csets = merge(connectionsets)
300302
domain_csets = merge([csets; domain_csets], true)
301303

302304
sys, (csets, domain_csets)
303305
end
304306

305307
function generate_connection_set!(connectionsets, domain_csets,
306-
sys::AbstractSystem, find, replace, namespace = nothing)
308+
sys::AbstractSystem, find, replace, scalarize, namespace = nothing)
307309
subsys = get_systems(sys)
308310

309311
isouter = generate_isouter(sys)
@@ -325,8 +327,13 @@ function generate_connection_set!(connectionsets, domain_csets,
325327
end
326328
neweq isa AbstractArray ? append!(eqs, neweq) : push!(eqs, neweq)
327329
else
328-
if lhs isa Number || lhs isa Symbolic
329-
push!(eqs, eq) # split connections and equations
330+
if lhs isa Number || lhs isa Symbolic || eltype(lhs) <: Symbolic
331+
# split connections and equations
332+
if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray
333+
append!(eqs, Symbolics.scalarize(eq))
334+
else
335+
push!(eqs, eq)
336+
end
330337
elseif lhs isa Connection && get_systems(lhs) === :domain
331338
connection2set!(domain_csets, namespace, get_systems(rhs), isouter)
332339
else
@@ -356,7 +363,7 @@ function generate_connection_set!(connectionsets, domain_csets,
356363
end
357364
@set! sys.systems = map(
358365
s -> generate_connection_set!(connectionsets, domain_csets, s,
359-
find, replace,
366+
find, replace, scalarize,
360367
renamespace(namespace, s)),
361368
subsys)
362369
@set! sys.eqs = eqs
@@ -471,8 +478,8 @@ function domain_defaults(sys, domain_csets)
471478
end
472479

473480
function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
474-
debug = false, tol = 1e-10)
475-
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace)
481+
debug = false, tol = 1e-10, scalarize = true)
482+
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
476483
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
477484
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
478485
sys = flatten(sys, true)

src/systems/diffeqs/odesystem.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
195195
gui_metadata = nothing)
196196
name === nothing &&
197197
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
198-
deqs = scalarize(deqs)
198+
#deqs = scalarize(deqs)
199199
@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."
200200

201-
iv′ = value(scalarize(iv))
202-
ps′ = value.(scalarize(ps))
203-
ctrl′ = value.(scalarize(controls))
204-
dvs′ = value.(scalarize(dvs))
201+
iv′ = value(iv)
202+
ps′ = value.(ps)
203+
ctrl′ = value.(controls)
204+
dvs′ = value.(dvs)
205205
dvs′ = filter(x -> !isdelay(x, iv), dvs′)
206206

207207
if !(isempty(default_u0) && isempty(default_p))

src/systems/systemstructure.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ function TearingState(sys; quick_cancel = false, check = true)
268268
end
269269

270270
vars = OrderedSet()
271+
varsvec = []
271272
for (i, eq′) in enumerate(eqs)
272273
if eq′.lhs isa Connection
273274
check ? error("$(nameof(sys)) has unexpanded `connect` statements") :
@@ -282,9 +283,18 @@ function TearingState(sys; quick_cancel = false, check = true)
282283
eq = 0 ~ rhs - lhs
283284
end
284285
vars!(vars, eq.rhs, op = Symbolics.Operator)
286+
for v in vars
287+
v = scalarize(v)
288+
if v isa AbstractArray
289+
v = setmetadata.(v, VariableIrreducible, true)
290+
append!(varsvec, v)
291+
else
292+
push!(varsvec, v)
293+
end
294+
end
285295
isalgeq = true
286296
unknownvars = []
287-
for var in vars
297+
for var in varsvec
288298
ModelingToolkit.isdelay(var, iv) && continue
289299
set_incidence = true
290300
@label ANOTHER_VAR
@@ -340,6 +350,7 @@ function TearingState(sys; quick_cancel = false, check = true)
340350
push!(symbolic_incidence, copy(unknownvars))
341351
empty!(unknownvars)
342352
empty!(vars)
353+
empty!(varsvec)
343354
if isalgeq
344355
eqs[i] = eq
345356
else
@@ -350,9 +361,10 @@ function TearingState(sys; quick_cancel = false, check = true)
350361
# sort `fullvars` such that the mass matrix is as diagonal as possible.
351362
dervaridxs = collect(dervaridxs)
352363
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
364+
var_to_old_var = Dict(zip(fullvars, fullvars))
353365
for dervaridx in dervaridxs
354366
dervar = fullvars[dervaridx]
355-
diffvar = lower_order_var(dervar)
367+
diffvar = var_to_old_var[lower_order_var(dervar)]
356368
if !(diffvar in sorted_fullvars)
357369
push!(sorted_fullvars, diffvar)
358370
end

src/utils.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,19 +348,22 @@ function vars!(vars, eq::Equation; op = Differential)
348348
(vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars)
349349
end
350350
function vars!(vars, O; op = Differential)
351+
!istree(O) && return vars
352+
if operation(O) === (getindex)
353+
arr = first(arguments(O))
354+
!istree(arr) && return vars
355+
operation(arr) isa op && return push!(vars, O)
356+
isvariable(operation(O)) && return push!(vars, O)
357+
end
358+
351359
if isvariable(O)
352360
return push!(vars, O)
353361
end
354-
!istree(O) && return vars
355362

356363
operation(O) isa op && return push!(vars, O)
357364

358-
if operation(O) === (getindex) &&
359-
isvariable(first(arguments(O)))
360-
return push!(vars, O)
361-
end
362-
363365
isvariable(operation(O)) && push!(vars, O)
366+
364367
for arg in arguments(O)
365368
vars!(vars, arg; op = op)
366369
end

test/odesystem.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -517,21 +517,16 @@ eqs = [D(x) ~ x * y
517517
using StaticArrays
518518
using SymbolicUtils: term
519519
using SymbolicUtils.Code
520-
using Symbolics: unwrap, wrap
521-
function foo(a::Num, ms::AbstractVector)
522-
a = unwrap(a)
523-
ms = map(unwrap, ms)
524-
wrap(term(foo, a, term(SVector, ms...)))
525-
end
520+
using Symbolics: unwrap, wrap, @register_symbolic
526521
foo(a, ms::AbstractVector) = a + sum(ms)
527-
@variables x(t) ms(t)[1:3]
528-
ms = collect(ms)
529-
eqs = [D(x) ~ foo(x, ms); D.(ms) .~ 1]
522+
@register_symbolic foo(a, ms::AbstractVector)
523+
@variables t x(t) ms(t)[1:3]
524+
D = Differential(t)
525+
eqs = [D(x) ~ foo(x, ms); D(ms) ~ ones(3)]
530526
@named sys = ODESystem(eqs, t, [x; ms], [])
531527
@named emptysys = ODESystem(Equation[], t)
532-
@named outersys = compose(emptysys, sys)
533-
outersys = complete(outersys)
534-
prob = ODEProblem(outersys, [sys.x => 1.0; collect(sys.ms) .=> 1:3], (0, 1.0))
528+
@mtkbuild outersys = compose(emptysys, sys)
529+
prob = ODEProblem(outersys, [sys.x => 1.0, sys.ms => 1:3], (0, 1.0))
535530
@test_nowarn solve(prob, Tsit5())
536531

537532
# x/x

0 commit comments

Comments
 (0)