Skip to content

Commit 3ea493a

Browse files
handle dds and t
1 parent 33c361e commit 3ea493a

File tree

7 files changed

+72
-27
lines changed

7 files changed

+72
-27
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
2323
IncrementalCycleTracker, add_edge_checked!, topological_sort,
2424
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
2525
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
26-
fast_substitute, get_fullvars, has_equations, observed
26+
fast_substitute, get_fullvars, has_equations, observed,
27+
Schedule
2728

2829
using ModelingToolkit.BipartiteGraphs
2930
import .BipartiteGraphs: invview, complete

src/structural_transformation/symbolics_tearing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
555555
# TODO: compute the dependency correctly so that we don't have to do this
556556
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
557557
@set! sys.observed = obs
558+
559+
# Only makes sense for time-dependent
560+
# TODO: generalize to SDE
561+
if sys isa ODESystem
562+
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
563+
end
558564
@set! state.sys = sys
559565
@set! sys.tearing_state = state
560566
return invalidate_cache!(sys)

src/systems/abstractsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ for prop in [:eqs
576576
:preface
577577
:torn_matching
578578
:initializesystem
579+
:schedule
579580
:tearing_state
580581
:substitutions
581582
:metadata

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
struct Schedule
2+
var_eq_matching
3+
dummy_sub
4+
end
5+
16
function filter_kwargs(kwargs)
27
kwargs = Dict(kwargs)
38
for key in keys(kwargs)
@@ -326,7 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
326331
expression_module = eval_module, checkbounds = checkbounds,
327332
kwargs...)
328333
f_oop, f_iip = eval_expression ?
329-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
334+
((@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
330335
f_gen
331336
f(u, p, t) = f_oop(u, p, t)
332337
f(du, u, p, t) = f_iip(du, u, p, t)
@@ -351,7 +356,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
351356
expression_module = eval_module,
352357
checkbounds = checkbounds, kwargs...)
353358
tgrad_oop, tgrad_iip = eval_expression ?
354-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
359+
((@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
355360
tgrad_gen
356361
if p isa Tuple
357362
__tgrad(u, p, t) = tgrad_oop(u, p..., t)
@@ -373,7 +378,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
373378
expression_module = eval_module,
374379
checkbounds = checkbounds, kwargs...)
375380
jac_oop, jac_iip = eval_expression ?
376-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
381+
((@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
377382
jac_gen
378383
_jac(u, p, t) = jac_oop(u, p, t)
379384
_jac(J, u, p, t) = jac_iip(J, u, p, t)
@@ -541,7 +546,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
541546
expression_module = eval_module, checkbounds = checkbounds,
542547
kwargs...)
543548
f_oop, f_iip = eval_expression ?
544-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
549+
((@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
545550
f_gen
546551
f(du, u, p, t) = f_oop(du, u, p, t)
547552
f(du, u, p::MTKParameters, t) = f_oop(du, u, p..., t)
@@ -555,7 +560,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
555560
expression_module = eval_module,
556561
checkbounds = checkbounds, kwargs...)
557562
jac_oop, jac_iip = eval_expression ?
558-
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
563+
((@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
559564
jac_gen
560565
_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
561566
_jac(du, u, p::MTKParameters, ˍ₋gamma, t) = jac_oop(du, u, p..., ˍ₋gamma, t)
@@ -624,7 +629,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
624629
expression = Val{true},
625630
expression_module = eval_module, checkbounds = checkbounds,
626631
kwargs...)
627-
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
632+
f_oop, f_iip = ((@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
628633
f(u, h, p, t) = f_oop(u, h, p, t)
629634
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
630635
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
@@ -649,10 +654,10 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
649654
expression = Val{true},
650655
expression_module = eval_module, checkbounds = checkbounds,
651656
kwargs...)
652-
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
657+
f_oop, f_iip = ((@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
653658
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
654659
isdde = true, kwargs...)
655-
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
660+
g_oop, g_iip = ((@RuntimeGeneratedFunction(ex)) for ex in g_gen)
656661
f(u, h, p, t) = f_oop(u, h, p, t)
657662
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
658663
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
@@ -849,6 +854,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
849854
symbolic_u0 = false,
850855
u0_constructor = identity,
851856
guesses = Dict(),
857+
t = nothing,
852858
warn_initialize_determined = true,
853859
kwargs...)
854860
eqs = equations(sys)
@@ -870,7 +876,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
870876
u0map = unknowns(sys) .=> u0map
871877
end
872878
initializeprob = ModelingToolkit.InitializationProblem(
873-
sys, u0map, parammap; guesses, warn_initialize_determined)
879+
sys, t, u0map, parammap; guesses, warn_initialize_determined)
874880
initializeprobmap = getu(initializeprob, unknowns(sys))
875881

876882
zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
@@ -1101,6 +1107,7 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
11011107
end
11021108
f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap;
11031109
implicit_dae = true, du0map = du0map, check_length,
1110+
t = tspan !== nothing ? tspan[1] : tspan,
11041111
warn_initialize_determined, kwargs...)
11051112
diffvars = collect_differential_variables(sys)
11061113
sts = unknowns(sys)
@@ -1277,6 +1284,7 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
12771284
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `ODEProblemExpr`")
12781285
end
12791286
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; check_length,
1287+
t = tspan !== nothing ? tspan[1] : tspan,
12801288
kwargs...)
12811289
linenumbers = get(kwargs, :linenumbers, true)
12821290
kwargs = filter_kwargs(kwargs)
@@ -1322,6 +1330,7 @@ function DAEProblemExpr{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
13221330
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblemExpr`")
13231331
end
13241332
f, du0, u0, p = process_DEProblem(DAEFunctionExpr{iip}, sys, u0map, parammap;
1333+
t = tspan !== nothing ? tspan[1] : tspan,
13251334
implicit_dae = true, du0map = du0map, check_length,
13261335
kwargs...)
13271336
linenumbers = get(kwargs, :linenumbers, true)
@@ -1505,11 +1514,11 @@ function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
15051514
InitializationProblem{true}(sys, args...; kwargs...)
15061515
end
15071516

1508-
function InitializationProblem(sys::AbstractODESystem,
1517+
function InitializationProblem(sys::AbstractODESystem, t,
15091518
u0map::StaticArray,
15101519
args...;
15111520
kwargs...)
1512-
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
1521+
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, t, u0map, args...; kwargs...)
15131522
end
15141523

15151524
function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
@@ -1520,7 +1529,8 @@ function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...
15201529
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
15211530
end
15221531

1523-
function InitializationProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
1532+
function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
1533+
t, u0map = [],
15241534
parammap = DiffEqBase.NullParameters();
15251535
guesses = [],
15261536
check_length = true,
@@ -1530,6 +1540,8 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, u0map =
15301540
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
15311541
end
15321542

1543+
@show u0map
1544+
15331545
if isempty(u0map) && get_initializesystem(sys) !== nothing
15341546
isys = get_initializesystem(sys)
15351547
elseif isempty(u0map) && get_initializesystem(sys) === nothing
@@ -1548,6 +1560,8 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, u0map =
15481560
if warn_initialize_determined && neqs < nunknown
15491561
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
15501562
end
1563+
1564+
parammap isa DiffEqBase.NullParameters ? [independent_variable(sys) => t] : merge(todict(parammap), Dict(independent_variable(sys) => t))
15511565

15521566
if neqs == nunknown
15531567
NonlinearProblem(isys, guesses, parammap)

src/systems/diffeqs/odesystem.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ struct ODESystem <: AbstractODESystem
101101
"""
102102
initializesystem::Union{Nothing, NonlinearSystem}
103103
"""
104+
The schedule for the code generation process.
105+
"""
106+
schedule::Any
107+
"""
104108
Type of the system.
105109
"""
106110
connector_type::Any
@@ -165,11 +169,13 @@ struct ODESystem <: AbstractODESystem
165169
"""
166170
parent::Any
167171

172+
168173
function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
169174
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses,
170-
torn_matching, initializesystem, connector_type, preface, cevents,
171-
devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
172-
tearing_state = nothing,
175+
torn_matching, initializesystem, schedule, connector_type, preface, cevents,
176+
devents, parameter_dependencies,
177+
metadata = nothing, gui_metadata = nothing,
178+
tearing_state = nothing,
173179
substitutions = nothing, complete = false, index_cache = nothing,
174180
discrete_subsystems = nothing, solved_unknowns = nothing,
175181
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
@@ -185,7 +191,8 @@ struct ODESystem <: AbstractODESystem
185191
end
186192
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
187193
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, torn_matching,
188-
initializesystem, connector_type, preface, cevents, devents, parameter_dependencies, metadata,
194+
initializesystem, schedule, connector_type, preface, cevents, devents, parameter_dependencies,
195+
metadata,
189196
gui_metadata, tearing_state, substitutions, complete, index_cache,
190197
discrete_subsystems, solved_unknowns, split_idxs, parent)
191198
end
@@ -202,12 +209,13 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
202209
defaults = _merge(Dict(default_u0), Dict(default_p)),
203210
guesses = Dict(),
204211
initializesystem = nothing,
212+
schedule = nothing,
205213
connector_type = nothing,
206214
preface = nothing,
207215
continuous_events = nothing,
208216
discrete_events = nothing,
209217
parameter_dependencies = nothing,
210-
checks = true,
218+
checks = true,
211219
metadata = nothing,
212220
gui_metadata = nothing)
213221
name === nothing &&
@@ -253,7 +261,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
253261
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
254262
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
255263
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing, initializesystem,
256-
connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
264+
schedule, connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
257265
metadata, gui_metadata, checks = checks)
258266
end
259267

src/systems/nonlinear/initializesystem.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
66
function generate_initializesystem(sys::ODESystem;
77
u0map = Dict(),
88
name = nameof(sys),
9-
guesses = Dict(), check_defguess = false, kwargs...)
9+
guesses = Dict(), check_defguess = false,
10+
default_dd_value = 0.0,
11+
kwargs...)
1012
sts, eqs = unknowns(sys), equations(sys)
1113
idxs_diff = isdiffeq.(eqs)
1214
idxs_alge = .!idxs_diff
@@ -18,11 +20,18 @@ function generate_initializesystem(sys::ODESystem;
1820
defs = merge(defaults(sys), todict(u0map))
1921

2022
full_states = [sts; getfield.((observed(sys)), :lhs)]
23+
set_full_states = Set(full_states)
24+
guesses = todict(guesses)
25+
schedule = getfield(sys, :schedule)
2126

22-
# Refactor to ODESystem construction
23-
# should be ModelingToolkit.guesses(sys)
27+
dd_guess = if schedule !== nothing
28+
guessmap = [x[2]=>get(guesses, x[1], default_dd_value) for x in schedule.dummy_sub]
29+
Dict(filter(x->!isnothing(x[1]) && x[1]set_full_states,guessmap))
30+
else
31+
Dict()
32+
end
2433

25-
guesses = merge(get_guesses(sys), todict(guesses))
34+
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)
2635

2736
for st in full_states
2837
if st keys(defs)
@@ -44,13 +53,13 @@ function generate_initializesystem(sys::ODESystem;
4453
end
4554
end
4655

47-
pars = parameters(sys)
56+
pars = [parameters(sys); independent_variable(sys)]
4857
nleqs = [eqs_ics; observed(sys)]
49-
58+
5059
sys_nl = NonlinearSystem(nleqs,
5160
full_states,
5261
pars;
53-
defaults = merge(ModelingToolkit.defaults(sys), todict(u0)),
62+
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
5463
name,
5564
kwargs...)
5665

src/systems/systems.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ function structural_simplify(
2626
else
2727
newsys = newsys′
2828
end
29-
@set! newsys.parent = complete(sys; split)
29+
if newsys isa ODESystem
30+
schedule = newsys.schedule
31+
@set! newsys.parent = complete(sys; split)
32+
@set! newsys.schedule = schedule
33+
else
34+
@set! newsys.parent = complete(sys; split)
35+
end
3036
newsys = complete(newsys; split)
3137
if newsys′ isa Tuple
3238
idxs = [parameter_index(newsys, i) for i in io[1]]

0 commit comments

Comments
 (0)