Skip to content

Commit 74a3814

Browse files
committed
initialize c2d communication parameter
closes #2356
1 parent 1e36fc7 commit 74a3814

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

src/systems/clock_inference.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
150150
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
151151
offset = length(appended_parameters)
152152
affect_funs = []
153+
init_funs = []
153154
svs = []
154155
clocks = TimeDomain[]
155156
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -202,6 +203,14 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
202203
push!(save_vec.args, :(p[$(input_offset + i)]))
203204
end
204205
empty_disc = isempty(disc_range)
206+
207+
disc_init = :(function (p, t)
208+
d2c_obs = $disc_to_cont_obs
209+
d2c_view = view(p, $disc_to_cont_idxs)
210+
disc_state = view(p, $disc_range)
211+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
212+
end)
213+
205214
affect! = :(function (integrator, saved_values)
206215
@unpack u, p, t = integrator
207216
c2d_obs = $cont_to_disc_obs
@@ -223,15 +232,20 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
223232
end)
224233
sv = SavedValues(Float64, Vector{Float64})
225234
push!(affect_funs, affect!)
235+
push!(init_funs, disc_init)
226236
push!(svs, sv)
227237
end
228238
if eval_expression
229239
affects = map(affect_funs) do a
230240
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
231241
end
242+
inits = map(init_funs) do a
243+
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
244+
end
232245
else
233246
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
247+
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
234248
end
235249
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
236-
return affects, clocks, svs, appended_parameters, defaults
250+
return affects, inits, clocks, svs, appended_parameters, defaults
237251
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
938938
has_difference = has_difference,
939939
check_length, kwargs...)
940940
cbs = process_events(sys; callback, has_difference, kwargs...)
941+
inits = []
941942
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
942-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
943+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
943944
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
944945
if clock isa Clock
945946
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -969,7 +970,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
969970
if svs !== nothing
970971
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
971972
end
972-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
973+
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
974+
if !isempty(inits)
975+
for init in inits
976+
init(prob.p, tspan[1])
977+
end
978+
end
979+
prob
973980
end
974981
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
975982

@@ -1038,8 +1045,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10381045
h = h_oop
10391046
u0 = h(p, tspan[1])
10401047
cbs = process_events(sys; callback, has_difference, kwargs...)
1048+
inits = []
10411049
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1042-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1050+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
10431051
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10441052
if clock isa Clock
10451053
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1068,7 +1076,13 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10681076
if svs !== nothing
10691077
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10701078
end
1071-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1079+
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1080+
if !isempty(inits)
1081+
for init in inits
1082+
init(prob.p, tspan[1])
1083+
end
1084+
end
1085+
prob
10721086
end
10731087

10741088
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1092,8 +1106,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10921106
h(p, t) = h_oop(p, t)
10931107
u0 = h(p, tspan[1])
10941108
cbs = process_events(sys; callback, has_difference, kwargs...)
1109+
inits = []
10951110
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1096-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
1111+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
10971112
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10981113
if clock isa Clock
10991114
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
@@ -1133,8 +1148,15 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11331148
else
11341149
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
11351150
end
1136-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p; noise_rate_prototype =
1151+
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1152+
noise_rate_prototype =
11371153
noise_rate_prototype, kwargs1..., kwargs...)
1154+
if !isempty(inits)
1155+
for init in inits
1156+
init(prob.p, tspan[1])
1157+
end
1158+
end
1159+
prob
11381160
end
11391161

11401162
"""

test/clock.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,7 @@ prob = ODEProblem(ssys,
428428
(0.0, 10.0),
429429
[model.controller.kp => 2.0; model.controller.ki => 2.0])
430430

431-
@test_broken prob.p[9] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
432-
prob.p[9] = 1 # constant output * kp
431+
@test prob.p[9] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
433432
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
434433
# plot([sol(timevec .+ 1e-12, idxs=model.plant.output.u) y])
435434

0 commit comments

Comments
 (0)