Skip to content

Commit 7882284

Browse files
Merge pull request #633 from SciML/initialize_prob
Allow for tagging an initialization problem to ODEFunction/DAEFunction
2 parents 0d0eed9 + a5f2942 commit 7882284

File tree

1 file changed

+49
-19
lines changed

1 file changed

+49
-19
lines changed

src/scimlfunctions.jl

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS} <: AbstractODEFunction{iip}
405+
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -419,6 +419,8 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
419419
observed::O
420420
colorvec::TCV
421421
sys::SYS
422+
initializeprob::IProb
423+
initializeprobmap::IProbMap
422424
end
423425

424426
TruncatedStacktraces.@truncate_stacktrace ODEFunction 1 2
@@ -1522,7 +1524,7 @@ automatically symbolically generating the Jacobian and more from the
15221524
numerically-defined functions.
15231525
"""
15241526
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
1525-
SYS} <:
1527+
SYS, IProb, IProbMap} <:
15261528
AbstractDAEFunction{iip}
15271529
f::F
15281530
analytic::Ta
@@ -1538,6 +1540,8 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
15381540
observed::O
15391541
colorvec::TCV
15401542
sys::SYS
1543+
initializeprob::IProb
1544+
initializeprobmap::IProbMap
15411545
end
15421546

15431547
TruncatedStacktraces.@truncate_stacktrace DAEFunction 1 2
@@ -2276,7 +2280,10 @@ function ODEFunction{iip, specialize}(f;
22762280
observed = __has_observed(f) ? f.observed :
22772281
DEFAULT_OBSERVED,
22782282
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2279-
sys = __has_sys(f) ? f.sys : nothing) where {iip,
2283+
sys = __has_sys(f) ? f.sys : nothing,
2284+
initializeprob = __has_initializeprob(f) ? f.sys : nothing,
2285+
initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing
2286+
) where {iip,
22802287
specialize
22812288
}
22822289
if mass_matrix === I && f isa Tuple
@@ -2321,18 +2328,22 @@ function ODEFunction{iip, specialize}(f;
23212328

23222329
_f = prepare_function(f)
23232330

2324-
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2331+
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
2332+
2333+
@assert typeof(initializeprob) <:
2334+
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
2335+
23252336
if specialize === NoSpecialize
23262337
ODEFunction{iip, specialize,
23272338
Any, Any, Any, Any,
23282339
Any, Any, Any, typeof(jac_prototype),
23292340
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
23302341
Any,
23312342
typeof(_colorvec),
2332-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2343+
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
23332344
jvp, vjp, jac_prototype, sparsity, Wfact,
23342345
Wfact_t, W_prototype, paramjac,
2335-
observed, _colorvec, sys)
2346+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23362347
elseif specialize === false
23372348
ODEFunction{iip, FunctionWrapperSpecialize,
23382349
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2341,10 +2352,11 @@ function ODEFunction{iip, specialize}(f;
23412352
typeof(paramjac),
23422353
typeof(observed),
23432354
typeof(_colorvec),
2344-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2355+
typeof(sys), typeof(initializeprob),
2356+
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
23452357
jvp, vjp, jac_prototype, sparsity, Wfact,
23462358
Wfact_t, W_prototype, paramjac,
2347-
observed, _colorvec, sys)
2359+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23482360
else
23492361
ODEFunction{iip, specialize,
23502362
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2353,10 +2365,11 @@ function ODEFunction{iip, specialize}(f;
23532365
typeof(paramjac),
23542366
typeof(observed),
23552367
typeof(_colorvec),
2356-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2368+
typeof(sys), typeof(initializeprob),
2369+
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
23572370
jvp, vjp, jac_prototype, sparsity, Wfact,
23582371
Wfact_t, W_prototype, paramjac,
2359-
observed, _colorvec, sys)
2372+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23602373
end
23612374
end
23622375

@@ -2373,21 +2386,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
23732386
Any, Any, Any, Any, typeof(f.jac_prototype),
23742387
typeof(f.sparsity), Any, Any, Any,
23752388
Any, typeof(f.colorvec),
2376-
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2389+
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
23772390
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
23782391
f.Wfact_t, f.W_prototype, f.paramjac,
2379-
f.observed, f.colorvec, f.sys)
2392+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
23802393
else
23812394
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
23822395
typeof(f.analytic), typeof(f.tgrad),
23832396
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
23842397
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
23852398
typeof(f.paramjac),
23862399
typeof(f.observed), typeof(f.colorvec),
2387-
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2400+
typeof(f.sys), typeof(f.initializeprob),
2401+
typeof(f.initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
23882402
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
23892403
f.Wfact_t, f.W_prototype, f.paramjac,
2390-
f.observed, f.colorvec, f.sys)
2404+
f.observed, f.colorvec, f.sys, f.initializeprob,
2405+
f.initializeprobmap)
23912406
end
23922407
end
23932408

@@ -3177,7 +3192,9 @@ function DAEFunction{iip, specialize}(f;
31773192
observed = __has_observed(f) ? f.observed :
31783193
DEFAULT_OBSERVED,
31793194
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
3180-
sys = __has_sys(f) ? f.sys : nothing) where {iip,
3195+
sys = __has_sys(f) ? f.sys : nothing,
3196+
initializeprob = __has_initializeprob(f) ? f.sys : nothing,
3197+
initializeprobmap = __has_initializeprobmap(f) ? f.sys : nothing) where {iip,
31813198
specialize
31823199
}
31833200
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
@@ -3209,24 +3226,28 @@ function DAEFunction{iip, specialize}(f;
32093226
_f = prepare_function(f)
32103227
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
32113228

3229+
@assert typeof(initializeprob) <:
3230+
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}
3231+
32123232
if specialize === NoSpecialize
32133233
DAEFunction{iip, specialize, Any, Any, Any,
32143234
Any, Any, Any, Any, Any,
32153235
Any, Any, Any,
3216-
Any, typeof(_colorvec), Any}(_f, analytic, tgrad, jac, jvp,
3236+
Any, typeof(_colorvec), Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
32173237
vjp, jac_prototype, sparsity,
32183238
Wfact, Wfact_t, paramjac, observed,
3219-
_colorvec, sys)
3239+
_colorvec, sys, initializeprob, initializeprobmap)
32203240
else
32213241
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
32223242
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
32233243
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
32243244
typeof(paramjac),
32253245
typeof(observed), typeof(_colorvec),
3226-
typeof(sys)}(_f, analytic, tgrad, jac, jvp, vjp,
3246+
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
3247+
_f, analytic, tgrad, jac, jvp, vjp,
32273248
jac_prototype, sparsity, Wfact, Wfact_t,
32283249
paramjac, observed,
3229-
_colorvec, sys)
3250+
_colorvec, sys, initializeprob, initializeprobmap)
32303251
end
32313252
end
32323253

@@ -3945,6 +3966,8 @@ __has_colorvec(f) = isdefined(f, :colorvec)
39453966
__has_sys(f) = isdefined(f, :sys)
39463967
__has_analytic_full(f) = isdefined(f, :analytic_full)
39473968
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
3969+
__has_initializeprob(f) = isdefined(f, :initializeprob)
3970+
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
39483971

39493972
# compatibility
39503973
has_invW(f::AbstractSciMLFunction) = false
@@ -3957,6 +3980,13 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
39573980
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
39583981
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
39593982
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
3983+
function has_initializeprob(f::AbstractSciMLFunction)
3984+
__has_initializeprob(f) && f.initializeprob !== nothing
3985+
end
3986+
function has_initializeprobmap(f::AbstractSciMLFunction)
3987+
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
3988+
end
3989+
39603990
function has_syms(f::AbstractSciMLFunction)
39613991
if __has_syms(f)
39623992
f.syms !== nothing

0 commit comments

Comments
 (0)