Skip to content

Commit d571a82

Browse files
handle DAEFunction as well
1 parent d363171 commit d571a82

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/scimlfunctions.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the
15241524
numerically-defined functions.
15251525
"""
15261526
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1527-
SYS} <:
1527+
SYS, IProb, IProbMap} <:
15281528
AbstractDAEFunction{iip}
15291529
f::F
15301530
analytic::Ta
@@ -1540,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15401540
observed::O
15411541
colorvec::TCV
15421542
sys::SYS
1543+
initializeprob::IProb
1544+
initializeprobmap::IProbMap
15431545
end
15441546

15451547
TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2
@@ -2279,8 +2281,8 @@ function ODEFunction{iip, specialize}(f;
22792281
DEFAULT_OBSERVED,
22802282
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
22812283
sys = __has_sys(f) ? f.sys : nothing,
2282-
initializeprob = nothing,
2283-
initializeprobmap = nothing
2284+
initializeprob = _has_initializeprob(f) ? f.sys : nothing,
2285+
initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing
22842286
) where {iip,
22852287
specialize,
22862288
}
@@ -2328,7 +2330,7 @@ function ODEFunction{iip, specialize}(f;
23282330

23292331
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
23302332

2331-
@assert typeof(initializeprob) <: Union{NonlinearProblem, NonlinearLeastSquaresProblem}
2333+
@assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
23322334

23332335
if specialize === NoSpecialize
23342336
ODEFunction{iip, specialize,
@@ -3189,7 +3191,9 @@ function DAEFunction{iip, specialize}(f;
31893191
observed = __has_observed(f) ? f.observed :
31903192
DEFAULT_OBSERVED,
31913193
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
3192-
sys = __has_sys(f) ? f.sys : nothing) where {iip,
3194+
sys = __has_sys(f) ? f.sys : nothing,
3195+
initializeprob = _has_initializeprob(f) ? f.sys : nothing,
3196+
initializeprobmap = _has_initializeprobmap(f) ? f.sys : nothing) where {iip,
31933197
specialize
31943198
}
31953199
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
@@ -3221,14 +3225,16 @@ function DAEFunction{iip, specialize}(f;
32213225
_f = prepare_function(f)
32223226
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
32233227

3228+
@assert typeof(initializeprob) <: Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
3229+
32243230
if specialize === NoSpecialize
32253231
DAEFunction{iip, specialize, Any, Any, Any,
32263232
Any, Any, Any, Any, Any,
32273233
Any, Any, Any,
32283234
Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp,
32293235
vjp, jac_prototype, sparsity,
32303236
Wfact, Wfact_t, paramjac, observed,
3231-
_colorvec, sys)
3237+
_colorvec, sys, initializeprob, initializeprobmap)
32323238
else
32333239
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
32343240
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
@@ -3238,7 +3244,7 @@ function DAEFunction{iip, specialize}(f;
32383244
typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp,
32393245
jac_prototype, sparsity, Wfact, Wfact_t,
32403246
paramjac, observed,
3241-
_colorvec, sys)
3247+
_colorvec, sys, initializeprob, initializeprobmap)
32423248
end
32433249
end
32443250

@@ -3957,6 +3963,8 @@ __has_colorvec(f) = isdefined(f, :colorvec)
39573963
__has_sys(f) = isdefined(f, :sys)
39583964
__has_analytic_full(f) = isdefined(f, :analytic_full)
39593965
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
3966+
__has_initializeprob(f) = isdefined(f, :initializeprob)
3967+
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
39603968

39613969
# compatibility
39623970
has_invW(f::AbstractSciMLFunction) = false
@@ -3969,6 +3977,9 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
39693977
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
39703978
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
39713979
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
3980+
has_initializeprob(f::AbstractSciMLFunction) = __has_initializeprob(f) && f.initializeprob !== nothing
3981+
has_initializeprobmap(f::AbstractSciMLFunction) = __has_initializeprobmap(f) && f.initializeprobmap !== nothing
3982+
39723983
function has_syms(f::AbstractSciMLFunction)
39733984
if __has_syms(f)
39743985
f.syms !== nothing

0 commit comments

Comments
 (0)