Skip to content

Commit ef728cc

Browse files
feat: use new discrete saving, only allow split=true hybrid systems
1 parent bfc73a3 commit ef728cc

File tree

3 files changed

+27
-130
lines changed

3 files changed

+27
-130
lines changed

src/systems/clock_inference.jl

Lines changed: 17 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,14 @@ function generate_discrete_affect(
203203
@static if VERSION < v"1.7"
204204
error("The `generate_discrete_affect` function requires at least Julia 1.7")
205205
end
206-
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
206+
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
207+
error("Hybrid systems require `split = true`")
207208
out = Sym{Any}(:out)
208209
appended_parameters = full_parameters(syss[continuous_id])
209210
offset = length(appended_parameters)
210-
param_to_idx = if use_index_cache
211-
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
212-
for p in appended_parameters)
213-
else
214-
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
215-
end
211+
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
212+
for p in appended_parameters)
216213
affect_funs = []
217-
init_funs = []
218-
svs = []
219214
clocks = TimeDomain[]
220215
for (i, (sys, input)) in enumerate(zip(syss, inputs))
221216
i == continuous_id && continue
@@ -231,11 +226,7 @@ function generate_discrete_affect(
231226
push!(fullvars, s)
232227
end
233228
needed_disc_to_cont_obs = []
234-
if use_index_cache
235-
disc_to_cont_idxs = ParameterIndex[]
236-
else
237-
disc_to_cont_idxs = Int[]
238-
end
229+
disc_to_cont_idxs = ParameterIndex[]
239230
for v in inputs[continuous_id]
240231
_v = arguments(v)[1]
241232
if _v in fullvars
@@ -255,7 +246,7 @@ function generate_discrete_affect(
255246
end
256247
append!(appended_parameters, input)
257248
cont_to_disc_obs = build_explicit_observed_function(
258-
use_index_cache ? osys : syss[continuous_id],
249+
osys,
259250
needed_cont_to_disc_obs,
260251
throw = false,
261252
expression = true,
@@ -281,56 +272,16 @@ function generate_discrete_affect(
281272
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
282273
save_expr = :($(SciMLBase.save_discretes!)(integrator, $i))
283274
empty_disc = isempty(disc_range)
284-
disc_init = if use_index_cache
285-
:(function (u, p, t)
286-
c2d_obs = $cont_to_disc_obs
287-
d2c_obs = $disc_to_cont_obs
288-
result = c2d_obs(u, p..., t)
289-
for (val, i) in zip(result, $cont_to_disc_idxs)
290-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
291-
end
292-
293-
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
294-
result = d2c_obs(disc_state, p..., t)
295-
for (val, i) in zip(result, $disc_to_cont_idxs)
296-
# prevent multiple updates to dependents
297-
_set_parameter_unchecked!(p, val, i; update_dependent = false)
298-
end
299-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
300-
$(SciMLStructures.Discrete()), p)
301-
repack(discretes) # to force recalculation of dependents
302-
end)
303-
else
304-
:(function (u, p, t)
305-
c2d_obs = $cont_to_disc_obs
306-
d2c_obs = $disc_to_cont_obs
307-
c2d_view = view(p, $cont_to_disc_idxs)
308-
d2c_view = view(p, $disc_to_cont_idxs)
309-
disc_unknowns = view(p, $disc_range)
310-
copyto!(c2d_view, c2d_obs(u, p, t))
311-
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
312-
end)
313-
end
314275

315276
# @show disc_to_cont_idxs
316277
# @show cont_to_disc_idxs
317278
# @show disc_range
318-
affect! = :(function (integrator, saved_values)
279+
affect! = :(function (integrator)
319280
@unpack u, p, t = integrator
320281
c2d_obs = $cont_to_disc_obs
321282
d2c_obs = $disc_to_cont_obs
322-
$(
323-
if use_index_cache
324-
:(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range])
325-
else
326-
quote
327-
c2d_view = view(p, $cont_to_disc_idxs)
328-
d2c_view = view(p, $disc_to_cont_idxs)
329-
disc_unknowns = view(p, $disc_range)
330-
end
331-
end
332-
)
333283
# TODO: find a way to do this without allocating
284+
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
334285
disc = $disc
335286

336287
# Write continuous into to discrete: handles `Sample`
@@ -353,71 +304,32 @@ function generate_discrete_affect(
353304
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
354305
end
355306
end
356-
else
357-
:(copyto!(c2d_view, c2d_obs(integrator.u, p, t)))
358-
end
359-
)
307+
end)
360308
# @show "after c2d", p
361-
$(
362-
if use_index_cache
363-
quote
364-
if !$empty_disc
365-
disc(disc_unknowns, integrator.u, p..., t)
366-
for (val, i) in zip(disc_unknowns, $disc_range)
367-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
368-
end
369-
end
370-
end
371-
else
372-
:($empty_disc || disc(disc_unknowns, disc_unknowns, p, t))
373-
end
374-
)
375309
# @show "after state update", p
376-
$(
377-
if use_index_cache
378-
quote
379-
result = d2c_obs(disc_unknowns, p..., t)
380-
for (val, i) in zip(result, $disc_to_cont_idxs)
381-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
382-
end
383-
end
384-
else
385-
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
310+
result = d2c_obs(disc_unknowns, p..., t)
311+
for (val, i) in zip(result, $disc_to_cont_idxs)
312+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
386313
end
387-
)
388314

389-
push!(saved_values.t, t)
390-
push!(saved_values.saveval, $save_vec)
315+
$save_expr
391316

392317
# @show "after d2c", p
393-
$(
394-
if use_index_cache
395-
quote
396-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
397-
$(SciMLStructures.Discrete()), p)
398-
repack(discretes)
399-
end
400-
end
401-
)
318+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
319+
$(SciMLStructures.Discrete()), p)
320+
repack(discretes)
402321
end)
403322

