Skip to content

Commit f3de1aa

Browse files
Merge pull request #2512 from SciML/initialize_nondae
Initialization on non-DAE models
2 parents ef8171a + a477f43 commit f3de1aa

14 files changed

+204
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ Libdl = "1"
8989
LinearAlgebra = "1"
9090
MLStyle = "0.4.17"
9191
NaNMath = "0.3, 1"
92-
OrdinaryDiffEq = "6.72.0"
92+
OrdinaryDiffEq = "6.73.0"
9393
PrecompileTools = "1"
9494
RecursiveArrayTools = "2.3, 3"
9595
Reexport = "0.2, 1"

examples/electrical_components.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using ModelingToolkit, OrdinaryDiffEq
33
using ModelingToolkit: t_nounits as t, D_nounits as D
44

55
@connector function Pin(; name)
6-
sts = @variables v(t)=1.0 i(t)=1.0 [connect = Flow]
6+
sts = @variables v(t) [guess = 1.0] i(t) [guess = 1.0, connect = Flow]
77
ODESystem(Equation[], t, sts, []; name = name)
88
end
99

@@ -16,7 +16,7 @@ end
1616
@component function OnePort(; name)
1717
@named p = Pin()
1818
@named n = Pin()
19-
sts = @variables v(t)=1.0 i(t)=1.0
19+
sts = @variables v(t) [guess = 1.0] i(t) [guess = 1.0]
2020
eqs = [v ~ p.v - n.v
2121
0 ~ p.i + n.i
2222
i ~ p.i]
@@ -64,7 +64,7 @@ end
6464
end
6565

6666
@connector function HeatPort(; name)
67-
@variables T(t)=293.15 Q_flow(t)=0.0 [connect = Flow]
67+
@variables T(t) [guess = 293.15] Q_flow(t) [guess = 0.0, connect = Flow]
6868
ODESystem(Equation[], t, [T, Q_flow], [], name = name)
6969
end
7070

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ for prop in [:eqs
576576
:preface
577577
:torn_matching
578578
:initializesystem
579+
:initialization_eqs
579580
:schedule
580581
:tearing_state
581582
:substitutions

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,14 +862,22 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
862862
ps = full_parameters(sys)
863863
iv = get_iv(sys)
864864

865+
# TODO: Pass already computed information to varmap_to_vars call
866+
# in process_u0? That would just be a small optimization
867+
varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ?
868+
defaults(sys) :
869+
merge(defaults(sys), todict(u0map))
870+
varlist = collect(map(unwrap, dvs))
871+
missingvars = setdiff(varlist, collect(keys(varmap)))
872+
865873
# Append zeros to the variables which are determined by the initialization system
866874
# This essentially bypasses the check for if initial conditions are defined for DAEs
867875
# since they will be checked in the initialization problem's construction
868876
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
869877
ci = infer_clocks!(ClockInference(TearingState(sys)))
870878
# TODO: make it work with clocks
871879
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
872-
if (implicit_dae || calculate_massmatrix(sys) !== I) &&
880+
if sys isa ODESystem && (implicit_dae || !isempty(missingvars)) &&
873881
all(isequal(Continuous()), ci.var_domain) &&
874882
ModelingToolkit.get_tearing_state(sys) !== nothing
875883
if eltype(u0map) <: Number
@@ -881,6 +889,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
881889

882890
zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
883891
trueinit = identity.([zerovars; u0map])
892+
u0map isa StaticArraysCore.StaticArray &&
893+
(trueinit = SVector{length(trueinit)}(trueinit))
884894
else
885895
initializeprob = nothing
886896
initializeprobmap = nothing
@@ -1530,6 +1540,21 @@ function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...
15301540
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
15311541
end
15321542

1543+
const INCOMPLETE_INITIALIZATION_MESSAGE = """
1544+
Initialization incomplete. Not all of the state variables of the
1545+
DAE system can be determined by the initialization. Missing
1546+
variables:
1547+
"""
1548+
1549+
struct IncompleteInitializationError <: Exception
1550+
uninit::Any
1551+
end
1552+
1553+
function Base.showerror(io::IO, e::IncompleteInitializationError)
1554+
println(io, INCOMPLETE_INITIALIZATION_MESSAGE)
1555+
println(io, e.uninit)
1556+
end
1557+
15331558
function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
15341559
t::Number, u0map = [],
15351560
parammap = DiffEqBase.NullParameters();
@@ -1550,6 +1575,14 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
15501575
generate_initializesystem(sys; u0map); fully_determined = false)
15511576
end
15521577

