Skip to content

Commit 090a087

Browse files
Merge pull request #2151 from SciML/overrideinit
Allow for using the initializeprob with OverrideInit
2 parents ad86386 + 19feb17 commit 090a087

File tree

6 files changed

+86
-27
lines changed

6 files changed

+86
-27
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ Logging = "1.9"
6565
MacroTools = "0.5"
6666
MuladdMacro = "0.2.1"
6767
NLsolve = "4"
68-
NonlinearSolve = "3.3"
68+
NonlinearSolve = "3.7.3"
6969
Polyester = "0.7"
7070
PreallocationTools = "0.4.15"
7171
PrecompileTools = "1"
7272
Preferences = "1.3"
7373
RecursiveArrayTools = "2.36, 3"
7474
Reexport = "1.0"
75-
SciMLBase = "2.26"
75+
SciMLBase = "2.27.1"
7676
SciMLOperators = "0.3"
7777
SimpleNonlinearSolve = "1"
7878
SimpleUnPack = "1"

src/alg_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ end
291291
_alg_autodiff(::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
292292
_alg_autodiff(::DAEAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
293293
_alg_autodiff(::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
294+
_alg_autodiff(alg::CompositeAlgorithm) = _alg_autodiff(alg.algs[2])
294295
function _alg_autodiff(::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
295296
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}
296297
}) where {

src/initialize_dae.jl

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,34 @@ end
2121
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)
2222

2323
default_nlsolve(alg, isinplace, u, autodiff = false) = alg
24-
function default_nlsolve(::Nothing, isinplace, u, autodiff = false)
25-
TrustRegion(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
24+
function default_nlsolve(::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
25+
FastShortcutNonlinearPolyalg(;
26+
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
2627
end
27-
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray, autodiff = false)
28+
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
29+
::NonlinearProblem, autodiff = false)
2830
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
2931
end
3032

33+
function default_nlsolve(
34+
::Nothing, isinplace, u, ::NonlinearLeastSquaresProblem, autodiff = false)
35+
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
36+
end
37+
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
38+
::NonlinearLeastSquaresProblem, autodiff = false)
39+
SimpleGaussNewton(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
40+
end
41+
42+
struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
43+
abstol::T
44+
nlsolve::F
45+
end
46+
47+
function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
48+
OverrideInit(abstol, nlsolve)
49+
end
50+
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)
51+
3152
## Notes
3253

3354
#=
@@ -54,19 +75,32 @@ end
5475

5576
function _initialize_dae!(integrator, prob::ODEProblem,
5677
alg::DefaultInit, x::Val{true})
57-
_initialize_dae!(integrator, prob,
58-
BrownFullBasicInit(integrator.opts.abstol), x)
78+
if SciMLBase.has_initializeprob(prob.f)
79+
_initialize_dae!(integrator, prob,
80+
OverrideInit(integrator.opts.abstol), x)
81+
else
82+
_initialize_dae!(integrator, prob,
83+
BrownFullBasicInit(integrator.opts.abstol), x)
84+
end
5985
end
6086

6187
function _initialize_dae!(integrator, prob::ODEProblem,
6288
alg::DefaultInit, x::Val{false})
63-
_initialize_dae!(integrator, prob,
64-
BrownFullBasicInit(integrator.opts.abstol), x)
89+
if SciMLBase.has_initializeprob(prob.f)
90+
_initialize_dae!(integrator, prob,
91+
OverrideInit(integrator.opts.abstol), x)
92+
else
93+
_initialize_dae!(integrator, prob,
94+
BrownFullBasicInit(integrator.opts.abstol), x)
95+
end
6596
end
6697

6798
function _initialize_dae!(integrator, prob::DAEProblem,
6899
alg::DefaultInit, x::Val{false})
69-
if prob.differential_vars === nothing
100+
if SciMLBase.has_initializeprob(prob.f)
101+
_initialize_dae!(integrator, prob,
102+
OverrideInit(integrator.opts.abstol), x)
103+
elseif prob.differential_vars === nothing
70104
_initialize_dae!(integrator, prob,
71105
ShampineCollocationInit(), x)
72106
else
@@ -77,7 +111,10 @@ end
77111

78112
function _initialize_dae!(integrator, prob::DAEProblem,
79113
alg::DefaultInit, x::Val{true})
80-
if prob.differential_vars === nothing
114+
if SciMLBase.has_initializeprob(prob.f)
115+
_initialize_dae!(integrator, prob,
116+
OverrideInit(integrator.opts.abstol), x)
117+
elseif prob.differential_vars === nothing
81118
_initialize_dae!(integrator, prob,
82119
ShampineCollocationInit(), x)
83120
else
@@ -92,6 +129,28 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
92129
alg::NoInit, x::Union{Val{true}, Val{false}})
93130
end
94131

