Skip to content

Commit f9f7ae1

Browse files
AayushSabharwalChrisRackauckas
authored andcommitted
refactor: support clock systems without index caches
1 parent 36388d0 commit f9f7ae1

File tree

2 files changed

+127
-43
lines changed

2 files changed

+127
-43
lines changed

src/systems/clock_inference.jl

Lines changed: 114 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,15 @@ function generate_discrete_affect(
146146
@static if VERSION < v"1.7"
147147
error("The `generate_discrete_affect` function requires at least Julia 1.7")
148148
end
149-
has_index_cache(osys) && get_index_cache(osys) !== nothing || error("System must have index_cache for clock support")
149+
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
150150
out = Sym{Any}(:out)
151151
appended_parameters = parameters(syss[continuous_id])
152-
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
153-
for p in appended_parameters)
154152
offset = length(appended_parameters)
153+
param_to_idx = if use_index_cache
154+
Dict{Any, ParameterIndex}(p => parameter_index(osys, p) for p in appended_parameters)
155+
else
156+
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
157+
end
155158
affect_funs = []
156159
init_funs = []
157160
svs = []
@@ -170,18 +173,20 @@ function generate_discrete_affect(
170173
push!(fullvars, s)
171174
end
172175
needed_disc_to_cont_obs = []
173-
disc_to_cont_idxs = ParameterIndex[]
176+
if use_index_cache
177+
disc_to_cont_idxs = ParameterIndex[]
178+
else
179+
disc_to_cont_idxs = Int[]
180+
end
174181
for v in inputs[continuous_id]
175182
vv = arguments(v)[1]
176183
if vv in fullvars
177184
push!(needed_disc_to_cont_obs, vv)
178-
# @show param_to_idx[v] v
179-
# @assert param_to_idx[v].portion isa SciMLStructures.Discrete # TOOD: remove
180185
push!(disc_to_cont_idxs, param_to_idx[v])
181186
end
182187
end
183188
append!(appended_parameters, input, unknowns(sys))
184-
cont_to_disc_obs = build_explicit_observed_function(osys,
189+
cont_to_disc_obs = build_explicit_observed_function(use_index_cache ? osys : syss[continuous_id],
185190
needed_cont_to_disc_obs,
186191
throw = false,
187192
expression = true,
@@ -192,36 +197,62 @@ function generate_discrete_affect(
192197
expression = true,
193198
output_type = SVector,
194199
ps = reorder_parameters(osys, parameters(sys)))
200+
ni = length(input)
201+
ns = length(unknowns(sys))
195202
disc = Func(
196203
[
197204
out,
198205
DestructuredArgs(unknowns(osys)),
199-
DestructuredArgs.(reorder_parameters(osys, parameters(osys)))...,
200-
# DestructuredArgs(appended_parameters),
206+
if use_index_cache
207+
DestructuredArgs.(reorder_parameters(osys, parameters(osys)))
208+
else
209+
(DestructuredArgs(appended_parameters),)
210+
end...,
201211
get_iv(sys)
202212
],
203213
[],
204214
let_block)
205-
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
206-
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
215+
if use_index_cache
216+
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
217+
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
218+
else
219+
cont_to_disc_idxs = (offset + 1):(offset += ni)
220+
input_offset = offset
221+
disc_range = (offset + 1):(offset += ns)
222+
end
207223
save_vec = Expr(:ref, :Float64)
208-
for unk in unknowns(sys)
209-
idx = parameter_index(osys, unk)
210-
push!(save_vec.args, :($(parameter_values)(p, $idx)))
224+
if use_index_cache
225+
for unk in unknowns(sys)
226+
idx = parameter_index(osys, unk)
227+
push!(save_vec.args, :($(parameter_values)(p, $idx)))
228+
end
229+
else
230+
for i in 1:ns
231+
push!(save_vec.args, :(p[$(input_offset + i)]))
232+
end
211233
end
212234
empty_disc = isempty(disc_range)
213-
disc_init = :(function (p, t)
214-
d2c_obs = $disc_to_cont_obs
215-
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
216-
result = d2c_obs(disc_state, p..., t)
217-
for (val, i) in zip(result, $disc_to_cont_idxs)
218-
# prevent multiple updates to dependents
219-
_set_parameter_unchecked!(p, val, i; update_dependent = false)
220-
end
221-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
222-
$(SciMLStructures.Discrete()), p)
223-
repack(discretes) # to force recalculation of dependents
224-
end)
235+
disc_init = if use_index_cache
236+
:(function (p, t)
237+
d2c_obs = $disc_to_cont_obs
238+
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
239+
result = d2c_obs(disc_state, p..., t)
240+
for (val, i) in zip(result, $disc_to_cont_idxs)
241+
# prevent multiple updates to dependents
242+
_set_parameter_unchecked!(p, val, i; update_dependent = false)
243+
end
244+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
245+
$(SciMLStructures.Discrete()), p)
246+
repack(discretes) # to force recalculation of dependents
247+
end)
248+
else
249+
:(function (p, t)
250+
d2c_obs = $disc_to_cont_obs
251+
d2c_view = view(p, $disc_to_cont_idxs)
252+
disc_state = view(p, $disc_range)
253+
copyto!(d2c_view, d2c_obs(disc_state, p, t))
254+
end)
255+
end
225256

226257
# @show disc_to_cont_idxs
227258
# @show cont_to_disc_idxs
@@ -230,8 +261,18 @@ function generate_discrete_affect(
230261
@unpack u, p, t = integrator
231262
c2d_obs = $cont_to_disc_obs
232263
d2c_obs = $disc_to_cont_obs
264+
$(
265+
if use_index_cache
266+
:(disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range])
267+
else
268+
quote
269+
c2d_view = view(p, $cont_to_disc_idxs)
270+
d2c_view = view(p, $disc_to_cont_idxs)
271+
disc_unknowns = view(p, $disc_range)
272+
end
273+
end
274+
)
233275
# TODO: find a way to do this without allocating
234-
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
235276
disc = $disc
236277

237278
push!(saved_values.t, t)
@@ -246,26 +287,56 @@ function generate_discrete_affect(
246287
# d2c comes last
247288
# @show t
248289
# @show "incoming", p
249-
result = c2d_obs(integrator.u, p..., t)
250-
for (val, i) in zip(result, $cont_to_disc_idxs)
251-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
252-
end
290+
$(
291+
if use_index_cache
292+
quote
293+
result = c2d_obs(integrator.u, p..., t)
294+
for (val, i) in zip(result, $cont_to_disc_idxs)
295+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
296+
end
297+
end
298+
else
299+
:(copyto!(c2d_view, c2d_obs(integrator.u, p, t)))
300+
end
301+
)
253302
# @show "after c2d", p
254-
if !$empty_disc
255-
disc(disc_unknowns, integrator.u, p..., t)
256-
for (val, i) in zip(disc_unknowns, $disc_range)
257-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
303+
$(
304+
if use_index_cache
305+
quote
306+
if !$empty_disc
307+
disc(disc_unknowns, integrator.u, p..., t)
308+
for (val, i) in zip(disc_unknowns, $disc_range)
309+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
310+
end
311+
end
312+
end
313+
else
314+
:($empty_disc || disc(disc_unknowns, disc_unknowns, p, t))
258315
end
259-
end
316+
)
260317
# @show "after state update", p
261-
result = d2c_obs(disc_unknowns, p..., t)
262-
for (val, i) in zip(result, $disc_to_cont_idxs)
263-
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
264-
end
318+
$(
319+
if use_index_cache
320+
quote
321+
result = d2c_obs(disc_unknowns, p..., t)
322+
for (val, i) in zip(result, $disc_to_cont_idxs)
323+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
324+
end
325+
end
326+
else
327+
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
328+
end
329+
)
265330
# @show "after d2c", p
266-
discretes, repack, _ = $(SciMLStructures.canonicalize)(
267-
$(SciMLStructures.Discrete()), p)
268-
repack(discretes)
331+
$(
332+
if use_index_cache
333+
quote
334+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
335+
$(SciMLStructures.Discrete()), p)
336+
repack(discretes)
337+
end
338+
end
339+
)
269340
end)
270341
sv = SavedValues(Float64, Vector{Float64})
271342
push!(affect_funs, affect!)

