Skip to content

Commit 490b4bc

Browse files
fix: properly initialize initialization problem in linearization_function
1 parent bb641ee commit 490b4bc

File tree

4 files changed

+135
-66
lines changed

4 files changed

+135
-66
lines changed

src/systems/abstractsystem.jl

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,18 +1774,87 @@ function linearization_function(sys::AbstractSystem, inputs,
17741774
sys = ssys
17751775
initsys = complete(generate_initializesystem(
17761776
sys, guesses = guesses(sys), algebraic_only = true))
1777+
if p isa SciMLBase.NullParameters
1778+
p = Dict()
1779+
else
1780+
p = todict(p)
1781+
end
1782+
p[get_iv(sys)] = 0.0
1783+
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
1784+
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
1785+
initsys_ps = parameters(initsys)
1786+
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
1787+
tunable_ps = [initsys_ps[i]
1788+
for i in eachindex(initsys_ps)
1789+
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
1790+
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
1791+
discrete_ps = [initsys_ps[i]
1792+
for i in eachindex(initsys_ps)
1793+
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
1794+
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
1795+
constant_ps = [initsys_ps[i]
1796+
for i in eachindex(initsys_ps)
1797+
if initsys_idxs[i].portion == SciMLStructures.Constants()]
1798+
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
1799+
nonnum_ps = [initsys_ps[i]
1800+
for i in eachindex(initsys_ps)
1801+
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
1802+
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
1803+
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
1804+
getu(sys, unknowns(initsys))
1805+
get_initprob_u_p = let tunable_getter = tunable_getter,
1806+
disc_getter = disc_getter,
1807+
const_getter = const_getter,
1808+
nonnum_getter = nonnum_getter,
1809+
oldps = oldps,
1810+
u_getter = u_getter
1811+
1812+
function (u, p, t)
1813+
state = ProblemState(; u, p, t)
1814+
if tunable_getter !== nothing
1815+
oldps = SciMLStructures.replace!(
1816+
SciMLStructures.Tunable(), oldps, tunable_getter(state))
1817+
end
1818+
if disc_getter !== nothing
1819+
oldps = SciMLStructures.replace!(
1820+
SciMLStructures.Discrete(), oldps, disc_getter(state))
1821+
end
1822+
if const_getter !== nothing
1823+
oldps = SciMLStructures.replace!(
1824+
SciMLStructures.Constants(), oldps, const_getter(state))
1825+
end
1826+
if nonnum_getter !== nothing
1827+
oldps = SciMLStructures.replace!(
1828+
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
1829+
end
1830+
newu = u_getter(state)
1831+
return newu, oldps
1832+
end
1833+
end
1834+
else
1835+
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
1836+
u_getter = getu(sys, unknowns(initsys))
1837+
1838+
function (u, p, t)
1839+
state = ProblemState(; u, p, t)
1840+
return u_getter(state), p_getter(state)
1841+
end
1842+
end
1843+
end
17771844
initfn = NonlinearFunction(initsys)
17781845
initprobmap = getu(initsys, unknowns(sys))
17791846
ps = parameters(sys)
17801847
lin_fun = let diff_idxs = diff_idxs,
17811848
alge_idxs = alge_idxs,
17821849
input_idxs = input_idxs,
17831850
sts = unknowns(sys),
1851+
get_initprob_u_p = get_initprob_u_p,
17841852
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
17851853
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
17861854
initfn = initfn,
17871855
h = build_explicit_observed_function(sys, outputs),
1788-
chunk = ForwardDiff.Chunk(input_idxs)
1856+
chunk = ForwardDiff.Chunk(input_idxs),
1857+
initialize = initialize
17891858

17901859
function (u, p, t)
17911860
if u !== nothing # Handle systems without unknowns
@@ -1794,7 +1863,8 @@ function linearization_function(sys::AbstractSystem, inputs,
17941863
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
17951864
residual = fun(u, p, t)
17961865
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
1797-
initprob = NonlinearLeastSquaresProblem(initfn, u, p)
1866+
initu0, initp = get_initprob_u_p(u, p, t)
1867+
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
17981868
@set! fun.initializeprob = initprob
17991869
prob = ODEProblem(fun, u, (t, t + 1), p)
18001870
integ = init(prob, OrdinaryDiffEq.Rodas5P())

test/downstream/linearization_dd.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
## Test that dummy_derivatives can be set to zero
2+
# The call to Link(; m = 0.2, l = 10, I = 1, g = -9.807) hangs forever on Julia v1.6
3+
using LinearAlgebra
4+
using ModelingToolkit
5+
using ModelingToolkitStandardLibrary
6+
using ModelingToolkitStandardLibrary.Blocks
7+
using ModelingToolkitStandardLibrary.Mechanical.MultiBody2D
8+
using ModelingToolkitStandardLibrary.Mechanical.TranslationalPosition
9+
using Test
10+
11+
using ControlSystemsMTK
12+
using ControlSystemsMTK.ControlSystemsBase: sminreal, minreal, poles
13+
connect = ModelingToolkit.connect
14+
15+
@parameters t
16+
D = Differential(t)
17+
18+
@named link1 = Link(; m = 0.2, l = 10, I = 1, g = -9.807)
19+
@named cart = TranslationalPosition.Mass(; m = 1, s = 0)
20+
@named fixed = Fixed()
21+
@named force = Force(use_support = false)
22+
23+
eqs = [connect(link1.TX1, cart.flange)
24+
connect(cart.flange, force.flange)
25+
connect(link1.TY1, fixed.flange)]
26+
27+
@named model = ODESystem(eqs, t, [], []; systems = [link1, cart, force, fixed])
28+
def = ModelingToolkit.defaults(model)
29+
def[cart.s] = 10
30+
def[cart.v] = 0
31+
def[link1.A] = -pi / 2
32+
def[link1.dA] = 0
33+
lin_outputs = [cart.s, cart.v, link1.A, link1.dA]
34+
lin_inputs = [force.f.u]
35+
36+
@info "named_ss"
37+
G = named_ss(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
38+
allow_input_derivatives = true, zero_dummy_der = true)
39+
G = sminreal(G)
40+
@info "minreal"
41+
G = minreal(G)
42+
@info "poles"
43+
ps = poles(G)
44+
45+
@test minimum(abs, ps) < 1e-6
46+
@test minimum(abs, complex(0, 1.3777260367206716) .- ps) < 1e-10
47+
48+
lsys, syss = linearize(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
49+
allow_input_derivatives = true, zero_dummy_der = true)
50+
lsyss, sysss = ModelingToolkit.linearize_symbolic(model, lin_inputs, lin_outputs;
51+
allow_input_derivatives = true)
52+
53+
dummyder = setdiff(unknowns(sysss), unknowns(model))
54+
def = merge(ModelingToolkit.guesses(model), def, Dict(x => 0.0 for x in dummyder))
55+
def[link1.fy1] = -def[link1.g] * def[link1.m]
56+
57+
@test substitute(lsyss.A, def) lsys.A
58+
# We cannot pivot symbolically, so the part where a linear solve is required
59+
# is not reliable.
60+
@test substitute(lsyss.B, def)[1:6, 1] lsys.B[1:6, 1]
61+
@test substitute(lsyss.C, def) lsys.C
62+
@test substitute(lsyss.D, def) lsys.D

test/downstream/linearize.jl

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -195,67 +195,3 @@ lsys, ssys = linearize(sat, [u], [y]; op = Dict(u => 2))
195195
@test isempty(lsys.B)
196196
@test isempty(lsys.C)
197197
@test lsys.D[] == 0
198-
199-
## Test that dummy_derivatives can be set to zero
200-
if VERSION >= v"1.8"
201-
# The call to Link(; m = 0.2, l = 10, I = 1, g = -9.807) hangs forever on Julia v1.6
202-
using LinearAlgebra
203-
using ModelingToolkit
204-
using ModelingToolkitStandardLibrary
205-
using ModelingToolkitStandardLibrary.Blocks
206-
using ModelingToolkitStandardLibrary.Mechanical.MultiBody2D
207-
using ModelingToolkitStandardLibrary.Mechanical.TranslationalPosition
208-
209-
using ControlSystemsMTK
210-
using ControlSystemsMTK.ControlSystemsBase: sminreal, minreal, poles
211-
connect = ModelingToolkit.connect
212-
213-
@parameters t
214-
D = Differential(t)
215-
216-
@named link1 = Link(; m = 0.2, l = 10, I = 1, g = -9.807)
217-
@named cart = TranslationalPosition.Mass(; m = 1, s = 0)
218-
@named fixed = Fixed()
219-
@named force = Force(use_support = false)
220-
221-
eqs = [connect(link1.TX1, cart.flange)
222-
connect(cart.flange, force.flange)
223-
connect(link1.TY1, fixed.flange)]
224-
225-
@named model = ODESystem(eqs, t, [], []; systems = [link1, cart, force, fixed])
226-
def = ModelingToolkit.defaults(model)
227-
def[cart.s] = 10
228-
def[cart.v] = 0
229-
def[link1.A] = -pi / 2
230-
def[link1.dA] = 0
231-
lin_outputs = [cart.s, cart.v, link1.A, link1.dA]
232-
lin_inputs = [force.f.u]
233-
234-
@info "named_ss"
235-
G = named_ss(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
236-
allow_input_derivatives = true, zero_dummy_der = true)
237-
G = sminreal(G)
238-
@info "minreal"
239-
G = minreal(G)
240-
@info "poles"
241-
ps = poles(G)
242-
243-
@test minimum(abs, ps) < 1e-6
244-
@test minimum(abs, complex(0, 1.3777260367206716) .- ps) < 1e-10
245-
246-
lsys, syss = linearize(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
247-
allow_input_derivatives = true, zero_dummy_der = true)
248-
lsyss, sysss = ModelingToolkit.linearize_symbolic(model, lin_inputs, lin_outputs;
249-
allow_input_derivatives = true)
250-
251-
dummyder = setdiff(unknowns(sysss), unknowns(model))
252-
def = merge(def, Dict(x => 0.0 for x in dummyder))
253-
def[link1.fy1] = -def[link1.g] * def[link1.m]
254-
255-
@test substitute(lsyss.A, def) lsys.A
256-
# We cannot pivot symbolically, so the part where a linear solve is required
257-
# is not reliable.
258-
@test substitute(lsyss.B, def)[1:6, 1] lsys.B[1:6, 1]
259-
@test substitute(lsyss.C, def) lsys.C
260-
@test substitute(lsyss.D, def) lsys.D
261-
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ end
9292
if GROUP == "All" || GROUP == "Downstream"
9393
activate_downstream_env()
9494
@safetestset "Linearization Tests" include("downstream/linearize.jl")
95+
@safetestset "Linearization Dummy Derivative Tests" include("downstream/linearization_dd.jl")
9596
@safetestset "Inverse Models Test" include("downstream/inversemodel.jl")
9697
end
9798

0 commit comments

Comments
 (0)