Skip to content

Commit 27e1145

Browse files
fix: run affect! for hybrid systems during problem initialization
1 parent 6bfd3a9 commit 27e1145

File tree

3 files changed

+17
-44
lines changed

3 files changed

+17
-44
lines changed

src/systems/clock_inference.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ function generate_discrete_affect(
157157
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
158158
end
159159
affect_funs = []
160-
init_funs = []
161160
svs = []
162161
clocks = TimeDomain[]
163162
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -249,27 +248,6 @@ function generate_discrete_affect(
249248
end
250249
end
251250
empty_disc = isempty(disc_range)
252-
disc_init = if use_index_cache
253-
:(function (p, t)
254-
d2c_obs = $disc_to_cont_obs
255-
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
256-
result = d2c_obs(disc_state, p..., t)
257-
for (val, i) in zip(result, $disc_to_cont_idxs)
258-
# prevent multiple updates to dependents
259-
_set_parameter_unchecked!(p, val, i; update_dependent = false)
260-
end
261-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
262-
$(SciMLStructures.Discrete()), p)
263-
repack(discretes) # to force recalculation of dependents
264-
end)
265-
else
266-
:(function (p, t)
267-
d2c_obs = $disc_to_cont_obs
268-
d2c_view = view(p, $disc_to_cont_idxs)
269-
disc_state = view(p, $disc_range)
270-
copyto!(d2c_view, d2c_obs(disc_state, p, t))
271-
end)
272-
end
273251

274252
# @show disc_to_cont_idxs
275253
# @show cont_to_disc_idxs
@@ -357,20 +335,15 @@ function generate_discrete_affect(
357335
end)
358336
sv = SavedValues(Float64, Vector{Float64})
359337
push!(affect_funs, affect!)
360-
push!(init_funs, disc_init)
361338
push!(svs, sv)
362339
end
363340
if eval_expression
364341
affects = map(affect_funs) do a
365342
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
366343
end
367-
inits = map(init_funs) do a
368-
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
369-
end
370344
else
371345
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
372-
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
373346
end
374347
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
375-
return affects, inits, clocks, svs, appended_parameters, defaults
348+
return affects, clocks, svs, appended_parameters, defaults
376349
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,9 +1039,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10391039
t = tspan !== nothing ? tspan[1] : tspan,
10401040
check_length, warn_initialize_determined, kwargs...)
10411041
cbs = process_events(sys; callback, kwargs...)
1042-
inits = []
1042+
affects = []
10431043
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1044-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1044+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
10451045
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10461046
if clock isa Clock
10471047
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1062,6 +1062,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10621062
else
10631063
cbs = CallbackSet(cbs, discrete_cbs...)
10641064
end
1065+
# initialize by running affects
1066+
dummy_saveval = (; t = [], saveval = [])
1067+
for affect! in affects
1068+
affect!(
1069+
(; u = u0, p = p, t = tspan !== nothing ? tspan[1] : tspan), dummy_saveval)
1070+
end
10651071
else
10661072
svs = nothing
10671073
end
@@ -1075,13 +1081,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10751081
if svs !== nothing
10761082
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10771083
end
1078-
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1079-
if !isempty(inits)
1080-
for init in inits
1081-
init(prob.p, tspan[1])
1082-
end
1083-
end
1084-
prob
1084+
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
10851085
end
10861086
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
10871087

test/clock.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ eqs = [yd ~ Sample(t, dt)(y)
105105
ss = structural_simplify(sys);
106106

107107
Tf = 1.0
108-
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf),
108+
prob = ODEProblem(ss, [x => 0.1], (0.0, Tf),
109109
[kp => 1.0; ud(k - 1) => 2.0; ud(k - 2) => 2.0])
110-
@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 2.0, 2.0] # yd, Hold(ud), kp, ud(k - 1)
110+
@test sort(vcat(prob.p...)) == [0.1, 1.0, 2.0, 2.1, 2.1] # yd, kp, ud(k-2), ud(k-1), Hold(ud)
111111
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
112112

113113
ss_nosplit = structural_simplify(sys; split = false)
114-
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0, y => 0.0], (0.0, Tf),
114+
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.1], (0.0, Tf),
115115
[kp => 1.0; ud(k - 1) => 2.0; ud(k - 2) => 2.0])
116-
@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 2.0, 2.0] # yd, Hold(ud), kp, ud(k - 1)
116+
@test sort(prob_nosplit.p) == [0.1, 1.0, 2.0, 2.1, 2.1] # yd, kp, ud(k-2), ud(k-1), Hold(ud)
117117
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
118118
# For all inputs in parameters, just initialize them to 0.0, and then set them
119119
# in the callback.
@@ -141,7 +141,7 @@ end
141141
saved_values = SavedValues(Float64, Vector{Float64})
142142
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1)
143143
# kp ud
144-
prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 2.0, 2.0], callback = cb)
144+
prob = ODEProblem(foo!, [0.1], (0.0, Tf), [1.0, 2.1, 2.0], callback = cb)
145145
sol2 = solve(prob, Tsit5())
146146
@test sol.u == sol2.u
147147
@test sol_nosplit.u == sol2.u
@@ -433,8 +433,8 @@ prob = ODEProblem(ssys,
433433
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
434434
(0.0, Tf))
435435

436-
@test prob.ps[Hold(ssys.holder.input.u)] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
437-
@test prob.ps[ssys.controller.x(k - 1)] == 0 # c2d
436+
@test prob.ps[Hold(ssys.holder.input.u)] == 2 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
437+
@test prob.ps[ssys.controller.x(k - 1)] == 1 # c2d
438438
@test prob.ps[Sample(d)(ssys.sampler.input.u)] == 0 # disc state
439439
sol = solve(prob,
440440
Tsit5(),

0 commit comments

Comments
 (0)