Skip to content

Commit 4ac6301

Browse files
feat!: add MTKParameters to more systems/functions
1 parent b4246b8 commit 4ac6301

File tree

6 files changed

+132
-41
lines changed

6 files changed

+132
-41
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,19 @@ function generate_tgrad(sys::AbstractODESystem, dvs = unknowns(sys), ps = parame
8484
simplify = false, kwargs...)
8585
tgrad = calculate_tgrad(sys, simplify = simplify)
8686
pre = get_preprocess_constants(tgrad)
87-
if ps isa Tuple
88-
return build_function(tgrad,
89-
dvs,
90-
ps...,
91-
get_iv(sys);
92-
postprocess_fbody = pre,
93-
kwargs...)
87+
p = if has_index_cache(sys)
88+
reorder_parameters(get_index_cache(sys), ps)
89+
elseif ps isa Tuple
90+
ps
9491
else
95-
return build_function(tgrad,
96-
dvs,
97-
ps,
98-
get_iv(sys);
99-
postprocess_fbody = pre,
100-
kwargs...)
92+
(ps,)
10193
end
94+
return build_function(tgrad,
95+
dvs,
96+
p...,
97+
get_iv(sys);
98+
postprocess_fbody = pre,
99+
kwargs...)
102100
end
103101

104102
function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys), ps = parameters(sys);
@@ -135,7 +133,12 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
135133
@variables ˍ₋gamma
136134
jac = ˍ₋gamma * jac_du + jac_u
137135
pre = get_preprocess_constants(jac)
138-
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys);
136+
p = if has_index_cache(sys)
137+
reorder_parameters(get_index_cache(sys), ps)
138+
else
139+
(ps,)
140+
end
141+
return build_function(jac, derivatives, dvs, p..., ˍ₋gamma, get_iv(sys);
139142
postprocess_fbody = pre, kwargs...)
140143
end
141144

@@ -399,15 +402,25 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
399402
build_explicit_observed_function(sys, obsvar)
400403
end
401404
if args === ()
402-
let obs = obs
403-
(u, p, t = Inf) -> if ps isa Tuple
405+
let obs = obs, ps_T = typeof(ps)
406+
(u, p, t = Inf) -> if p isa MTKParameters
407+
obs(u, raw_vectors(p)..., t)
408+
elseif ps_T <: Tuple
404409
obs(u, p..., t)
405410
else
406411
obs(u, p, t)
407412
end
408413
end
409414
else
410-
if ps isa Tuple
415+
if args[2] isa MTKParameters
416+
if length(args) == 2
417+
u, p = args
418+
obs(u, raw_vectors(p)..., Inf)
419+
else
420+
u, p, t = args
421+
obs(u, raw_vectors(p)..., t)
422+
end
423+
elseif ps isa Tuple
411424
if length(args) == 2
412425
u, p = args
413426
obs(u, p..., Inf)
@@ -437,16 +450,21 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = u
437450
ps)
438451
end
439452
if args === ()
440-
let obs = obs
441-
(u, p, t) -> if ps isa Tuple
453+
let obs = obs, ps_T = typeof(ps)
454+
(u, p, t) -> if p isa MTKParameters
455+
obs(u, raw_vectors(p)..., t)
456+
elseif ps_T <: Tuple
442457
obs(u, p..., t)
443458
else
444459
obs(u, p, t)
445460
end
446461
end
447462
else
448-
if ps isa Tuple # split parameters
463+
u, p, t = args
464+
if p isa MTKParameters
449465
u, p, t = args
466+
obs(u, raw_vectors(p)..., t)
467+
elseif ps isa Tuple # split parameters
450468
obs(u, p..., t)
451469
else
452470
obs(args...)
@@ -518,7 +536,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
518536
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
519537
f_gen
520538
f(du, u, p, t) = f_oop(du, u, p, t)
539+
f(du, u, p::MTKParameters, t) = f_oop(du, u, raw_vectors(p)..., t)
521540
f(out, du, u, p, t) = f_iip(out, du, u, p, t)
541+
f(out, du, u, p::MTKParameters, t) = f_iip(out, du, u, raw_vectors(p)..., t)
522542

