Skip to content

Commit 1b17f09

Browse files
fixup! fix: create and solve initialization system in linearization_function
1 parent 4fca27d commit 1b17f09

File tree

2 files changed

+22
-93
lines changed

2 files changed

+22
-93
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,8 +1760,8 @@ function linearization_function(sys::AbstractSystem, inputs,
17601760
op = merge(defs, op)
17611761
end
17621762
sys = ssys
1763-
initsys = complete(generate_algebraic_initializesystem(
1764-
sys, guesses = guesses(sys)))
1763+
initsys = complete(generate_initializesystem(
1764+
sys, guesses = guesses(sys), algebraic_only = true))
17651765
initfn = NonlinearFunction(initsys)
17661766
initprobmap = getu(initsys, unknowns(sys))
17671767
ps = parameters(sys)

src/systems/nonlinear/initializesystem.jl

Lines changed: 20 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function generate_initializesystem(sys::ODESystem;
88
name = nameof(sys),
99
guesses = Dict(), check_defguess = false,
1010
default_dd_value = 0.0,
11+
algebraic_only = false,
1112
kwargs...)
1213
sts, eqs = unknowns(sys), equations(sys)
1314
idxs_diff = isdiffeq.(eqs)
@@ -68,107 +69,35 @@ function generate_initializesystem(sys::ODESystem;
6869
defs = merge(defaults(sys), filtered_u0)
6970
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
7071

71-
for st in full_states
72-
if st keys(defs)
73-
def = defs[st]
72+
if !algebraic_only
73+
for st in full_states
74+
if st keys(defs)
75+
def = defs[st]
7476

75-
if def isa Equation
76-
st keys(guesses) && check_defguess &&
77-
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
78-
push!(eqs_ics, def)
77+
if def isa Equation
78+
st keys(guesses) && check_defguess &&
79+
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
80+
push!(eqs_ics, def)
81+
push!(u0, st => guesses[st])
82+
else
83+
push!(eqs_ics, st ~ def)
84+
push!(u0, st => def)
85+
end
86+
elseif st keys(guesses)
7987
push!(u0, st => guesses[st])
80-
else
81-
push!(eqs_ics, st ~ def)
82-
push!(u0, st => def)
88+
elseif check_defguess
89+
error("Invalid setup: unknown $(st) has no default value or initial guess")
8390
end
84-
elseif st keys(guesses)
85-
push!(u0, st => guesses[st])
86-
elseif check_defguess
87-
error("Invalid setup: unknown $(st) has no default value or initial guess")
8891
end
8992
end
9093

9194
pars = [parameters(sys); get_iv(sys)]
92-
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
93-
94-
sys_nl = NonlinearSystem(nleqs,
95-
full_states,
96-
pars;
97-
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
98-
name,
99-
kwargs...)
100-
101-
return sys_nl
102-
end
103-
104-
function generate_algebraic_initializesystem(sys::ODESystem;
105-
u0map = Dict(),
106-
name = nameof(sys),
107-
guesses = Dict(), check_defguess = false,
108-
default_dd_value = 0.0,
109-
kwargs...)
110-
sts, eqs = unknowns(sys), equations(sys)
111-
idxs_diff = isdiffeq.(eqs)
112-
idxs_alge = .!idxs_diff
113-
num_alge = sum(idxs_alge)
114-
115-
# Start the equations list with algebraic equations
116-
eqs_ics = eqs[idxs_alge]
117-
u0 = Vector{Pair}(undef, 0)
118-
119-
eqs_diff = eqs[idxs_diff]
120-
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
121-
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
122-
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
123-
124-
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
125-
set_full_states = Set(full_states)
126-
guesses = todict(guesses)
127-
schedule = getfield(sys, :schedule)
128-
129-
if schedule !== nothing
130-
guessmap = [x[2] => get(guesses, x[1], default_dd_value)
131-
for x in schedule.dummy_sub]
132-
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
133-
if u0map === nothing || isempty(u0map)
134-
filtered_u0 = u0map
135-
else
136-
filtered_u0 = Pair[]
137-
for x in u0map
138-
y = get(schedule.dummy_sub, x[1], x[1])
139-
y = ModelingToolkit.fixpoint_sub(y, observed_diffmap)
140-
y = get(diffmap, y, y)
141-
142-
if y isa Symbolics.Arr
143-
_y = collect(y)
144-
145-
# TODO: Don't scalarize arrays
146-
for i in 1:length(_y)
147-
push!(filtered_u0, _y[i] => x[2][i])
148-
end
149-
elseif y isa ModelingToolkit.BasicSymbolic
150-
# y is a derivative expression expanded
151-
# add to the initialization equations
152-
push!(eqs_ics, y ~ x[2])
153-
elseif y set_full_states
154-
push!(filtered_u0, y => x[2])
155-
else
156-
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
157-
end
158-
end
159-
filtered_u0 = filtered_u0 isa Pair ? todict([filtered_u0]) : todict(filtered_u0)
160-
end
95+
nleqs = if algebraic_only
96+
eqs_ics
16197
else
162-
dd_guess = Dict()
163-
filtered_u0 = todict(u0map)
98+
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
16499
end
165100

166-
defs = merge(defaults(sys), filtered_u0)
167-
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
168-
169-
pars = [parameters(sys); get_iv(sys)]
170-
nleqs = eqs_ics
171-
172101
sys_nl = NonlinearSystem(nleqs,
173102
full_states,
174103
pars;

0 commit comments

Comments
 (0)