Skip to content

Commit e40faca

Browse files
feat: use new discrete saving, only allow split=true hybrid systems
1 parent 522ad07 commit e40faca

File tree

4 files changed

+34
-150
lines changed

4 files changed

+34
-150
lines changed

src/systems/clock_inference.jl

Lines changed: 17 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,14 @@ function generate_discrete_affect(
203203
@static if VERSION < v"1.7"
204204
error("The `generate_discrete_affect` function requires at least Julia 1.7")
205205
end
206-
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
206+
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
207+
error("Hybrid systems require `split = true`")
207208
out = Sym{Any}(:out)
208209
appended_parameters = full_parameters(syss[continuous_id])
209210
offset = length(appended_parameters)
210-
param_to_idx = if use_index_cache
211-
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
212-
for p in appended_parameters)
213-
else
214-
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
215-
end
211+
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
212+
for p in appended_parameters)
216213
affect_funs = []
217-
init_funs = []
218-
svs = []
219214
clocks = TimeDomain[]
220215
for (i, (sys, input)) in enumerate(zip(syss, inputs))
221216
i == continuous_id && continue
@@ -231,11 +226,7 @@ function generate_discrete_affect(
231226
push!(fullvars, s)
232227
end
233228
needed_disc_to_cont_obs = []
234-
if use_index_cache
235-
disc_to_cont_idxs = ParameterIndex[]
236-
else
237-
disc_to_cont_idxs = Int[]
238-
end
229+
disc_to_cont_idxs = ParameterIndex[]
239230
for v in inputs[continuous_id]
240231
_v = arguments(v)[1]
241232
if _v in fullvars
@@ -255,7 +246,7 @@ function generate_discrete_affect(
255246
end
256247
append!(appended_parameters, input)
257248
cont_to_disc_obs = build_explicit_observed_function(
258-
use_index_cache ? osys : syss[continuous_id],
249+
osys,
259250
needed_cont_to_disc_obs,
260251
throw = false,
261252
expression = true,
@@ -281,56 +272,16 @@ function generate_discrete_affect(
281272
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
282273
save_expr = :($(SciMLBase.save_discretes!)(integrator, $i))
283274
empty_disc = isempty(disc_range)
284-
disc_init = if use_index_cache
285-
:(function (u, p, t)
286-
c2d_obs = $cont_to_disc_obs
287-
d2c_obs = $disc_to_cont_obs
288-
result = c2d_obs(u, p..., t)
289-
for (val, i) in zip(result, $cont_to_disc_idxs)
290-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
291-
end
292-
293-
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
294-
result = d2c_obs(disc_state, p..., t)
295-
for (val, i) in zip(result, $disc_to_cont_idxs)
296-
# prevent multiple updates to dependents
297-
_set_parameter_unchecked!(p, val, i; update_dependent = false)
298-
end
299-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
300-
$(SciMLStructures.Discrete()), p)
301-
repack(discretes) # to force recalculation of dependents
302-
end)
303-
else
304-
:(function (u, p, t)
305-
c2d_obs = $cont_to_disc_obs
306-
d2c_obs = $disc_to_cont_obs
307-
c2d_view = view(p, $cont_to_disc_idxs)
308-
d2c_view = view(p, $disc_to_cont_idxs)
309-
disc_unknowns = view(p, $disc_range)
310-
copyto!(c2d_view, c2d_obs(u, p, t))
311-
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
312-
end)
313-
end
314275

315276
# @show disc_to_cont_idxs
316277
# @show cont_to_disc_idxs
317278
# @show disc_range
318-
affect! = :(function (integrator, saved_values)
279+
affect! = :(function (integrator)
319280
@unpack u, p, t = integrator
320281
c2d_obs = $cont_to_disc_obs
321282
d2c_obs = $disc_to_cont_obs
322-
$(
323-
if use_index_cache
324-
:(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range])
325-
else
326-
quote
327-
c2d_view = view(p, $cont_to_disc_idxs)
328-
d2c_view = view(p, $disc_to_cont_idxs)
329-
disc_unknowns = view(p, $disc_range)
330-
end
331-
end
332-
)
333283
# TODO: find a way to do this without allocating
284+
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
334285
disc = $disc
335286

336287
# Write continuous into to discrete: handles `Sample`
@@ -353,69 +304,31 @@ function generate_discrete_affect(
353304
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
354305
end
355306
end
356-
else
357-
:(copyto!(c2d_view, c2d_obs(integrator.u, p, t)))
358-
end
359-
)
307+
end)
360308
# @show "after c2d", p
361-
$(
362-
if use_index_cache
363-
quote
364-
if !$empty_disc
365-
disc(disc_unknowns, integrator.u, p..., t)
366-
for (val, i) in zip(disc_unknowns, $disc_range)
367-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
368-
end
369-
end
370-
end
371-
else
372-
:($empty_disc || disc(disc_unknowns, disc_unknowns, p, t))
373-
end
374-
)
375309
# @show "after state update", p
376-
$(
377-
if use_index_cache
378-
quote
379-
result = d2c_obs(disc_unknowns, p..., t)
380-
for (val, i) in zip(result, $disc_to_cont_idxs)
381-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
382-
end
383-
end
384-
else
385-
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
310+
result = d2c_obs(disc_unknowns, p..., t)
311+
for (val, i) in zip(result, $disc_to_cont_idxs)
312+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
386313
end
387-
)
388314

389-
push!(saved_values.t, t)
390-
push!(saved_values.saveval, $save_vec)
315+
$save_expr
391316

392317
# @show "after d2c", p
393-
$(
394-
if use_index_cache
395-
quote
396-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
397-
$(SciMLStructures.Discrete()), p)
398-
repack(discretes)
399-
end
400-
end
401-
)
318+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
319+
$(SciMLStructures.Discrete()), p)
320+
repack(discretes)
402321
end)
403322