404323
push!(affect_funs, affect!)
405-
push!(init_funs, disc_init)
406-
push!(svs, sv)
407324
end
408325
if eval_expression
409326
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
410-
inits = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), init_funs)
411327
else
412328
affects = map(affect_funs) do a
413329
drop_expr(RuntimeGeneratedFunction(
414330
eval_module, eval_module, toexpr(LiteralExpr(a))))
415331
end
416-
inits = map(init_funs) do a
417-
drop_expr(RuntimeGeneratedFunction(
418-
eval_module, eval_module, toexpr(LiteralExpr(a))))
419-
end
420332
end
421333
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
422-
return affects, inits, clocks, svs, appended_parameters, defaults
334+
return affects, clocks, appended_parameters, defaults
423335
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,14 +1008,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10081008
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
10091009
inits = []
10101010
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1011-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(
1011+
affects, clocks = ModelingToolkit.generate_discrete_affect(
10121012
sys, dss...; eval_expression, eval_module)
1013-
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
1013+
discrete_cbs = map(affects, clocks) do affect, clock
10141014
if clock isa Clock
1015-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1015+
PeriodicCallback(affect, clock.dt;
10161016
final_affect = true, initial_affect = true)
10171017
elseif clock isa SolverStepClock
1018-
affect = DiscreteSaveAffect(affect, sv)
10191018
DiscreteCallback(Returns(true), affect,
10201019
initialize = (c, u, t, integrator) -> affect(integrator))
10211020
else
@@ -1031,8 +1030,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10311030
else
10321031
cbs = CallbackSet(cbs, discrete_cbs...)
10331032
end
1034-
else
1035-
svs = nothing
10361033
end
10371034
kwargs = filter_kwargs(kwargs)
10381035
pt = something(get_metadata(sys), StandardODEProblem())
@@ -1041,17 +1038,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10411038
if cbs !== nothing
10421039
kwargs1 = merge(kwargs1, (callback = cbs,))
10431040
end
1044-
if svs !== nothing
1045-
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1046-
end
10471041

1048-
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1049-
if !isempty(inits)
1050-
for init in inits
1051-
# init(prob.u0, prob.p, tspan[1])
1052-
end
1053-
end
1054-
prob
1042+
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
10551043
end
10561044
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
10571045

test/parameter_dependencies.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,22 +173,19 @@ end
173173
@test_skip begin
174174
Tf = 1.0
175175
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
176-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
177-
yd(k - 2) => 2.0])
178-
@test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent)
176+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
177+
@test_nowarn solve(prob, Tsit5())
179178

180179
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
181180
discrete_events = [[0.5] => [kp ~ 2.0]])
182181
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
183-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
184-
yd(k - 2) => 2.0])
182+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
185183
@test prob.ps[kp] == 1.0
186184
@test prob.ps[kq] == 2.0
187-
@test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
185+
@test_nowarn solve(prob, Tsit5())
188186
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
189-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
190-
yd(k - 2) => 2.0])
191-
integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent)
187+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
188+
integ = init(prob, Tsit5())
192189
@test integ.ps[kp] == 1.0
193190
@test integ.ps[kq] == 2.0
194191
step!(integ, 0.6)

0 commit comments

Comments
 (0)