523543
if jac
524544
jac_gen = generate_dae_jacobian(sys, dvs, ps;
@@ -530,8 +550,10 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
530550
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
531551
jac_gen
532552
_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
553+
_jac(du, u, p::MTKParameters, ˍ₋gamma, t) = jac_oop(du, u, raw_vectors(p)..., ˍ₋gamma, t)
533554

534555
_jac(J, du, u, p, ˍ₋gamma, t) = jac_iip(J, du, u, p, ˍ₋gamma, t)
556+
_jac(J, du, u, p::MTKParameters, ˍ₋gamma, t) = jac_iip(J, du, u, raw_vectors(p)..., ˍ₋gamma, t)
535557
else
536558
_jac = nothing
537559
end
@@ -544,10 +566,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
544566
end
545567
if args === ()
546568
let obs = obs
547-
(u, p, t) -> obs(u, p, t)
569+
fun(u, p, t) = obs(u, p, t)
570+
fun(u, p::MTKParameters, t) = obs(u, raw_vectors(p)..., t)
571+
fun
548572
end
549573
else
550-
obs(args...)
574+
u, p, t = args
575+
p isa MTKParameters ? obs(u, raw_vectors(p)..., t) : obs(u, p, t)
551576
end
552577
end
553578
end
@@ -591,7 +616,9 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
591616
kwargs...)
592617
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
593618
f(u, h, p, t) = f_oop(u, h, p, t)
619+
f(u, h, p::MTKParameters, t) = f_oop(u, h, raw_vectors(p)..., t)
594620
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
621+
f(du, u, h, p::MTKParameters, t) = f_iip(du, u, h, raw_vectors(p)..., t)
595622

596623
DDEFunction{iip}(f, sys = sys)
597624
end
@@ -617,9 +644,13 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
617644
isdde = true, kwargs...)
618645
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
619646
f(u, h, p, t) = f_oop(u, h, p, t)
647+
f(u, h, p::MTKParameters, t) = f_oop(u, h, raw_vectors(p)..., t)
620648
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
649+
f(du, u, h, p::MTKParameters, t) = f_iip(du, u, h, raw_vectors(p)..., t)
621650
g(u, h, p, t) = g_oop(u, h, p, t)
651+
g(u, h, p::MTKParameters, t) = g_oop(u, h, raw_vectors(p)..., t)
622652
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
653+
g(du, u, h, p::MTKParameters, t) = g_iip(du, u, h, raw_vectors(p)..., t)
623654

624655
SDDEFunction{iip}(f, g, sys = sys)
625656
end

src/systems/diffeqs/odesystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ function build_explicit_observed_function(sys, ts;
401401
if inputs !== nothing
402402
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
403403
end
404-
if ps isa Tuple
404+
if has_index_cache(sys)
405+
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
406+
elseif ps isa Tuple
405407
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
406408
else
407409
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)

