Skip to content

Commit 49d8119

Browse files
Handle steady state initializations
1 parent 1274908 commit 49d8119

File tree

4 files changed

+85
-10
lines changed

4 files changed

+85
-10
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ function generate_initializesystem(sys::ODESystem;
1818
eqs_ics = eqs[idxs_alge]
1919
u0 = Vector{Pair}(undef, 0)
2020

21-
full_states = [sts; getfield.((observed(sys)), :lhs)]
21+
eqs_diff = eqs[idxs_diff]
22+
diffmap = Dict(getfield.(eqs_diff,:lhs) .=> getfield.(eqs_diff,:rhs))
23+
24+
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
2225
set_full_states = Set(full_states)
2326
guesses = todict(guesses)
2427
schedule = getfield(sys, :schedule)
@@ -30,10 +33,26 @@ function generate_initializesystem(sys::ODESystem;
3033
if u0map === nothing || isempty(u0map)
3134
filtered_u0 = u0map
3235
else
33-
# TODO: Don't scalarize arrays
34-
filtered_u0 = map(u0map) do x
36+
filtered_u0 = []
37+
for x in u0map
3538
y = get(schedule.dummy_sub, x[1], x[1])
36-
y isa Symbolics.Arr ? collect(x[1]) .=> x[2] : x[1] => x[2]
39+
y = get(diffmap, y, y)
40+
if y isa Symbolics.Arr
41+
_y = collect(y)
42+
43+
# TODO: Don't scalarize arrays
44+
for i in 1:length(_y)
45+
push!(filtered_u0, _y[i] => x[2][i])
46+
end
47+
elseif y isa ModelingToolkit.BasicSymbolic
48+
# y is a derivative expression expanded
49+
# add to the initialization equations
50+
push!(eqs_ics, y ~ x[2])
51+
elseif y set_full_states
52+
push!(filtered_u0, y => x[2])
53+
else
54+
error("Unreachable. Open an issue")
55+
end
3756
end
3857
filtered_u0 = reduce(vcat, filtered_u0)
3958
filtered_u0 = filtered_u0 isa Pair ? todict([filtered_u0]) : todict(filtered_u0)

src/variables.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,23 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
170170
end
171171
end
172172

173+
const MISSING_VARIABLES_MESSAGE = """
174+
Initial condition underdefined. Some are missing from the variable map.
175+
Please provide a default (`u0`), initialization equation, or guess
176+
for the following variables:
177+
"""
178+
179+
struct MissingVariablesError <: Exception
180+
vars::Any
181+
end
182+
183+
function Base.showerror(io::IO, e::MissingVariablesError)
184+
println(io, MISSING_VARIABLES_MESSAGE)
185+
println(io, e.vars)
186+
end
187+
173188
function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false,
174-
toterm = Symbolics.diff2term)
189+
toterm = Symbolics.diff2term, initialization_phase = false)
175190
varmap = merge(defaults, varmap) # prefers the `varmap`
176191
varmap = Dict(toterm(value(k)) => value(varmap[k]) for k in keys(varmap))
177192
# resolve symbolic parameter expressions
@@ -180,7 +195,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
180195
end
181196

182197
missingvars = setdiff(varlist, collect(keys(varmap)))
183-
check && (isempty(missingvars) || throw_missingvars(missingvars))
198+
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
184199

185200
out = [varmap[var] for var in varlist]
186201
end

test/initializationsystem.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,44 @@ prob = ODEProblem(sys, [], (0, 1), guesses = [sys.dx => 1])
377377
sol = solve(prob, Tsit5())
378378
@test SciMLBase.successful_retcode(sol)
379379
@test sol[1] == [1.0]
380+
381+
# Steady state initialization
382+
383+
@parameters σ ρ β
384+
@variables x(t) y(t) z(t)
385+
386+
eqs = [D(D(x)) ~ σ * (y - x),
387+
D(y) ~ x *- z) - y,
388+
D(z) ~ x * y - β * z]
389+
390+
@named sys = ODESystem(eqs, t)
391+
sys = structural_simplify(sys)
392+
393+
u0 = [D(x) => 2.0,
394+
x => 1.0,
395+
D(y) => 0.0,
396+
z => 0.0]
397+
398+
p ==> 28.0,
399+
ρ => 10.0,
400+
β => 8 / 3]
401+
402+
tspan = (0.0, 0.2)
403+
prob_mtk = ODEProblem(sys, u0, tspan, p)
404+
sol = solve(prob_mtk, Tsit5())
405+
@test sol[x *- z) - y][1] == 0.0
406+
407+
@variables x(t) y(t) z(t)
408+
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0
409+
410+
eqs = [D(x) ~ α * x - β * x * y
411+
D(y) ~ -γ * y + δ * x * y
412+
z ~ x + y]
413+
414+
@named sys = ODESystem(eqs, t)
415+
simpsys = structural_simplify(sys)
416+
tspan = (0.0, 10.0)
417+
418+
prob = ODEProblem(simpsys, [D(x) => 0.0, y => 0.0], tspan, guesses = [x => 0.0])
419+
sol = solve(prob, Tsit5())
420+
@test sol[1] == [0.0,0.0]

test/serialization.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ sys = include_string(@__MODULE__, str)
3232
# check answer
3333
ss = structural_simplify(rc_model)
3434
all_obs = [o.lhs for o in observed(ss)]
35-
prob = ODEProblem(ss, [], (0, 0.1))
35+
prob = ODEProblem(ss, [capacitor.v => 0.0], (0, 0.1))
3636
sol = solve(prob, ImplicitEuler())
3737

3838
## Check ODESystem with Observables ----------
3939
ss_exp = ModelingToolkit.toexpr(ss)
4040
ss_ = complete(eval(ss_exp))
41-
prob_ = ODEProblem(ss_, [], (0, 0.1))
41+
prob_ = ODEProblem(ss_, [capacitor.v => 0.0], (0, 0.1))
4242
sol_ = solve(prob_, ImplicitEuler())
4343
@test sol[all_obs] == sol_[all_obs]
4444

@@ -61,8 +61,8 @@ observedfun_exp = :(function (var, u0, p, t)
6161
end)
6262

6363
# ODEProblemExpr with observedfun_exp included
64-
probexpr = ODEProblemExpr{true}(ss, [], (0, 0.1); observedfun_exp);
64+
probexpr = ODEProblemExpr{true}(ss, [capacitor.v => 0.0], (0, 0.1); observedfun_exp);
6565
prob_obs = eval(probexpr)
6666
sol_obs = solve(prob_obs, ImplicitEuler())
6767
@show all_obs
68-
@test sol_obs[all_obs] == sol[all_obs]
68+
@test_broken sol_obs[all_obs] == sol[all_obs]

0 commit comments

Comments
 (0)