Skip to content

Commit 84396da

Browse files
fix: create and solve initialization problem in linearization_function
1 parent f96fc45 commit 84396da

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,24 +1756,18 @@ function linearization_function(sys::AbstractSystem, inputs,
17561756
op = merge(defs, op)
17571757
end
17581758
sys = ssys
1759-
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
1760-
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1759+
initsys = complete(generate_initializesystem(
1760+
sys, guesses = guesses(sys), algebraic_only = true))
1761+
initfn = NonlinearFunction(initsys)
1762+
initprobmap = getu(initsys, unknowns(sys))
17611763
ps = parameters(sys)
1762-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1763-
p = MTKParameters(sys, p, u0)
1764-
else
1765-
p = _p
1766-
p, split_idxs = split_parameters_by_type(p)
1767-
if p isa Tuple
1768-
ps = Base.Fix1(getindex, ps).(split_idxs)
1769-
ps = (ps...,) #if p is Tuple, ps should be Tuple
1770-
end
1771-
end
17721764
lin_fun = let diff_idxs = diff_idxs,
17731765
alge_idxs = alge_idxs,
17741766
input_idxs = input_idxs,
17751767
sts = unknowns(sys),
1776-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, unknowns(sys), ps; p = p),
1768+
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
1769+
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
1770+
initfn = initfn,
17771771
h = build_explicit_observed_function(sys, outputs),
17781772
chunk = ForwardDiff.Chunk(input_idxs)
17791773

@@ -1782,6 +1776,8 @@ function linearization_function(sys::AbstractSystem, inputs,
17821776
length(sts) == length(u) ||
17831777
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
17841778
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
1779+
initprob = NonlinearLeastSquaresProblem(initfn, u, p)
1780+
@set! fun.initializeprob = initprob
17851781
residual = fun(u, p, t)
17861782
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
17871783
prob = ODEProblem(fun, u, (t, t + 1), p)
@@ -2051,8 +2047,8 @@ lsys_sym, _ = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])
20512047
"""
20522048
function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives = false,
20532049
p = DiffEqBase.NullParameters())
2054-
x0 = merge(defaults(sys), op)
2055-
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
2050+
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
2051+
u0, defs = get_u0(sys, x0, p)
20562052
if has_index_cache(sys) && get_index_cache(sys) !== nothing
20572053
if p isa SciMLBase.NullParameters
20582054
p = op
@@ -2063,7 +2059,7 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
20632059
elseif p isa Vector
20642060
p = merge(Dict(parameters(sys) .=> p), op)
20652061
end
2066-
p2 = MTKParameters(sys, p, Dict(unknowns(sys) .=> u0))
2062+
p2 = MTKParameters(sys, p, merge(Dict(unknowns(sys) .=> u0), x0, guesses(sys)))
20672063
end
20682064
linres = lin_fun(u0, p2, t)
20692065
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres

src/systems/nonlinear/initializesystem.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function generate_initializesystem(sys::ODESystem;
88
name = nameof(sys),
99
guesses = Dict(), check_defguess = false,
1010
default_dd_value = 0.0,
11+
algebraic_only = false,
1112
kwargs...)
1213
sts, eqs = unknowns(sys), equations(sys)
1314
idxs_diff = isdiffeq.(eqs)
@@ -68,28 +69,34 @@ function generate_initializesystem(sys::ODESystem;
6869
defs = merge(defaults(sys), filtered_u0)
6970
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
7071

71-
for st in full_states
72-
if st keys(defs)
73-
def = defs[st]
72+
if !algebraic_only
73+
for st in full_states
74+
if st keys(defs)
75+
def = defs[st]
7476

75-
if def isa Equation
76-
st keys(guesses) && check_defguess &&
77-
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
78-
push!(eqs_ics, def)
77+
if def isa Equation
78+
st keys(guesses) && check_defguess &&
79+
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
80+
push!(eqs_ics, def)
81+
push!(u0, st => guesses[st])
82+
else
83+
push!(eqs_ics, st ~ def)
84+
push!(u0, st => def)
85+
end
86+
elseif st keys(guesses)
7987
push!(u0, st => guesses[st])
80-
else
81-
push!(eqs_ics, st ~ def)
82-
push!(u0, st => def)
88+
elseif check_defguess
89+
error("Invalid setup: unknown $(st) has no default value or initial guess")
8390
end
84-
elseif st keys(guesses)
85-
push!(u0, st => guesses[st])
86-
elseif check_defguess
87-
error("Invalid setup: unknown $(st) has no default value or initial guess")
8891
end
8992
end
9093

9194
pars = [parameters(sys); get_iv(sys)]
92-
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
95+
nleqs = if algebraic_only
96+
[eqs_ics; observed(sys)]
97+
else
98+
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
99+
end
93100

94101
sys_nl = NonlinearSystem(nleqs,
95102
full_states,

test/downstream/linearize.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,14 @@ lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
120120
lsyss, _ = ModelingToolkit.linearize_symbolic(pid, [reference.u, measurement.u],
121121
[ctr_output.u])
122122

123-
@test substitute(lsyss.A, ModelingToolkit.defaults(pid)) == lsys.A
124-
@test substitute(lsyss.B, ModelingToolkit.defaults(pid)) == lsys.B
125-
@test substitute(lsyss.C, ModelingToolkit.defaults(pid)) == lsys.C
126-
@test substitute(lsyss.D, ModelingToolkit.defaults(pid)) == lsys.D
123+
@test substitute(
124+
lsyss.A, merge(ModelingToolkit.defaults(pid), ModelingToolkit.guesses(pid))) == lsys.A
125+
@test substitute(
126+
lsyss.B, merge(ModelingToolkit.defaults(pid), ModelingToolkit.guesses(pid))) == lsys.B
127+
@test substitute(
128+
lsyss.C, merge(ModelingToolkit.defaults(pid), ModelingToolkit.guesses(pid))) == lsys.C
129+
@test substitute(
130+
lsyss.D, merge(ModelingToolkit.defaults(pid), ModelingToolkit.guesses(pid))) == lsys.D
127131

128132
# Test with the reverse desired unknown order as well to verify that similarity transform and reoreder_unknowns really works
129133
lsys = ModelingToolkit.reorder_unknowns(lsys, unknowns(ssys), reverse(desired_order))

test/input_output_handling.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ if VERSION >= v"1.8" # :opaque_closure not supported before
144144
drop_expr = identity)
145145
x = randn(size(A, 1))
146146
u = randn(size(B, 2))
147-
p = getindex.(Ref(ModelingToolkit.defaults(ssys)), parameters(ssys))
147+
p = getindex.(
148+
Ref(merge(ModelingToolkit.defaults(ssys), ModelingToolkit.guesses(ssys))),
149+
parameters(ssys))
148150
y1 = obsf(x, u, p, 0)
149151
y2 = C * x + D * u
150152
@test y1[] y2[]

0 commit comments

Comments
 (0)