Skip to content

Commit 46b493a

Browse files
feat: use new discrete saving, only allow split=true hybrid systems
1 parent 462e0fb commit 46b493a

File tree

4 files changed

+44
-174
lines changed

4 files changed

+44
-174
lines changed

src/systems/clock_inference.jl

Lines changed: 27 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,14 @@ function generate_discrete_affect(
200200
@static if VERSION < v"1.7"
201201
error("The `generate_discrete_affect` function requires at least Julia 1.7")
202202
end
203-
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
203+
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
204+
error("Hybrid systems require `split = true`")
204205
out = Sym{Any}(:out)
205206
appended_parameters = full_parameters(syss[continuous_id])
206207
offset = length(appended_parameters)
207-
param_to_idx = if use_index_cache
208-
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
209-
for p in appended_parameters)
210-
else
211-
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
212-
end
208+
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
209+
for p in appended_parameters)
213210
affect_funs = []
214-
init_funs = []
215-
svs = []
216211
clocks = TimeDomain[]
217212
for (i, (sys, input)) in enumerate(zip(syss, inputs))
218213
i == continuous_id && continue
@@ -228,11 +223,7 @@ function generate_discrete_affect(
228223
push!(fullvars, s)
229224
end
230225
needed_disc_to_cont_obs = []
231-
if use_index_cache
232-
disc_to_cont_idxs = ParameterIndex[]
233-
else
234-
disc_to_cont_idxs = Int[]
235-
end
226+
disc_to_cont_idxs = ParameterIndex[]
236227
for v in inputs[continuous_id]
237228
_v = arguments(v)[1]
238229
if _v in fullvars
@@ -252,7 +243,7 @@ function generate_discrete_affect(
252243
end
253244
append!(appended_parameters, input)
254245
cont_to_disc_obs = build_explicit_observed_function(
255-
use_index_cache ? osys : syss[continuous_id],
246+
osys,
256247
needed_cont_to_disc_obs,
257248
throw = false,
258249
expression = true,
@@ -274,76 +265,20 @@ function generate_discrete_affect(
274265
],
275266
[],
276267
let_block) |> toexpr
277-
if use_index_cache
278-
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
279-
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
280-
else
281-
cont_to_disc_idxs = (offset + 1):(offset += ni)
282-
input_offset = offset
283-
disc_range = (offset + 1):(offset += ns)
284-
end
285-
save_vec = Expr(:ref, :Float64)
286-
if use_index_cache
287-
for unk in unknowns(sys)
288-
idx = parameter_index(osys, unk)
289-
push!(save_vec.args, :($(parameter_values)(p, $idx)))
290-
end
291-
else
292-
for i in 1:ns
293-
push!(save_vec.args, :(p[$(input_offset + i)]))
294-
end
295-
end
268+
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
269+
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
270+
save_expr = :($(SciMLBase.save_discretes!)(integrator, $(i - 1)))
296271
empty_disc = isempty(disc_range)
297-
disc_init = if use_index_cache
298-
:(function (u, p, t)
299-
c2d_obs = $cont_to_disc_obs
300-
d2c_obs = $disc_to_cont_obs
301-
result = c2d_obs(u, p..., t)
302-
for (val, i) in zip(result, $cont_to_disc_idxs)
303-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
304-
end
305-
306-
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
307-
result = d2c_obs(disc_state, p..., t)
308-
for (val, i) in zip(result, $disc_to_cont_idxs)
309-
# prevent multiple updates to dependents
310-
_set_parameter_unchecked!(p, val, i; update_dependent = false)
311-
end
312-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
313-
$(SciMLStructures.Discrete()), p)
314-
repack(discretes) # to force recalculation of dependents
315-
end)
316-
else
317-
:(function (u, p, t)
318-
c2d_obs = $cont_to_disc_obs
319-
d2c_obs = $disc_to_cont_obs
320-
c2d_view = view(p, $cont_to_disc_idxs)
321-
d2c_view = view(p, $disc_to_cont_idxs)
322-
disc_unknowns = view(p, $disc_range)
323-
copyto!(c2d_view, c2d_obs(u, p, t))
324-
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
325-
end)
326-
end
327272

328273
# @show disc_to_cont_idxs
329274
# @show cont_to_disc_idxs
330275
# @show disc_range
331-
affect! = :(function (integrator, saved_values)
276+
affect! = :(function (integrator)
332277
@unpack u, p, t = integrator
333278
c2d_obs = $cont_to_disc_obs
334279
d2c_obs = $disc_to_cont_obs
335-
$(
336-
if use_index_cache
337-
:(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range])
338-
else
339-
quote
340-
c2d_view = view(p, $cont_to_disc_idxs)
341-
d2c_view = view(p, $disc_to_cont_idxs)
342-
disc_unknowns = view(p, $disc_range)
343-
end
344-
end
345-
)
346280
# TODO: find a way to do this without allocating
281+
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
347282
disc = $disc
348283

349284
# Write continuous into to discrete: handles `Sample`
@@ -355,77 +290,41 @@ function generate_discrete_affect(
355290
# d2c comes last
356291
# @show t
357292
# @show "incoming", p
358-
$(
359-
if use_index_cache
293+
result = c2d_obs(integrator.u, p..., t)
294+
for (val, i) in zip(result, $cont_to_disc_idxs)
295+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
296+
end
297+
$(if !empty_disc
360298
quote
361-
result = c2d_obs(integrator.u, p..., t)
362-
for (val, i) in zip(result, $cont_to_disc_idxs)
299+
disc(disc_unknowns, integrator.u, p..., t)
300+
for (val, i) in zip(disc_unknowns, $disc_range)
363301
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
364302
end
365303
end
366-
else
367-
:(copyto!(c2d_view, c2d_obs(integrator.u, p, t)))
368-
end
369-
)
304+
end)
370305
# @show "after c2d", p
371-
$(
372-
if use_index_cache
373-
quote
374-
if !$empty_disc
375-
disc(disc_unknowns, integrator.u, p..., t)
376-
for (val, i) in zip(disc_unknowns, $disc_range)
377-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
378-
end
379-
end
380-
end
381-
else
382-
:($empty_disc || disc(disc_unknowns, disc_unknowns, p, t))
383-
end
384-
)
385306
# @show "after state update", p
386-
$(
387-
if use_index_cache
388-
quote
389-
result = d2c_obs(disc_unknowns, p..., t)
390-
for (val, i) in zip(result, $disc_to_cont_idxs)
391-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
392-
end
393-
end
394-
else
395-
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
307+
result = d2c_obs(disc_unknowns, p..., t)
308+
for (val, i) in zip(result, $disc_to_cont_idxs)
309+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
396310
end
397-
)
398311

399-
push!(saved_values.t, t)
400-
push!(saved_values.saveval, $save_vec)
312+
$save_expr
401313

402314
# @show "after d2c", p
403-
$(
404-
if use_index_cache
405-
quote
406-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
407-
$(SciMLStructures.Discrete()), p)
408-
repack(discretes)
409-
end
410-
end
411-
)
315+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
316+
$(SciMLStructures.Discrete()), p)
317+
repack(discretes)
412318
end)
413-
sv = SavedValues(Float64, Vector{Float64})
414319
push!(affect_funs, affect!)
415-
push!(init_funs, disc_init)
416-
push!(svs, sv)
417320
end
418321
if eval_expression
419322
affects = map(affect_funs) do a
420323
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
421324
end
422-
inits = map(init_funs) do a
423-
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
424-
end
425325
else
426326
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
427-
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
428327
end
429328
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
430-
return affects, inits, clocks, svs, appended_parameters, defaults
329+
return affects, clocks, appended_parameters, defaults
431330
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,15 +1086,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10861086
t = tspan !== nothing ? tspan[1] : tspan,
10871087
check_length, warn_initialize_determined, kwargs...)
10881088
cbs = process_events(sys; callback, kwargs...)
1089-
inits = []
10901089
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1091-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1092-
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
1090+
affects, clocks = ModelingToolkit.generate_discrete_affect(sys, dss...)
1091+
discrete_cbs = map(affects, clocks) do affect, clock
10931092
if clock isa Clock
1094-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1093+
PeriodicCallback(affect, clock.dt;
10951094
final_affect = true, initial_affect = true)
10961095
elseif clock isa SolverStepClock
1097-
affect = DiscreteSaveAffect(affect, sv)
10981096
DiscreteCallback(Returns(true), affect,
10991097
initialize = (c, u, t, integrator) -> affect(integrator))
11001098
else
@@ -1110,8 +1108,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
11101108
else
11111109
cbs = CallbackSet(cbs, discrete_cbs...)
11121110
end
1113-
else
1114-
svs = nothing
11151111
end
11161112
kwargs = filter_kwargs(kwargs)
11171113
pt = something(get_metadata(sys), StandardODEProblem())
@@ -1120,17 +1116,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
11201116
if cbs !== nothing
11211117
kwargs1 = merge(kwargs1, (callback = cbs,))
11221118
end
1123-
if svs !== nothing
1124-
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1125-
end
11261119