132+
## OverrideInit
133+
134+
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
135+
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
136+
initializeprob = prob.f.initializeprob
137+
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
138+
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
139+
nlsol = solve(initializeprob, alg)
140+
if isinplace === Val{true}()
141+
integrator.u .= prob.f.initializeprobmap(nlsol)
142+
elseif isinplace === Val{false}()
143+
integrator.u = prob.f.initializeprobmap(nlsol)
144+
else
145+
error("Unreachable reached. Report this error.")
146+
end
147+
148+
if nlsol.retcode != ReturnCode.Success
149+
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol,
150+
ReturnCode.InitialFailure)
151+
end
152+
end
153+
95154
## ShampineCollocationInit
96155

97156
#=
@@ -196,7 +255,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
196255
jac_prototype = f.jac_prototype,
197256
jac = jac)
198257
nlprob = NonlinearProblem(nlfunc, integrator.u, p)
199-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
258+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
200259
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
201260
reltol = integrator.opts.reltol)
202261
integrator.u .= nlsol.u
@@ -266,12 +325,12 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
266325
end
267326
end
268327

269-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)
270-
271328
nlfunc = NonlinearFunction(nlequation_oop;
272329
jac_prototype = f.jac_prototype,
273330
jac = jac)
274331
nlprob = NonlinearProblem(nlfunc, u0)
332+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
333+
275334
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
276335
reltol = integrator.opts.reltol)
277336
integrator.u = nlsol.u
@@ -351,7 +410,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
351410
jac_prototype = f.jac_prototype,
352411
jac = jac)
353412
nlprob = NonlinearProblem(nlfunc, u0, p)
354-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
413+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
355414
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
356415
reltol = integrator.opts.reltol)
357416

@@ -395,7 +454,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
395454
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype,
396455
jac = jac)
397456
nlprob = NonlinearProblem(nlfunc, u0)
398-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)
457+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, u0)
399458

400459
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
401460
nlprob = NonlinearProblem(nlfunc, u0)
@@ -493,11 +552,10 @@ function _initialize_dae!(integrator, prob::ODEProblem,
493552
end
494553

495554
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
496-
497-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, isAD)
498-
499555
nlfunc = NonlinearFunction(nlequation!; jac_prototype = J)
500556
nlprob = NonlinearProblem(nlfunc, alg_u, p)
557+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, nlprob, isAD)
558+
501559
nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
502560
alg_u .= nlsol
503561

@@ -554,11 +612,10 @@ function _initialize_dae!(integrator, prob::ODEProblem,
554612
end
555613

556614
J = algebraic_jacobian(f.jac_prototype, algebraic_eqs, algebraic_vars)
557-
558-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
559-
560615
nlfunc = NonlinearFunction(nlequation; jac_prototype = J)
561616
nlprob = NonlinearProblem(nlfunc, u0[algebraic_vars])
617+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, nlprob, isAD)
618+
562619
nlsol = solve(nlprob, nlsolve)
563620

564621
u[algebraic_vars] .= nlsol.u
@@ -689,11 +746,11 @@ function _initialize_dae!(integrator, prob::DAEProblem,
689746
f(du, u, p, t)
690747
end
691748

692-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, integrator.u)
693-
694749
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)
695750
nlprob = NonlinearProblem(nlfunc, ifelse.(differential_vars, du, u))
696751

752+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, nlprob, integrator.u)
753+
697754
nlsol = solve(nlprob, nlsolve)
698755

699756
du = ifelse.(differential_vars, nlsol.u, du)

test/interface/dae_initialize_integration.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ n2 = ODESystem(single_neuron_eqs, t, [v, w, F], [g, e, b], name = :n2)
1313
connections = [0 ~ n1.F - D * Dk * max(n1.v - n2.v, 0)
1414
0 ~ n2.F - D * max(n2.v - n1.v, 0)]
1515
connected = ODESystem(connections, t, [], [D, Dk], systems = [n1, n2], name = :connected)
16+
connected = complete(connected)
1617

1718
u0 = [
1819
n1.v => -2,

test/interface/jacobian_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ function lotka(du, u, p, t)
5555
end
5656

5757
prob = ODEProblem(lotka, [1.0, 1.0], (0.0, 1.0), [1.5, 1.0, 3.0, 1.0])
58-
de = ModelingToolkit.modelingtoolkitize(prob)
59-
prob2 = remake(prob, f = ODEFunction(de; jac = true))
58+
de = ModelingToolkit.modelingtoolkitize(prob) |> complete
59+
prob2 = ODEProblem(de; jac = true)
6060

6161
sol = solve(prob, TRBDF2())
6262

test/regression/psos_and_energy_conservation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ end
9090

9191
# energy conserving callback:
9292
# important to use save = false, I dont want rescaling points
93-
cb = ManifoldProjection(ghh, nlopts = Dict(:ftol => 1e-13), save = false)
93+
cb = ManifoldProjection(ghh, abstol = 1e-13, save = false)
9494

9595
# Callback for Poincare surface of section
9696
function psos_callback(j, direction = +1, offset::Real = 0,

0 commit comments

Comments
 (0)