404323
push!(affect_funs, affect!)
405-
push!(init_funs, disc_init)
406-
push!(svs, sv)
407324
end
408325
if eval_expression
409326
affects = map(affect_funs) do a
410327
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
411328
end
412-
inits = map(init_funs) do a
413-
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
414-
end
415329
else
416330
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
417-
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
418331
end
419332
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
420-
return affects, inits, clocks, svs, appended_parameters, defaults
333+
return affects, clocks, appended_parameters, defaults
421334
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,15 +1087,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10871087
t = tspan !== nothing ? tspan[1] : tspan,
10881088
check_length, warn_initialize_determined, kwargs...)
10891089
cbs = process_events(sys; callback, kwargs...)
1090-
inits = []
10911090
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1092-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1093-
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
1091+
affects, clocks = ModelingToolkit.generate_discrete_affect(sys, dss...)
1092+
discrete_cbs = map(affects, clocks) do affect, clock
10941093
if clock isa Clock
1095-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1094+
PeriodicCallback(affect, clock.dt;
10961095
final_affect = true, initial_affect = true)
10971096
elseif clock isa SolverStepClock
1098-
affect = DiscreteSaveAffect(affect, sv)
10991097
DiscreteCallback(Returns(true), affect,
11001098
initialize = (c, u, t, integrator) -> affect(integrator))
11011099
else
@@ -1111,8 +1109,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
11111109
else
11121110
cbs = CallbackSet(cbs, discrete_cbs...)
11131111
end
1114-
else
1115-
svs = nothing
11161112
end
11171113
kwargs = filter_kwargs(kwargs)
11181114
pt = something(get_metadata(sys), StandardODEProblem())
@@ -1121,17 +1117,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
11211117
if cbs !== nothing
11221118
kwargs1 = merge(kwargs1, (callback = cbs,))
11231119
end
1124-
if svs !== nothing
1125-
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
1126-
end
11271120

1128-
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1129-
if !isempty(inits)
1130-
for init in inits
1131-
# init(prob.u0, prob.p, tspan[1])
1132-
end
1133-
end
1134-
prob
1121+
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
11351122
end
11361123
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
11371124

test/clock.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,12 @@ Tf = 1.0
112112
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
113113
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
114114
# create integrator so callback is evaluated at t=0 and we can test correct param values
115-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
115+
int = init(prob, Tsit5())
116116
@test sort(vcat(int.p...)) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
117117
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
118118
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
119-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
119+
sol = solve(prob, Tsit5())
120120

121-
ss_nosplit = structural_simplify(sys; split = false)
122-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
123-
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0])
124-
int = init(prob_nosplit, Tsit5(); kwargshandle = KeywordArgSilent)
125-
@test sort(int.p) == [0.1, 1.0, 2.1, 2.1, 2.1] # yd, kp, ud(k-1), ud, Hold(ud)
126-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
127-
[kp => 1.0; ud(k - 1) => 2.1; ud(k - 2) => 2.0]) # recreate problem to empty saved values
128-
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
129121
# For all inputs in parameters, just initialize them to 0.0, and then set them
130122
# in the callback.
131123

