Skip to content

Commit e1befe0

Browse files
Merge pull request #2676 from AayushSabharwal/as/linearization-initialization
fix: create and solve initialization system in linearization_function
2 parents 016c891 + d7fa540 commit e1befe0

14 files changed

+244
-118
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ SparseArrays = "1"
107107
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
108108
StaticArrays = "0.10, 0.11, 0.12, 1.0"
109109
SymbolicIndexingInterface = "0.3.12"
110-
SymbolicUtils = "1.0"
110+
SymbolicUtils = "<1.6"
111111
Symbolics = "5.26"
112112
URIs = "1"
113113
UnPack = "0.1, 1.0"

src/systems/abstractsystem.jl

Lines changed: 114 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,14 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
494494
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
495495
end
496496

497+
function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
498+
return let _fn = build_explicit_observed_function(sys, sym)
499+
fn(u, p, t) = _fn(u, p, t)
500+
fn(u, p::MTKParameters, t) = _fn(u, p..., t)
501+
fn
502+
end
503+
end
504+
497505
function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
498506
return merge(
499507
Dict(eq.lhs => eq.rhs for eq in observed(sys)),
@@ -1020,7 +1028,15 @@ function defaults(sys::AbstractSystem)
10201028
isempty(systems) ? defs : mapfoldr(namespace_defaults, merge, systems; init = defs)
10211029
end
10221030

1031+
function defaults_and_guesses(sys::AbstractSystem)
1032+
merge(guesses(sys), defaults(sys))
1033+
end
1034+
10231035
unknowns(sys::Union{AbstractSystem, Nothing}, v) = renamespace(sys, v)
1036+
for vType in [Symbolics.Arr, Symbolics.Symbolic{<:AbstractArray}]
1037+
@eval unknowns(sys::AbstractSystem, v::$vType) = renamespace(sys, v)
1038+
@eval parameters(sys::AbstractSystem, v::$vType) = toparam(unknowns(sys, v))
1039+
end
10241040
parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(unknowns(sys, v))
10251041
for f in [:unknowns, :parameters]
10261042
@eval function $f(sys::AbstractSystem, vs::AbstractArray)
@@ -1756,34 +1772,117 @@ function linearization_function(sys::AbstractSystem, inputs,
17561772
op = merge(defs, op)
17571773
end
17581774
sys = ssys
1759-
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
1760-
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
1761-
ps = parameters(sys)
1775+
initsys = complete(generate_initializesystem(
1776+
sys, guesses = guesses(sys), algebraic_only = true))
1777+
if p isa SciMLBase.NullParameters
1778+
p = Dict()
1779+
else
1780+
p = todict(p)
1781+
end
1782+
x0 = merge(defaults_and_guesses(sys), op)
17621783
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1763-
p = MTKParameters(sys, p, u0)
1784+
sys_ps = MTKParameters(sys, p, x0)
1785+
else
1786+
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
1787+
end
1788+
p[get_iv(sys)] = NaN
1789+
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
1790+
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
1791+
initsys_ps = parameters(initsys)
1792+
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
1793+
tunable_ps = [initsys_ps[i]
1794+
for i in eachindex(initsys_ps)
1795+
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
1796+
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
1797+
discrete_ps = [initsys_ps[i]
1798+
for i in eachindex(initsys_ps)
1799+
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
1800+
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
1801+
constant_ps = [initsys_ps[i]
1802+
for i in eachindex(initsys_ps)
1803+
if initsys_idxs[i].portion == SciMLStructures.Constants()]
1804+
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
1805+
nonnum_ps = [initsys_ps[i]
1806+
for i in eachindex(initsys_ps)
1807+
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
1808+
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
1809+
u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
1810+
getu(sys, unknowns(initsys))
1811+
get_initprob_u_p = let tunable_getter = tunable_getter,
1812+
disc_getter = disc_getter,
1813+
const_getter = const_getter,
1814+
nonnum_getter = nonnum_getter,
1815+
oldps = oldps,
1816+
u_getter = u_getter
1817+
1818+
function (u, p, t)
1819+
state = ProblemState(; u, p, t)
1820+
if tunable_getter !== nothing
1821+
SciMLStructures.replace!(
1822+
SciMLStructures.Tunable(), oldps, tunable_getter(state))
1823+
end
1824+
if disc_getter !== nothing
1825+
SciMLStructures.replace!(
1826+
SciMLStructures.Discrete(), oldps, disc_getter(state))
1827+
end
1828+
if const_getter !== nothing
1829+
SciMLStructures.replace!(
1830+
SciMLStructures.Constants(), oldps, const_getter(state))
1831+
end
1832+
if nonnum_getter !== nothing
1833+
SciMLStructures.replace!(
1834+
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
1835+
end
1836+
newu = u_getter(state)
1837+
return newu, oldps
1838+
end
1839+
end
17641840
else
1765-
p = _p
1766-
p, split_idxs = split_parameters_by_type(p)
1767-
if p isa Tuple
1768-
ps = Base.Fix1(getindex, ps).(split_idxs)
1769-
ps = (ps...,) #if p is Tuple, ps should be Tuple
1841+
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
1842+
u_getter = getu(sys, unknowns(initsys))
1843+
1844+
function (u, p, t)
1845+
state = ProblemState(; u, p, t)
1846+
return u_getter(state), p_getter(state)
1847+
end
17701848
end
17711849
end
1850+
initfn = NonlinearFunction(initsys)
1851+
initprobmap = getu(initsys, unknowns(sys))
1852+
ps = full_parameters(sys)
17721853
lin_fun = let diff_idxs = diff_idxs,
17731854
alge_idxs = alge_idxs,
17741855
input_idxs = input_idxs,
17751856
sts = unknowns(sys),
1776-
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, unknowns(sys), ps; p = p),
1857+
get_initprob_u_p = get_initprob_u_p,
1858+
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
1859+
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
1860+
initfn = initfn,
17771861
h = build_explicit_observed_function(sys, outputs),
1778-
chunk = ForwardDiff.Chunk(input_idxs)
1862+
chunk = ForwardDiff.Chunk(input_idxs),
1863+
sys_ps = sys_ps,
1864+
initialize = initialize,
1865+
sys = sys
17791866

17801867
function (u, p, t)
1868+
if !isa(p, MTKParameters)
1869+
p = todict(p)
1870+
newps = deepcopy(sys_ps)
1871+
for (k, v) in p
1872+
setp(sys, k)(newps, v)
1873+
end
1874+
p = newps
1875+
end
1876+
17811877
if u !== nothing # Handle systems without unknowns
17821878
length(sts) == length(u) ||
17831879
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
17841880
if initialize && !isempty(alge_idxs) # This is expensive and can be omitted if the user knows that the system is already initialized
17851881
residual = fun(u, p, t)
17861882
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
1883+
initu0, initp = get_initprob_u_p(u, p, t)
1884+
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
1885+
@set! fun.initializeprob = initprob
17871886
prob = ODEProblem(fun, u, (t, t + 1), p)
17881887
integ = init(prob, OrdinaryDiffEq.Rodas5P())
17891888
u = integ.u
@@ -2051,21 +2150,20 @@ lsys_sym, _ = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])
20512150
"""
20522151
function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives = false,
20532152
p = DiffEqBase.NullParameters())
2054-
x0 = merge(defaults(sys), op)
2055-
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
2153+
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
2154+
u0, defs = get_u0(sys, x0, p)
20562155
if has_index_cache(sys) && get_index_cache(sys) !== nothing
20572156
if p isa SciMLBase.NullParameters
2058-
p = op
2157+
p = Dict()
20592158
elseif p isa Dict
20602159
p = merge(p, op)
20612160
elseif p isa Vector && eltype(p) <: Pair
20622161
p = merge(Dict(p), op)
20632162
elseif p isa Vector
20642163
p = merge(Dict(parameters(sys) .=> p), op)
20652164
end
2066-
p2 = MTKParameters(sys, p, Dict(unknowns(sys) .=> u0))
20672165
end
2068-
linres = lin_fun(u0, p2, t)
2166+
linres = lin_fun(u0, p, t)
20692167
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
20702168

20712169
nx, nu = size(f_u)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,10 +1632,16 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16321632
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
16331633
[get_iv(sys) => t] :
16341634
merge(todict(parammap), Dict(get_iv(sys) => t))
1635-
1635+
if isempty(u0map)
1636+
u0map = Dict()
1637+
end
1638+
if isempty(guesses)
1639+
guesses = Dict()
1640+
end
1641+
u0map = merge(todict(guesses), todict(u0map))
16361642
if neqs == nunknown
1637-
NonlinearProblem(isys, guesses, parammap)
1643+
NonlinearProblem(isys, u0map, parammap)
16381644
else
1639-
NonlinearLeastSquaresProblem(isys, guesses, parammap)
1645+
NonlinearLeastSquaresProblem(isys, u0map, parammap)
16401646
end
16411647
end

src/systems/index_cache.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ function BufferTemplate(s::Type{<:Symbolics.Struct}, length::Int)
88
BufferTemplate(T, length)
99
end
1010

11-
const DEPENDENT_PORTION = :dependent
12-
const NONNUMERIC_PORTION = :nonnumeric
11+
struct Dependent <: SciMLStructures.AbstractPortion end
12+
struct Nonnumeric <: SciMLStructures.AbstractPortion end
13+
const DEPENDENT_PORTION = Dependent()
14+
const NONNUMERIC_PORTION = Nonnumeric()
1315

1416
struct ParameterIndex{P, I}
1517
portion::P

src/systems/nonlinear/initializesystem.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ function generate_initializesystem(sys::ODESystem;
88
name = nameof(sys),
99
guesses = Dict(), check_defguess = false,
1010
default_dd_value = 0.0,
11+
algebraic_only = false,
1112
kwargs...)
1213
sts, eqs = unknowns(sys), equations(sys)
1314
idxs_diff = isdiffeq.(eqs)
@@ -68,28 +69,34 @@ function generate_initializesystem(sys::ODESystem;
6869
defs = merge(defaults(sys), filtered_u0)
6970
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
7071

71-
for st in full_states
72-
if st keys(defs)
73-
def = defs[st]
72+
if !algebraic_only
73+
for st in full_states
74+
if st keys(defs)
75+
def = defs[st]
7476

75-
if def isa Equation
76-
st keys(guesses) && check_defguess &&
77-
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
78-
push!(eqs_ics, def)
77+
if def isa Equation
78+
st keys(guesses) && check_defguess &&
79+
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
80+
push!(eqs_ics, def)
81+
push!(u0, st => guesses[st])
82+
else
83+
push!(eqs_ics, st ~ def)
84+
push!(u0, st => def)
85+
end
86+
elseif st keys(guesses)
7987
push!(u0, st => guesses[st])
80-
else
81-
push!(eqs_ics, st ~ def)
82-
push!(u0, st => def)
88+
elseif check_defguess
89+
error("Invalid setup: unknown $(st) has no default value or initial guess")
8390
end
84-
elseif st keys(guesses)
85-
push!(u0, st => guesses[st])
86-
elseif check_defguess
87-
error("Invalid setup: unknown $(st) has no default value or initial guess")
8891
end
8992
end
9093

9194
pars = [parameters(sys); get_iv(sys)]
92-
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]
95+
nleqs = if algebraic_only
96+
[eqs_ics; observed(sys)]
97+
else
98+
[eqs_ics; get_initialization_eqs(sys); observed(sys)]
99+
end
93100

94101
sys_nl = NonlinearSystem(nleqs,
95102
full_states,

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function _update_tuple_helper(buf_v::T, raw, idx) where {T}
190190
end
191191

192192
function _update_tuple_helper(::Type{<:AbstractArray}, buf_v, raw, idx)
193-
ntuple(i -> _update_tuple_helper(buf_v[i], raw, idx), Val(length(buf_v)))
193+
ntuple(i -> _update_tuple_helper(buf_v[i], raw, idx), length(buf_v))
194194
end
195195

196196
function _update_tuple_helper(::Any, buf_v, raw, idx)
@@ -210,7 +210,8 @@ SciMLStructures.ismutablescimlstructure(::MTKParameters) = true
210210

211211
for (Portion, field) in [(SciMLStructures.Tunable, :tunable)
212212
(SciMLStructures.Discrete, :discrete)
213-
(SciMLStructures.Constants, :constant)]
213+
(SciMLStructures.Constants, :constant)
214+
(Nonnumeric, :nonnumeric)]
214215
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
215216
as_vector = buffer_to_arraypartition(p.$field)
216217
repack = let as_vector = as_vector, p = p

test/downstream/inversemodel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
150150
p = ModelingToolkit.MTKParameters(simplified_sys, op)
151151
matrices1 = Sf(x, p, 0)
152152
matrices2, _ = Blocks.get_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
153-
@test matrices1.f_x matrices2.A[1:7, 1:7]
153+
@test_broken matrices1.f_x matrices2.A[1:7, 1:7]
154154
nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
155155
@test matrices2.A nsys.A
156156

@@ -161,6 +161,6 @@ x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
161161
p = ModelingToolkit.MTKParameters(simplified_sys, op)
162162
matrices1 = Sf(x, p, 0)
163163
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
164-
@test matrices1.f_x matrices2.A[1:7, 1:7]
164+
@test_broken matrices1.f_x matrices2.A[1:7, 1:7]
165165
nsys = get_named_comp_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
166166
@test matrices2.A nsys.A

test/downstream/linearization_dd.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
## Test that dummy_derivatives can be set to zero
2+
# The call to Link(; m = 0.2, l = 10, I = 1, g = -9.807) hangs forever on Julia v1.6
3+
using LinearAlgebra
4+
using ModelingToolkit
5+
using ModelingToolkitStandardLibrary
6+
using ModelingToolkitStandardLibrary.Blocks
7+
using ModelingToolkitStandardLibrary.Mechanical.MultiBody2D
8+
using ModelingToolkitStandardLibrary.Mechanical.TranslationalPosition
9+
using Test
10+
11+
using ControlSystemsMTK
12+
using ControlSystemsMTK.ControlSystemsBase: sminreal, minreal, poles
13+
connect = ModelingToolkit.connect
14+
15+
@parameters t
16+
D = Differential(t)
17+
18+
@named link1 = Link(; m = 0.2, l = 10, I = 1, g = -9.807)
19+
@named cart = TranslationalPosition.Mass(; m = 1, s = 0)
20+
@named fixed = Fixed()
21+
@named force = Force(use_support = false)
22+
23+
eqs = [connect(link1.TX1, cart.flange)
24+
connect(cart.flange, force.flange)
25+
connect(link1.TY1, fixed.flange)]
26+
27+
@named model = ODESystem(eqs, t, [], []; systems = [link1, cart, force, fixed])
28+
def = ModelingToolkit.defaults(model)
29+
def[cart.s] = 10
30+
def[cart.v] = 0
31+
def[link1.A] = -pi / 2
32+
def[link1.dA] = 0
33+
lin_outputs = [cart.s, cart.v, link1.A, link1.dA]
34+
lin_inputs = [force.f.u]
35+
36+
@test_broken begin
37+
@info "named_ss"
38+
G = named_ss(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
39+
allow_input_derivatives = true, zero_dummy_der = true)
40+
G = sminreal(G)
41+
@info "minreal"
42+
G = minreal(G)
43+
@info "poles"
44+
ps = poles(G)
45+
46+
@test minimum(abs, ps) < 1e-6
47+
@test minimum(abs, complex(0, 1.3777260367206716) .- ps) < 1e-10
48+
49+
lsys, syss = linearize(model, lin_inputs, lin_outputs, allow_symbolic = true, op = def,
50+
allow_input_derivatives = true, zero_dummy_der = true)
51+
lsyss, sysss = ModelingToolkit.linearize_symbolic(model, lin_inputs, lin_outputs;
52+
allow_input_derivatives = true)
53+
54+
dummyder = setdiff(unknowns(sysss), unknowns(model))
55+
def = merge(ModelingToolkit.guesses(model), def, Dict(x => 0.0 for x in dummyder))
56+
def[link1.fy1] = -def[link1.g] * def[link1.m]
57+
58+
@test substitute(lsyss.A, def) lsys.A
59+
# We cannot pivot symbolically, so the part where a linear solve is required
60+
# is not reliable.
61+
@test substitute(lsyss.B, def)[1:6, 1] lsys.B[1:6, 1]
62+
@test substitute(lsyss.C, def) lsys.C
63+
@test substitute(lsyss.D, def) lsys.D
64+
end

0 commit comments

Comments
 (0)