Skip to content

Commit d363171

Browse files
Auto stash before checking out "origin/master"
1 parent 1774538 commit d363171

File tree

1 file changed

+50
-38
lines changed

1 file changed

+50
-38
lines changed

src/scimlfunctions.jl

Lines changed: 50 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ numerically-defined functions.
402402
"""
403403
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
404404
O, TCV,
405-
SYS} <: AbstractODEFunction{iip}
405+
SYS, IProb, IProbMap} <: AbstractODEFunction{iip}
406406
f::F
407407
mass_matrix::TMM
408408
analytic::Ta
@@ -419,6 +419,8 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
419419
observed::O
420420
colorvec::TCV
421421
sys::SYS
422+
initializeprob::IProb
423+
initializeprobmap::IProbMap
422424
end
423425

424426
TruncatedStacktraces.@truncate_stacktrace ODEFunction 1 2
@@ -2254,30 +2256,33 @@ end
22542256
######### Basic Constructor
22552257

22562258
function ODEFunction{iip, specialize}(f;
2257-
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
2258-
I,
2259-
analytic = __has_analytic(f) ? f.analytic : nothing,
2260-
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
2261-
jac = __has_jac(f) ? f.jac : nothing,
2262-
jvp = __has_jvp(f) ? f.jvp : nothing,
2263-
vjp = __has_vjp(f) ? f.vjp : nothing,
2264-
jac_prototype = __has_jac_prototype(f) ?
2265-
f.jac_prototype :
2266-
nothing,
2267-
sparsity = __has_sparsity(f) ? f.sparsity :
2268-
jac_prototype,
2269-
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
2270-
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
2271-
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
2272-
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
2273-
syms = nothing,
2274-
indepsym = nothing,
2275-
paramsyms = nothing,
2276-
observed = __has_observed(f) ? f.observed :
2277-
DEFAULT_OBSERVED,
2278-
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2279-
sys = __has_sys(f) ? f.sys : nothing) where {iip,
2280-
specialize
2259+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
2260+
I,
2261+
analytic = __has_analytic(f) ? f.analytic : nothing,
2262+
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
2263+
jac = __has_jac(f) ? f.jac : nothing,
2264+
jvp = __has_jvp(f) ? f.jvp : nothing,
2265+
vjp = __has_vjp(f) ? f.vjp : nothing,
2266+
jac_prototype = __has_jac_prototype(f) ?
2267+
f.jac_prototype :
2268+
nothing,
2269+
sparsity = __has_sparsity(f) ? f.sparsity :
2270+
jac_prototype,
2271+
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
2272+
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
2273+
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
2274+
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
2275+
syms = nothing,
2276+
indepsym = nothing,
2277+
paramsyms = nothing,
2278+
observed = __has_observed(f) ? f.observed :
2279+
DEFAULT_OBSERVED,
2280+
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
2281+
sys = __has_sys(f) ? f.sys : nothing,
2282+
initializeprob = nothing,
2283+
initializeprobmap = nothing
2284+
) where {iip,
2285+
specialize,
22812286
}
22822287
if mass_matrix === I && f isa Tuple
22832288
mass_matrix = ((I for i in 1:length(f))...,)
@@ -2321,18 +2326,21 @@ function ODEFunction{iip, specialize}(f;
23212326

23222327
_f = prepare_function(f)
23232328

2324-
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
2329+
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
2330+
2331+
@assert typeof(initializeprob) <: Union{NonlinearProblem, NonlinearLeastSquaresProblem}
2332+
23252333
if specialize === NoSpecialize
23262334
ODEFunction{iip, specialize,
23272335
Any, Any, Any, Any,
23282336
Any, Any, Any, typeof(jac_prototype),
23292337
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
23302338
Any,
23312339
typeof(_colorvec),
2332-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2340+
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
23332341
jvp, vjp, jac_prototype, sparsity, Wfact,
23342342
Wfact_t, W_prototype, paramjac,
2335-
observed, _colorvec, sys)
2343+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23362344
elseif specialize === false
23372345
ODEFunction{iip, FunctionWrapperSpecialize,
23382346
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2341,10 +2349,11 @@ function ODEFunction{iip, specialize}(f;
23412349
typeof(paramjac),
23422350
typeof(observed),
23432351
typeof(_colorvec),
2344-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2352+
typeof(sys), typeof(initializeprob),
2353+
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
23452354
jvp, vjp, jac_prototype, sparsity, Wfact,
2346-
Wfact_t, W_prototype, paramjac,
2347-
observed, _colorvec, sys)
2355+
Wfact_t, W_prototype, paramjac,
2356+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23482357
else
23492358
ODEFunction{iip, specialize,
23502359
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
@@ -2353,10 +2362,11 @@ function ODEFunction{iip, specialize}(f;
23532362
typeof(paramjac),
23542363
typeof(observed),
23552364
typeof(_colorvec),
2356-
typeof(sys)}(_f, mass_matrix, analytic, tgrad, jac,
2365+
typeof(sys), typeof(initializeprob),
2366+
typeof(initializeprobmap)}(_f, mass_matrix, analytic, tgrad, jac,
23572367
jvp, vjp, jac_prototype, sparsity, Wfact,
2358-
Wfact_t, W_prototype, paramjac,
2359-
observed, _colorvec, sys)
2368+
Wfact_t, W_prototype, paramjac,
2369+
observed, _colorvec, sys, initializeprob, initializeprobmap)
23602370
end
23612371
end
23622372

@@ -2373,21 +2383,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
23732383
Any, Any, Any, Any, typeof(f.jac_prototype),
23742384
typeof(f.sparsity), Any, Any, Any,
23752385
Any, typeof(f.colorvec),
2376-
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2386+
typeof(f.sys), Any, Any}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
23772387
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
23782388
f.Wfact_t, f.W_prototype, f.paramjac,
2379-
f.observed, f.colorvec, f.sys)
2389+
f.observed, f.colorvec, f.sys, f.initializeprob, f.initializeprobmap)
23802390
else
23812391
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
23822392
typeof(f.analytic), typeof(f.tgrad),
23832393
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
23842394
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
23852395
typeof(f.paramjac),
23862396
typeof(f.observed), typeof(f.colorvec),
2387-
typeof(f.sys)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
2397+
typeof(f.sys), typeof(initializeprob),
2398+
typeof(initializeprobmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
23882399
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
23892400
f.Wfact_t, f.W_prototype, f.paramjac,
2390-
f.observed, f.colorvec, f.sys)
2401+
f.observed, f.colorvec, f.sys, f.initializeprob,
2402+
f.initializeprobmap)
23912403
end
23922404
end
23932405

0 commit comments

Comments
 (0)