@@ -156,11 +148,8 @@ cb = PeriodicCallback(
156148
prob = ODEProblem(foo!, [0.1], (0.0, Tf), [1.0, 2.1, 2.0], callback = cb)
157149
sol2 = solve(prob, Tsit5())
158150
@test sol.u == sol2.u
159-
@test sol_nosplit.u == sol2.u
160-
@test saved_values.t == sol.prob.kwargs[:disc_saved_values][1].t
161-
@test saved_values.t == sol_nosplit.prob.kwargs[:disc_saved_values][1].t
162-
@test saved_values.saveval == sol.prob.kwargs[:disc_saved_values][1].saveval
163-
@test saved_values.saveval == sol_nosplit.prob.kwargs[:disc_saved_values][1].saveval
151+
@test saved_values.t sol.discretes[1].t
152+
@test saved_values.saveval == sol.ps[[ud, ud(k - 1)]]
164153

165154
@info "Testing multi-rate hybrid system"
166155
dt = 0.1
@@ -285,13 +274,10 @@ ci, varmap = infer_clocks(cl)
285274
@test varmap[u] == Continuous()
286275

287276
ss = structural_simplify(cl)
288-
ss_nosplit = structural_simplify(cl; split = false)
289277

290278
if VERSION >= v"1.7"
291279
prob = ODEProblem(ss, [x => 0.0], (0.0, 1.0), [kp => 1.0])
292-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0], (0.0, 1.0), [kp => 1.0])
293-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
294-
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
280+
sol = solve(prob, Tsit5())
295281

296282
function foo!(dx, x, p, t)
297283
kp, ud1, ud2 = p
@@ -322,7 +308,6 @@ if VERSION >= v"1.7"
322308
sol2 = solve(prob, Tsit5())
323309

324310
@test sol.usol2.u atol=1e-6
325-
@test sol_nosplit.usol2.u atol=1e-6
326311
end
327312

328313
##
@@ -441,13 +426,12 @@ y = res.y[:]
441426
prob = ODEProblem(ssys,
442427
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
443428
(0.0, Tf))
444-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
429+
int = init(prob, Tsit5())
445430
@test_broken int.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
446431
@test int.ps[ssys.controller.x] == 1 # c2d
447432
@test int.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
448433
sol = solve(prob,
449434
Tsit5(),
450-
kwargshandle = KeywordArgSilent,
451435
abstol = 1e-8,
452436
reltol = 1e-8)
453437
@test_skip begin
@@ -509,9 +493,9 @@ end
509493

510494
@mtkbuild model = FirstOrderWithStepCounter()
511495
prob = ODEProblem(model, [], (0.0, 10.0))
512-
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
496+
sol = solve(prob, Tsit5())
513497

514-
@test sol.prob.kwargs[:disc_saved_values][1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
498+
@test sol.discretes[1].t == sol.t[1:2:end] # Test that the discrete-tiem system executed at every step of the continuous solver. The solver saves each time step twice, one state value before discrete affect and one after.
515499
@test_nowarn ModelingToolkit.build_explicit_observed_function(
516500
model, model.counter.ud)(sol.u[1], prob.p..., sol.t[1])
517501

@@ -520,12 +504,12 @@ eqs = [D(y) ~ Hold(x)
520504
x ~ x(k - 1) + x(k - 2)]
521505
@mtkbuild sys = ODESystem(eqs, t)
522506
prob = ODEProblem(sys, [], (0.0, 10.0))
523-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
507+
int = init(prob, Tsit5())
524508
@test int.ps[x] == 2.0
525509
@test int.ps[x(k - 1)] == 1.0
526510

527511
@test_throws ErrorException ODEProblem(sys, [], (0.0, 10.0), [x => 2.0])
528512
prob = ODEProblem(sys, [], (0.0, 10.0), [x(k - 1) => 2.0])
529-
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
513+
int = init(prob, Tsit5())
530514
@test int.ps[x] == 3.0
531515
@test int.ps[x(k - 1)] == 2.0

test/parameter_dependencies.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,18 @@ end
8686
Tf = 1.0
8787
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
8888
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
89-
@test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent)
89+
@test_nowarn solve(prob, Tsit5())
9090

9191
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
9292
discrete_events = [[0.5] => [kp ~ 2.0]])
9393
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
9494
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
9595
@test prob.ps[kp] == 1.0
9696
@test prob.ps[kq] == 2.0
97-
@test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
97+
@test_nowarn solve(prob, Tsit5())
9898
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
9999
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
100-
integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent)
100+
integ = init(prob, Tsit5())
101101
@test integ.ps[kp] == 1.0
102102
@test integ.ps[kq] == 2.0
103103
step!(integ, 0.6)

0 commit comments

Comments
 (0)