1578+
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
1579+
1580+
# TODO: throw on uninitialized arrays
1581+
filter!(x -> !(x isa Symbolics.Arr), uninit)
1582+
if !isempty(uninit)
1583+
throw(IncompleteInitializationError(uninit))
1584+
end
1585+
15531586
neqs = length(equations(isys))
15541587
nunknown = length(unknowns(isys))
15551588

src/systems/diffeqs/odesystem.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ struct ODESystem <: AbstractODESystem
101101
"""
102102
initializesystem::Union{Nothing, NonlinearSystem}
103103
"""
104+
Extra equations to be enforced during the initialization sequence.
105+
"""
106+
initialization_eqs::Vector{Equation}
107+
"""
104108
The schedule for the code generation process.
105109
"""
106110
schedule::Any
@@ -171,7 +175,8 @@ struct ODESystem <: AbstractODESystem
171175

172176
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
173177
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses,
174-
torn_matching, initializesystem, schedule, connector_type, preface, cevents,
178+
torn_matching, initializesystem, initialization_eqs, schedule,
179+
connector_type, preface, cevents,
175180
devents, parameter_dependencies,
176181
metadata = nothing, gui_metadata = nothing,
177182
tearing_state = nothing,
@@ -190,8 +195,8 @@ struct ODESystem <: AbstractODESystem
190195
end
191196
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
192197
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, torn_matching,
193-
initializesystem, schedule, connector_type, preface, cevents, devents, parameter_dependencies,
194-
metadata,
198+
initializesystem, initialization_eqs, schedule, connector_type, preface,
199+
cevents, devents, parameter_dependencies, metadata,
195200
gui_metadata, tearing_state, substitutions, complete, index_cache,
196201
discrete_subsystems, solved_unknowns, split_idxs, parent)
197202
end
@@ -208,6 +213,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
208213
defaults = _merge(Dict(default_u0), Dict(default_p)),
209214
guesses = Dict(),
210215
initializesystem = nothing,
216+
initialization_eqs = Equation[],
211217
schedule = nothing,
212218
connector_type = nothing,
213219
preface = nothing,
@@ -260,7 +266,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
260266
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
261267
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
262268
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing, initializesystem,
263-
schedule, connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
269+
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
270+
disc_callbacks, parameter_dependencies,
264271
metadata, gui_metadata, checks = checks)
265272
end
266273

src/systems/nonlinear/initializesystem.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,51 @@ function generate_initializesystem(sys::ODESystem;
1717
# Start the equations list with algebraic equations
1818
eqs_ics = eqs[idxs_alge]
1919
u0 = Vector{Pair}(undef, 0)
20-
defs = merge(defaults(sys), todict(u0map))
2120

22-
full_states = [sts; getfield.((observed(sys)), :lhs)]
21+
eqs_diff = eqs[idxs_diff]
22+
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
23+
24+
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
2325
set_full_states = Set(full_states)
2426
guesses = todict(guesses)
2527
schedule = getfield(sys, :schedule)
2628

27-
dd_guess = if schedule !== nothing
29+
if schedule !== nothing
2830
guessmap = [x[2] => get(guesses, x[1], default_dd_value)
2931
for x in schedule.dummy_sub]
30-
Dict(filter(x -> !isnothing(x[1]), guessmap))
32+
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
33+
if u0map === nothing || isempty(u0map)
34+
filtered_u0 = u0map
35+
else
36+
filtered_u0 = Pair[]
37+
for x in u0map
38+
y = get(schedule.dummy_sub, x[1], x[1])
39+
y = get(diffmap, y, y)
40+
if y isa Symbolics.Arr
41+
_y = collect(y)
42+
43+
# TODO: Don't scalarize arrays
44+
for i in 1:length(_y)
45+
push!(filtered_u0, _y[i] => x[2][i])
46+
end
47+
elseif y isa ModelingToolkit.BasicSymbolic
48+
# y is a derivative expression expanded
49+
# add to the initialization equations
50+
push!(eqs_ics, y ~ x[2])
51+
elseif y set_full_states
52+
push!(filtered_u0, y => x[2])
53+
else
54+
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
55+
end
56+
end
57+
filtered_u0 = filtered_u0 isa Pair ? todict([filtered_u0]) : todict(filtered_u0)
58+
end
3159
else
32-
Dict()
60+
dd_guess = Dict()
61+
filtered_u0 = u0map
3362
end
3463

64+
defs = merge(defaults(sys), filtered_u0)
3565
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
3666

3767
for st in full_states
@@ -55,7 +85,7 @@ function generate_initializesystem(sys::ODESystem;
5585
end
5686

5787
pars = [parameters(sys); get_iv(sys)]
58-
nleqs = [eqs_ics; observed(sys)]
88+
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
5989

6090
sys_nl = NonlinearSystem(nleqs,
6191
full_states,

src/variables.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,23 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
170170
end
171171
end
172172

173+
const MISSING_VARIABLES_MESSAGE = """
174+
Initial condition underdefined. Some are missing from the variable map.
175+
Please provide a default (`u0`), initialization equation, or guess
176+
for the following variables:
177+
"""
178+
179+
struct MissingVariablesError <: Exception
180+
vars::Any
181+
end
182+
183+
function Base.showerror(io::IO, e::MissingVariablesError)
184+
println(io, MISSING_VARIABLES_MESSAGE)
185+
println(io, e.vars)
186+
end
187+
173188
function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false,
174-
toterm = Symbolics.diff2term)
189+
toterm = Symbolics.diff2term, initialization_phase = false)
175190
varmap = merge(defaults, varmap) # prefers the `varmap`
176191
varmap = Dict(toterm(value(k)) => value(varmap[k]) for k in keys(varmap))
177192
# resolve symbolic parameter expressions
@@ -180,7 +195,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
180195
end
181196

182197
missingvars = setdiff(varlist, collect(keys(varmap)))
183-
check && (isempty(missingvars) || throw_missingvars(missingvars))
198+
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
184199

185200
out = [varmap[var] for var in varlist]
186201
end

test/components.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ let
9898
@named rc_model2 = compose(_rc_model2,
9999
[resistor, resistor2, capacitor, source, ground])
100100
sys2 = structural_simplify(rc_model2)
101-
prob2 = ODEProblem(sys2, [], (0, 10.0), guesses = u0)
101+
prob2 = ODEProblem(sys2, [source.p.i => 0.0], (0, 10.0), guesses = u0)
102102
sol2 = solve(prob2, Rosenbrock23())
103-
@test sol2[source.p.i] sol2[rc_model2.source.p.i] sol2[capacitor.i]
103+
@test sol2[source.p.i] sol2[rc_model2.source.p.i] -sol2[capacitor.i]
104104
end
105105

106106
# Outer/inner connections
@@ -157,7 +157,7 @@ sys = structural_simplify(ll_model)
157157
u0 = unknowns(sys) .=> 0
158158
@test_nowarn ODEProblem(
159159
sys, [], (0, 10.0), guesses = u0, warn_initialize_determined = false)
160-
prob = DAEProblem(sys, D.(unknowns(sys)) .=> 0, u0, (0, 0.5))
160+
prob = DAEProblem(sys, D.(unknowns(sys)) .=> 0, [], (0, 0.5), guesses = u0)
161161
sol = solve(prob, DFBDF())
162162
@test sol.retcode == SciMLBase.ReturnCode.Success
163163

test/initializationsystem.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,4 +331,90 @@ p = [σ => 28.0,
331331
β => 8 / 3]
332332

333333
tspan = (0.0, 100.0)
334-
@test_throws ArgumentError prob=ODEProblem(sys, u0, tspan, p, jac = true)
334+
@test_throws ModelingToolkit.IncompleteInitializationError prob=ODEProblem(
335+
sys, u0, tspan, p, jac = true)
336+
337+
# DAE Initialization on ODE with nonlinear system for initial conditions
338+
# https://github.com/SciML/ModelingToolkit.jl/issues/2508
339+
340+
using ModelingToolkit, OrdinaryDiffEq, Test
341+
using ModelingToolkit: t_nounits as t, D_nounits as D
342+
343+
function System2(; name)
344+
vars = @variables begin
345+
dx(t), [guess = 0]
346+
ddx(t), [guess = 0]
347+
end
348+
eqs = [D(dx) ~ ddx
349+
0 ~ ddx + dx + 1]
350+
return ODESystem(eqs, t, vars, []; name)
351+
end
352+
353+
@mtkbuild sys = System2()
354+
prob = ODEProblem(sys, [sys.dx => 1], (0, 1)) # OK
355+
prob = ODEProblem(sys, [sys.ddx => -2], (0, 1), guesses = [sys.dx => 1])
356+
sol = solve(prob, Tsit5())
357+
@test SciMLBase.successful_retcode(sol)
358+
@test sol[1] == [1.0]
359+
360+
## Late binding initialization_eqs
361+
362+
function System3(; name)
363+
vars = @variables begin
364+
dx(t), [guess = 0]
365+
ddx(t), [guess = 0]
366+
end
367+
eqs = [D(dx) ~ ddx
368+
0 ~ ddx + dx + 1]
369+
initialization_eqs = [
370+
ddx ~ -2
371+
]
372+
return ODESystem(eqs, t, vars, []; name, initialization_eqs)
373+
end
374+
375+
@mtkbuild sys = System3()
376+
prob = ODEProblem(sys, [], (0, 1), guesses = [sys.dx => 1])
377+
sol = solve(prob, Tsit5())
378+
@test SciMLBase.successful_retcode(sol)
379+
@test sol[1] == [1.0]
380+
381+
# Steady state initialization
382+
383+
@parameters σ ρ β
384+
@variables x(t) y(t) z(t)
385+
386+
eqs = [D(D(x)) ~ σ * (y - x),
387+
D(y) ~ x *- z) - y,
388+
D(z) ~ x * y - β * z]
389+
390+
@named sys = ODESystem(eqs, t)
391+
sys = structural_simplify(sys)
392+
393+
u0 = [D(x) => 2.0,
394+
x => 1.0,
395+
D(y) => 0.0,
396+
z => 0.0]
397+
398+
p ==> 28.0,
399+
ρ => 10.0,
400+
β => 8 / 3]
401+
402+
tspan = (0.0, 0.2)
403+
prob_mtk = ODEProblem(sys, u0, tspan, p)
404+
sol = solve(prob_mtk, Tsit5())
405+
@test sol[x *- z) - y][1] == 0.0
406+
407+
@variables x(t) y(t) z(t)
408+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
409+
410+
eqs = [D(x) ~ α * x - β * x * y
411+
D(y) ~ -γ * y + δ * x * y
412+
z ~ x + y]
413+
414+
@named sys = ODESystem(eqs, t)
415+
simpsys = structural_simplify(sys)
416+
tspan = (0.0, 10.0)
417+
418+
prob = ODEProblem(simpsys, [D(x) => 0.0, y => 0.0], tspan, guesses = [x => 0.0])
419+
sol = solve(prob, Tsit5())
420+
@test sol[1] == [0.0, 0.0]

test/input_output_handling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ model = ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper]
131131
name = :name)
132132
model_outputs = [inertia1.w, inertia2.w, inertia1.phi, inertia2.phi]
133133
model_inputs = [torque.tau.u]
134-
matrices, ssys = linearize(model, model_inputs, model_outputs)
134+
matrices, ssys = linearize(model, model_inputs, model_outputs);
135135
@test length(ModelingToolkit.outputs(ssys)) == 4
136136

137137
if VERSION >= v"1.8" # :opaque_closure not supported before

0 commit comments

Comments
 (0)