src/systems/diffeqs/sdesystem.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ struct SDESystem <: AbstractODESystem
116116
"""
117117
complete::Bool
118118
"""
119+
Cached data for fast symbolic indexing.
120+
"""
121+
index_cache::Union{Nothing, IndexCache}
122+
"""
119123
The hierarchical parent system before simplification.
120124
"""
121125
parent::Any
@@ -125,7 +129,7 @@ struct SDESystem <: AbstractODESystem
125129
jac,
126130
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type,
127131
cevents, devents, metadata = nothing, gui_metadata = nothing,
128-
complete = false, parent = nothing;
132+
complete = false, index_cache = nothing, parent = nothing;
129133
checks::Union{Bool, Int} = true)
130134
if checks == true || (checks & CheckComponents) > 0
131135
check_variables(dvs, iv)
@@ -140,7 +144,7 @@ struct SDESystem <: AbstractODESystem
140144
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
141145
ctrl_jac,
142146
Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents,
143-
metadata, gui_metadata, complete, parent)
147+
metadata, gui_metadata, complete, index_cache, parent)
144148
end
145149
end
146150

@@ -221,11 +225,15 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
221225
eqs = delay_to_function(sys, eqs)
222226
end
223227
u = map(x -> time_varying_as_func(value(x), sys), dvs)
224-
p = map(x -> time_varying_as_func(value(x), sys), ps)
228+
p = if has_index_cache(sys)
229+
reorder_parameters(get_index_cache(sys), ps)
230+
else
231+
(map(x -> time_varying_as_func(value(x), sys), ps),)
232+
end
225233
if isdde
226-
return build_function(eqs, u, DDE_HISTORY_FUN, p, get_iv(sys); kwargs...)
234+
return build_function(eqs, u, DDE_HISTORY_FUN, p..., get_iv(sys); kwargs...)
227235
else
228-
return build_function(eqs, u, p, get_iv(sys); kwargs...)
236+
return build_function(eqs, u, p..., get_iv(sys); kwargs...)
229237
end
230238
end
231239

@@ -408,9 +416,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
408416
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) : g_gen
409417

410418
f(u, p, t) = f_oop(u, p, t)
419+
f(u, p::MTKParameters, t) = f_oop(u, raw_vectors(p)..., t)
411420
f(du, u, p, t) = f_iip(du, u, p, t)
421+
f(du, u, p::MTKParameters, t) = f_iip(du, u, raw_vectors(p)..., t)
412422
g(u, p, t) = g_oop(u, p, t)
423+
g(u, p::MTKParameters, t) = g_oop(u, raw_vectors(p)..., t)
413424
g(du, u, p, t) = g_iip(du, u, p, t)
425+
g(du, u, p::MTKParameters, t) = g_iip(du, u, raw_vectors(p)..., t)
414426

415427
if tgrad
416428
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{eval_expression},
@@ -419,7 +431,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
419431
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen) :
420432
tgrad_gen
421433
_tgrad(u, p, t) = tgrad_oop(u, p, t)
434+
_tgrad(u, p::MTKParameters, t) = tgrad_oop(u, raw_vectors(p)..., t)
422435
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
436+
_tgrad(J, u, p::MTKParameters, t) = tgrad_iip(J, u, raw_vectors(p)..., t)
423437
else
424438
_tgrad = nothing
425439
end
@@ -431,7 +445,9 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
431445
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen) :
432446
jac_gen
433447
_jac(u, p, t) = jac_oop(u, p, t)
448+
_jac(u, p::MTKParameters, t) = jac_oop(u, raw_vectors(p)..., t)
434449
_jac(J, u, p, t) = jac_iip(J, u, p, t)
450+
_jac(J, u, p::MTKParameters, t) = jac_iip(J, u, raw_vectors(p)..., t)
435451
else
436452
_jac = nothing
437453
end
@@ -446,9 +462,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
446462
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t) :
447463
tmp_Wfact_t
448464
_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
465+
_Wfact(u, p::MTKParameters, dtgamma, t) = Wfact_oop(u, raw_vectors(p)..., dtgamma, t)
449466
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)
467+
_Wfact(W, u, p::MTKParameters, dtgamma, t) = Wfact_iip(W, u, raw_vectors(p)..., dtgamma, t)
450468
_Wfact_t(u, p, dtgamma, t) = Wfact_oop_t(u, p, dtgamma, t)
469+
_Wfact_t(u, p::MTKParameters, dtgamma, t) = Wfact_oop_t(u, raw_vectors(p)..., dtgamma, t)
451470
_Wfact_t(W, u, p, dtgamma, t) = Wfact_iip_t(W, u, p, dtgamma, t)
471+
_Wfact_t(W, u, p::MTKParameters, dtgamma, t) = Wfact_iip_t(W, u, raw_vectors(p)..., dtgamma, t)
452472
else
453473
_Wfact, _Wfact_t = nothing, nothing
454474
end
@@ -462,11 +482,14 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
462482
obs = get!(dict, value(obsvar)) do
463483
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
464484
end
465-
obs(u, p, t)
485+
if p isa MTKParameters
486+
obs(u, raw_vectors(p)..., t)
487+
else
488+
obs(u, p, t)
489+
end
466490
end
467491
end
468492

469-
sts = unknowns(sys)
470493
SDEFunction{iip}(f, g,
471494
sys = sys,
472495
jac = _jac === nothing ? nothing : _jac,

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,15 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
101101
If a model `sys` is complete, then `sys.x` no longer performs namespacing.
102102
"""
103103
complete::Bool
104+
"""
105+
Cached data for fast symbolic indexing.
106+
"""
107+
index_cache::Union{Nothing, IndexCache}
104108

105109
function JumpSystem{U}(tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, systems,
106110
defaults, connector_type, devents,
107111
metadata = nothing, gui_metadata = nothing,
108-
complete = false;
112+
complete = false, index_cache = nothing;
109113
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
110114
if checks == true || (checks & CheckComponents) > 0
111115
check_variables(unknowns, iv)
@@ -116,7 +120,7 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
116120
check_units(u, ap, iv)
117121
end
118122
new{U}(tag, ap, iv, unknowns, ps, var_to_name, observed, name, systems, defaults,
119-
connector_type, devents, metadata, gui_metadata, complete)
123+
connector_type, devents, metadata, gui_metadata, complete, index_cache)
120124
end
121125
end
122126
JumpSystem(tag, ap, iv, states, ps, var_to_name, args...; kwargs...) = JumpSystem{typeof(ap)}(tag, ap, iv, states, ps, var_to_name, args...; kwargs...)

0 commit comments

Comments
 (0)