1127-
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1128-
if !isempty(inits)
1129-
for init in inits
1130-
# init(prob.u0, prob.p, tspan[1])
1131-
end
1132-
end
1133-
prob
1120+
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
11341121
end
11351122
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
11361123

test/clock.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,12 @@ Tf = 1.0
111111
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
112112
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
113113
# create integrator so callback is evaluated at t=0 and we can test correct param values
114-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
114+
int = init(prob, Tsit5())
115115
@test sort(vcat(int.p...)) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
116116
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
117117
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
118-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
118+
sol = solve(prob, Tsit5())
119119

120-
ss_nosplit = structural_simplify(sys; split = false)
121-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
122-
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
123-
int = init(prob_nosplit, Tsit5(); kwargshandle = KeywordArgSilent)
124-
@test sort(int.p) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
125-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
126-
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
127-
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
128120
# For all inputs in parameters, just initialize them to 0.0, and then set them
129121
# in the callback.
130122

@@ -155,11 +147,8 @@ cb = PeriodicCallback(
155147
prob = ODEProblem(foo!, [0.1], (0.0, Tf), [1.0, 2.1, 2.0], callback = cb)
156148
sol2 = solve(prob, Tsit5())
157149
@test sol.u == sol2.u
158-
@test sol_nosplit.u == sol2.u
159-
@test saved_values.t == sol.prob.kwargs[:disc_saved_values][1].t
160-
@test saved_values.t == sol_nosplit.prob.kwargs[:disc_saved_values][1].t
161-
@test saved_values.saveval == sol.prob.kwargs[:disc_saved_values][1].saveval
162-
@test saved_values.saveval == sol_nosplit.prob.kwargs[:disc_saved_values][1].saveval
150+
@test saved_values.t sol.discretes[1].t
151+
@test saved_values.saveval == sol.ps[[ud, ud(k - 1)]]
163152

164153
@info "Testing multi-rate hybrid system"
165154
dt = 0.1
@@ -284,13 +273,10 @@ ci, varmap = infer_clocks(cl)
284273
@test varmap[u] == Continuous()
285274

286275
ss = structural_simplify(cl)
287-
ss_nosplit = structural_simplify(cl; split = false)
288276

289277
if VERSION >= v"1.7"
290278
prob = ODEProblem(ss, [x => 0.0], (0.0, 1.0), [kp => 1.0])
291-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0], (0.0, 1.0), [kp => 1.0])
292-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
293-
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
279+
sol = solve(prob, Tsit5())
294280

