21
21
BrownFullBasicInit (abstol) = BrownFullBasicInit (; abstol = abstol, nlsolve = nothing )
22
22
23
23
default_nlsolve (alg, isinplace, u, autodiff = false ) = alg
24
- function default_nlsolve (:: Nothing , isinplace, u, autodiff = false )
25
- TrustRegion (; autodiff = autodiff ? AutoForwardDiff () : AutoFiniteDiff ())
24
+ function default_nlsolve (:: Nothing , isinplace, u, :: NonlinearProblem , autodiff = false )
25
+ FastShortcutNonlinearPolyalg (;
26
+ autodiff = autodiff ? AutoForwardDiff () : AutoFiniteDiff ())
26
27
end
27
- function default_nlsolve (:: Nothing , isinplace:: Val{false} , u:: StaticArray , autodiff = false )
28
+ function default_nlsolve (:: Nothing , isinplace:: Val{false} , u:: StaticArray ,
29
+ :: NonlinearProblem , autodiff = false )
28
30
SimpleTrustRegion (autodiff = autodiff ? AutoForwardDiff () : AutoFiniteDiff ())
29
31
end
30
32
33
+ function default_nlsolve (
34
+ :: Nothing , isinplace, u, :: NonlinearLeastSquaresProblem , autodiff = false )
35
+ FastShortcutNLLSPolyalg (; autodiff = autodiff ? AutoForwardDiff () : AutoFiniteDiff ())
36
+ end
37
+ function default_nlsolve (:: Nothing , isinplace:: Val{false} , u:: StaticArray ,
38
+ :: NonlinearLeastSquaresProblem , autodiff = false )
39
+ SimpleGaussNewton (autodiff = autodiff ? AutoForwardDiff () : AutoFiniteDiff ())
40
+ end
41
+
42
+ struct OverrideInit{T, F} <: DiffEqBase.DAEInitializationAlgorithm
43
+ abstol:: T
44
+ nlsolve:: F
45
+ end
46
+
47
+ function OverrideInit (; abstol = 1e-10 , nlsolve = nothing )
48
+ OverrideInit (abstol, nlsolve)
49
+ end
50
+ OverrideInit (abstol) = OverrideInit (; abstol = abstol, nlsolve = nothing )
51
+
31
52
# # Notes
32
53
33
54
#=
54
75
55
76
function _initialize_dae! (integrator, prob:: ODEProblem ,
56
77
alg:: DefaultInit , x:: Val{true} )
57
- _initialize_dae! (integrator, prob,
58
- BrownFullBasicInit (integrator. opts. abstol), x)
78
+ if SciMLBase. has_initializeprob (prob. f)
79
+ _initialize_dae! (integrator, prob,
80
+ OverrideInit (integrator. opts. abstol), x)
81
+ else
82
+ _initialize_dae! (integrator, prob,
83
+ BrownFullBasicInit (integrator. opts. abstol), x)
84
+ end
59
85
end
60
86
61
87
function _initialize_dae! (integrator, prob:: ODEProblem ,
62
88
alg:: DefaultInit , x:: Val{false} )
63
- _initialize_dae! (integrator, prob,
64
- BrownFullBasicInit (integrator. opts. abstol), x)
89
+ if SciMLBase. has_initializeprob (prob. f)
90
+ _initialize_dae! (integrator, prob,
91
+ OverrideInit (integrator. opts. abstol), x)
92
+ else
93
+ _initialize_dae! (integrator, prob,
94
+ BrownFullBasicInit (integrator. opts. abstol), x)
95
+ end
65
96
end
66
97
67
98
function _initialize_dae! (integrator, prob:: DAEProblem ,
68
99
alg:: DefaultInit , x:: Val{false} )
69
- if prob. differential_vars === nothing
100
+ if SciMLBase. has_initializeprob (prob. f)
101
+ _initialize_dae! (integrator, prob,
102
+ OverrideInit (integrator. opts. abstol), x)
103
+ elseif prob. differential_vars === nothing
70
104
_initialize_dae! (integrator, prob,
71
105
ShampineCollocationInit (), x)
72
106
else
77
111
78
112
function _initialize_dae! (integrator, prob:: DAEProblem ,
79
113
alg:: DefaultInit , x:: Val{true} )
80
- if prob. differential_vars === nothing
114
+ if SciMLBase. has_initializeprob (prob. f)
115
+ _initialize_dae! (integrator, prob,
116
+ OverrideInit (integrator. opts. abstol), x)
117
+ elseif prob. differential_vars === nothing
81
118
_initialize_dae! (integrator, prob,
82
119
ShampineCollocationInit (), x)
83
120
else
@@ -92,6 +129,28 @@ function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
92
129
alg:: NoInit , x:: Union{Val{true}, Val{false}} )
93
130
end
94
131
132
+ # # OverrideInit
133
+
134
+ function _initialize_dae! (integrator, prob:: Union{ODEProblem, DAEProblem} ,
135
+ alg:: OverrideInit , isinplace:: Union{Val{true}, Val{false}} )
136
+ initializeprob = prob. f. initializeprob
137
+ isAD = alg_autodiff (integrator. alg) isa AutoForwardDiff
138
+ alg = default_nlsolve (alg. nlsolve, isinplace, initializeprob. u0, initializeprob, isAD)
139
+ nlsol = solve (initializeprob, alg)
140
+ if isinplace === Val {true} ()
141
+ integrator. u .= prob. f. initializeprobmap (nlsol)
142
+ elseif isinplace === Val {false} ()
143
+ integrator. u = prob. f. initializeprobmap (nlsol)
144
+ else
145
+ error (" Unreachable reached. Report this error." )
146
+ end
147
+
148
+ if nlsol. retcode != ReturnCode. Success
149
+ integrator. sol = SciMLBase. solution_new_retcode (integrator. sol,
150
+ ReturnCode. InitialFailure)
151
+ end
152
+ end
153
+
95
154
# # ShampineCollocationInit
96
155
97
156
#=
@@ -196,7 +255,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
196
255
jac_prototype = f. jac_prototype,
197
256
jac = jac)
198
257
nlprob = NonlinearProblem (nlfunc, integrator. u, p)
199
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
258
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, nlprob, isAD)
200
259
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
201
260
reltol = integrator. opts. reltol)
202
261
integrator. u .= nlsol. u
@@ -266,12 +325,12 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
266
325
end
267
326
end
268
327
269
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
270
-
271
328
nlfunc = NonlinearFunction (nlequation_oop;
272
329
jac_prototype = f. jac_prototype,
273
330
jac = jac)
274
331
nlprob = NonlinearProblem (nlfunc, u0)
332
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, nlprob, u0)
333
+
275
334
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
276
335
reltol = integrator. opts. reltol)
277
336
integrator. u = nlsol. u
@@ -351,7 +410,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
351
410
jac_prototype = f. jac_prototype,
352
411
jac = jac)
353
412
nlprob = NonlinearProblem (nlfunc, u0, p)
354
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
413
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, nlprob, isAD)
355
414
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
356
415
reltol = integrator. opts. reltol)
357
416
@@ -395,7 +454,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,
395
454
nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype,
396
455
jac = jac)
397
456
nlprob = NonlinearProblem (nlfunc, u0)
398
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
457
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, nlprob, u0)
399
458
400
459
nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
401
460
nlprob = NonlinearProblem (nlfunc, u0)
@@ -493,11 +552,10 @@ function _initialize_dae!(integrator, prob::ODEProblem,
493
552
end
494
553
495
554
J = algebraic_jacobian (f. jac_prototype, algebraic_eqs, algebraic_vars)
496
-
497
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u, isAD)
498
-
499
555
nlfunc = NonlinearFunction (nlequation!; jac_prototype = J)
500
556
nlprob = NonlinearProblem (nlfunc, alg_u, p)
557
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u, nlprob, isAD)
558
+
501
559
nlsol = solve (nlprob, nlsolve; abstol = alg. abstol, reltol = integrator. opts. reltol)
502
560
alg_u .= nlsol
503
561
@@ -554,11 +612,10 @@ function _initialize_dae!(integrator, prob::ODEProblem,
554
612
end
555
613
556
614
J = algebraic_jacobian (f. jac_prototype, algebraic_eqs, algebraic_vars)
557
-
558
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
559
-
560
615
nlfunc = NonlinearFunction (nlequation; jac_prototype = J)
561
616
nlprob = NonlinearProblem (nlfunc, u0[algebraic_vars])
617
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, nlprob, isAD)
618
+
562
619
nlsol = solve (nlprob, nlsolve)
563
620
564
621
u[algebraic_vars] .= nlsol. u
@@ -689,11 +746,11 @@ function _initialize_dae!(integrator, prob::DAEProblem,
689
746
f (du, u, p, t)
690
747
end
691
748
692
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, integrator. u)
693
-
694
749
nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
695
750
nlprob = NonlinearProblem (nlfunc, ifelse .(differential_vars, du, u))
696
751
752
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, nlprob, integrator. u)
753
+
697
754
nlsol = solve (nlprob, nlsolve)
698
755
699
756
du = ifelse .(differential_vars, nlsol. u, du)
0 commit comments