Skip to content

Commit cb6c390

Browse files
Allow for DAE initialization on ODEs with initializeprob
This can come up from cases with ModelingToolkit, see SciML/ModelingToolkit.jl#2512
1 parent 019c186 commit cb6c390

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

src/alg_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
284284
CompositeAlgorithm(algs, alg.choice_function)
285285
end
286286

287+
has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false
288+
has_autodiff(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm, OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm}) = true
289+
287290
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
288291
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
289292
error("This algorithm does not have an autodifferentiation option defined.")

src/initialize_dae.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,12 @@ end
134134
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
135135
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
136136
initializeprob = prob.f.initializeprob
137-
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff
137+
138+
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
139+
# Since then it's the case of not a DAE but has initializeprob
140+
# In which case, it should be differentiable
141+
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : true
142+
138143
alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
139144
nlsol = solve(initializeprob, alg)
140145
if isinplace === Val{true}()

src/solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ function DiffEqBase.__init(
493493
opts, stats, initializealg, differential_vars)
494494

495495
if initialize_integrator
496-
if isdae
496+
if isdae || SciMLBase.has_initializeprob(prob.f)
497497
DiffEqBase.initialize_dae!(integrator)
498498
end
499499

0 commit comments

Comments
 (0)