295281
function foo!(dx, x, p, t)
296282
kp, ud1, ud2 = p
@@ -321,7 +307,6 @@ if VERSION >= v"1.7"
321307
sol2 = solve(prob, Tsit5())
322308

323309
@test sol.usol2.u atol=1e-6
324-
@test sol_nosplit.usol2.u atol=1e-6
325310
end
326311

327312
##
@@ -440,13 +425,12 @@ y = res.y[:]
440425
prob = ODEProblem(ssys,
441426
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
442427
(0.0, Tf))
443-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
428+
int = init(prob, Tsit5())
444429
@test_broken int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
445430
@test int.ps[ssys.controller.x] == 1 # c2d
446431
@test int.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
447432
sol = solve(prob,
448433
Tsit5(),
449-
kwargshandle = KeywordArgSilent,
450434
abstol = 1e-8,
451435
reltol = 1e-8)
452436
@test_skip begin
@@ -508,9 +492,9 @@ end
508492

509493
@mtkbuild model = FirstOrderWithStepCounter()
510494
prob = ODEProblem(model, [], (0.0, 10.0))
511-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
495+
sol = solve(prob, Tsit5())
512496

513-
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
497+
@test sol.discretes[1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
514498
@test_nowarn ModelingToolkit.build_explicit_observed_function(
515499
model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1])
516500

@@ -519,12 +503,12 @@ eqs = [D(y) ~ Hold(x)
519503
x ~ x(k - 1) + x(k - 2)]
520504
@mtkbuild sys = ODESystem(eqs, t)
521505
prob = ODEProblem(sys, [], (0.0, 10.0))
522-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
506+
int = init(prob, Tsit5())
523507
@test int.ps[x] == 2.0
524508
@test int.ps[x(k - 1)] == 1.0
525509

526510
@test_throws ErrorException ODEProblem(sys, [], (0.0, 10.0), [x => 2.0])
527511
prob = ODEProblem(sys, [], (0.0, 10.0), [x(k - 1) => 2.0])
528-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
512+
int = init(prob, Tsit5())
529513
@test int.ps[x] == 3.0
530514
@test int.ps[x(k - 1)] == 2.0

test/parameter_dependencies.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,18 @@ end
8686
Tf = 1.0
8787
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
8888
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
89-
@test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent)
89+
@test_nowarn solve(prob, Tsit5())
9090

9191
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
9292
discrete_events = [[0.5] => [kp ~ 2.0]])
9393
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
9494
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
9595
@test prob.ps[kp] == 1.0
9696
@test prob.ps[kq] == 2.0
97-
@test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
97+
@test_nowarn solve(prob, Tsit5())
9898
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
9999
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
100-
integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent)
100+
integ = init(prob, Tsit5())
101101
@test integ.ps[kp] == 1.0
102102
@test integ.ps[kq] == 2.0
103103
step!(integ, 0.6)

0 commit comments

Comments
 (0)