test/clock.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf),
118118
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
119119
@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
120120
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
121+
122+
ss_nosplit = structural_simplify(sys; split = false)
123+
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0, y => 0.0], (0.0, Tf),
124+
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
125+
@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
126+
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
121127
# For all inputs in parameters, just initialize them to 0.0, and then set them
122128
# in the callback.
123129

@@ -154,8 +160,11 @@ prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 4.0, 2.0, 3.0], callback = cb)
154160
# ud initializes to kp * (r - yd) + z = 1 * (1 - 0) + 3 = 4
155161
sol2 = solve(prob, Tsit5())
156162
@test sol.u == sol2.u
163+
@test sol_nosplit.u == sol2.u
157164
@test saved_values.t == sol.prob.kwargs[:disc_saved_values][1].t
165+
@test saved_values.t == sol_nosplit.prob.kwargs[:disc_saved_values][1].t
158166
@test saved_values.saveval == sol.prob.kwargs[:disc_saved_values][1].saveval
167+
@test saved_values.saveval == sol_nosplit.prob.kwargs[:disc_saved_values][1].saveval
159168

160169
@info "Testing multi-rate hybrid system"
161170
dt = 0.1
@@ -280,10 +289,13 @@ ci, varmap = infer_clocks(cl)
280289
@test varmap[u] == Continuous()
281290

282291
ss = structural_simplify(cl)
292+
ss_nosplit = structural_simplify(cl; split = false)
283293

284294
if VERSION >= v"1.7"
285295
prob = ODEProblem(ss, [x => 0.0], (0.0, 1.0), [kp => 1.0])
296+
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0], (0.0, 1.0), [kp => 1.0])
286297
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
298+
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
287299

288300
function foo!(dx, x, p, t)
289301
kp, ud1, ud2 = p
@@ -314,6 +326,7 @@ if VERSION >= v"1.7"
314326
sol2 = solve(prob, Tsit5())
315327

316328
@test sol.usol2.u atol=1e-6
329+
@test sol_nosplit.usol2.u atol=1e-6
317330
end
318331

319332
##

0 commit comments

Comments
 (0)