Skip to content

Commit 5fd53b8

Browse files
fix: improve performance of linearization
1 parent 001e146 commit 5fd53b8

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,13 @@ function linearization_function(sys::AbstractSystem, inputs,
17791779
else
17801780
p = todict(p)
17811781
end
1782-
p[get_iv(sys)] = 0.0
1782+
x0 = merge(defaults_and_guesses(sys), op)
1783+
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1784+
sys_ps = MTKParameters(sys, p, x0)
1785+
else
1786+
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
1787+
end
1788+
p[get_iv(sys)] = NaN
17831789
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
17841790
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
17851791
initsys_ps = parameters(initsys)
@@ -1812,19 +1818,19 @@ function linearization_function(sys::AbstractSystem, inputs,
18121818
function (u, p, t)
18131819
state = ProblemState(; u, p, t)
18141820
if tunable_getter !== nothing
1815-
oldps = SciMLStructures.replace!(
1821+
SciMLStructures.replace!(
18161822
SciMLStructures.Tunable(), oldps, tunable_getter(state))
18171823
end
18181824
if disc_getter !== nothing
1819-
oldps = SciMLStructures.replace!(
1825+
SciMLStructures.replace!(
18201826
SciMLStructures.Discrete(), oldps, disc_getter(state))
18211827
end
18221828
if const_getter !== nothing
1823-
oldps = SciMLStructures.replace!(
1829+
SciMLStructures.replace!(
18241830
SciMLStructures.Constants(), oldps, const_getter(state))
18251831
end
18261832
if nonnum_getter !== nothing
1827-
oldps = SciMLStructures.replace!(
1833+
SciMLStructures.replace!(
18281834
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
18291835
end
18301836
newu = u_getter(state)
@@ -1843,7 +1849,7 @@ function linearization_function(sys::AbstractSystem, inputs,
18431849
end
18441850
initfn = NonlinearFunction(initsys)
18451851
initprobmap = getu(initsys, unknowns(sys))
1846-
ps = parameters(sys)
1852+
ps = full_parameters(sys)
18471853
lin_fun = let diff_idxs = diff_idxs,
18481854
alge_idxs = alge_idxs,
18491855
input_idxs = input_idxs,
@@ -1854,9 +1860,20 @@ function linearization_function(sys::AbstractSystem, inputs,
18541860
initfn = initfn,
18551861
h = build_explicit_observed_function(sys, outputs),
18561862
chunk = ForwardDiff.Chunk(input_idxs),
1857-
initialize = initialize
1863+
sys_ps = sys_ps,
1864+
initialize = initialize,
1865+
sys = sys
18581866

18591867
function (u, p, t)
1868+
if !isa(p, MTKParameters)
1869+
p = todict(p)
1870+
newps = deepcopy(sys_ps)
1871+
for (k, v) in p
1872+
setp(sys, k)(newps, v)
1873+
end
1874+
p = newps
1875+
end
1876+
18601877
if u !== nothing # Handle systems without unknowns
18611878
length(sts) == length(u) ||
18621879
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
@@ -2137,17 +2154,16 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
21372154
u0, defs = get_u0(sys, x0, p)
21382155
if has_index_cache(sys) && get_index_cache(sys) !== nothing
21392156
if p isa SciMLBase.NullParameters
2140-
p = op
2157+
p = Dict()
21412158
elseif p isa Dict
21422159
p = merge(p, op)
21432160
elseif p isa Vector && eltype(p) <: Pair
21442161
p = merge(Dict(p), op)
21452162
elseif p isa Vector
21462163
p = merge(Dict(parameters(sys) .=> p), op)
21472164
end
2148-
p2 = MTKParameters(sys, p, merge(Dict(unknowns(sys) .=> u0), x0, guesses(sys)))
21492165
end
2150-
linres = lin_fun(u0, p2, t)
2166+
linres = lin_fun(u0, p, t)
21512167
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
21522168

21532169
nx, nu = size(f_u)

0 commit comments

Comments
 (0)