Skip to content

Commit f2aad04

Browse files
refactor: store parameters from different clock partitions separately
1 parent 3b31b17 commit f2aad04

File tree

2 files changed

+206
-79
lines changed

2 files changed

+206
-79
lines changed

src/systems/index_cache.jl

Lines changed: 154 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ const UnknownIndexMap = Dict{
2424

2525
struct IndexCache
2626
unknown_idx::UnknownIndexMap
27+
discrete_clocks::Dict{Union{Symbol, BasicSymbolic}, Int}
2728
discrete_idx::ParamIndexMap
2829
tunable_idx::ParamIndexMap
2930
constant_idx::ParamIndexMap
3031
dependent_idx::ParamIndexMap
3132
nonnumeric_idx::ParamIndexMap
32-
discrete_buffer_sizes::Vector{BufferTemplate}
33+
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
3334
tunable_buffer_sizes::Vector{BufferTemplate}
3435
constant_buffer_sizes::Vector{BufferTemplate}
3536
dependent_buffer_sizes::Vector{BufferTemplate}
@@ -71,7 +72,8 @@ function IndexCache(sys::AbstractSystem)
7172
end
7273
end
7374

74-
disc_buffers = Dict{Any, Set{BasicSymbolic}}()
75+
disc_buffers = Dict{Int, Dict{Any, Set{BasicSymbolic}}}()
76+
disc_clocks = Dict{Union{Symbol, BasicSymbolic}, Int}()
7577
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
7678
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
7779
dependent_buffers = Dict{Any, Set{BasicSymbolic}}()
@@ -84,27 +86,107 @@ function IndexCache(sys::AbstractSystem)
8486
push!(buf, sym)
8587
end
8688

89+
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
90+
syss, inputs, continuous_id, _ = get_discrete_subsystems(sys)
91+
92+
for (i, (inps, disc_sys)) in enumerate(zip(inputs, syss))
93+
i == continuous_id && continue
94+
disc_buffers[i - 1] = Dict{Any, Set{BasicSymbolic}}()
95+
96+
for inp in inps
97+
inp = unwrap(inp)
98+
is_parameter(sys, inp) ||
99+
error("Discrete subsystem $i input $inp is not a parameter")
100+
disc_clocks[inp] = i - 1
101+
disc_clocks[default_toterm(inp)] = i - 1
102+
if hasname(inp) && (!istree(inp) || operation(inp) !== getindex)
103+
disc_clocks[getname(inp)] = i - 1
104+
disc_clocks[default_toterm(inp)] = i - 1
105+
end
106+
insert_by_type!(disc_buffers[i - 1], inp)
107+
end
108+
109+
for sym in unknowns(disc_sys)
110+
sym = unwrap(sym)
111+
is_parameter(sys, sym) ||
112+
error("Discrete subsystem $i unknown $sym is not a parameter")
113+
disc_clocks[sym] = i - 1
114+
disc_clocks[default_toterm(sym)] = i - 1
115+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
116+
disc_clocks[getname(sym)] = i - 1
117+
disc_clocks[getname(default_toterm(sym))] = i - 1
118+
end
119+
insert_by_type!(disc_buffers[i - 1], sym)
120+
end
121+
t = get_iv(sys)
122+
for eq in observed(disc_sys)
123+
# TODO: Is this a valid check
124+
# FIXME: This shouldn't be necessary
125+
eq.rhs === -0.0 && continue
126+
sym = eq.lhs
127+
if istree(sym) && operation(sym) == Shift(t, 1)
128+
sym = only(arguments(sym))
129+
end
130+
# is_parameter(sys, sym) || is_parameter(sys, Hold(sym)) || continue
131+
disc_clocks[sym] = i - 1
132+
disc_clocks[sym] = i - 1
133+
disc_clocks[default_toterm(sym)] = i - 1
134+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
135+
disc_clocks[getname(sym)] = i - 1
136+
disc_clocks[getname(default_toterm(sym))] = i - 1
137+
end
138+
end
139+
end
140+
141+
for par in inputs[continuous_id]
142+
is_parameter(sys, par) || error("Discrete subsystem input is not a parameter")
143+
istree(par) && operation(par) isa Hold ||
144+
error("Continuous subsystem input is not a Hold")
145+
if haskey(disc_clocks, par)
146+
sym = par
147+
else
148+
sym = first(arguments(par))
149+
end
150+
haskey(disc_clocks, sym) ||
151+
error("Variable $par not part of a discrete subsystem")
152+
disc_clocks[par] = disc_clocks[sym]
153+
insert_by_type!(disc_buffers[disc_clocks[sym]], par)
154+
end
155+
end
156+
87157
affs = vcat(affects(continuous_events(sys)), affects(discrete_events(sys)))
158+
user_affect_clock = maximum(values(disc_clocks); init = 1)
88159
for affect in affs
89160
if affect isa Equation
90161
is_parameter(sys, affect.lhs) || continue
91-
insert_by_type!(disc_buffers, affect.lhs)
162+
163+
disc_clocks[affect.lhs] = user_affect_clock
164+
disc_clocks[default_toterm(affect.lhs)] = user_affect_clock
165+
if hasname(affect.lhs) &&
166+
(!istree(affect.lhs) || operation(affect.lhs) !== getindex)
167+
disc_clocks[getname(affect.lhs)] = user_affect_clock
168+
disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock
169+
end
170+
buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
171+
insert_by_type!(buffer, affect.lhs)
92172
else
93173
discs = discretes(affect)
94174
for disc in discs
95175
is_parameter(sys, disc) ||
96176
error("Expected discrete variable $disc in callback to be a parameter")
97-
insert_by_type!(disc_buffers, disc)
177+
disc = unwrap(disc)
178+
disc_clocks[disc] = user_affect_clock
179+
disc_clocks[default_toterm(disc)] = user_affect_clock
180+
if hasname(disc) && (!istree(disc) || operation(disc) !== getindex)
181+
disc_clocks[getname(disc)] = user_affect_clock
182+
disc_clocks[getname(default_toterm(disc))] = user_affect_clock
183+
end
184+
buffer = get!(
185+
disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
186+
insert_by_type!(buffer, disc)
98187
end
99188
end
100189
end
101-
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
102-
_, inputs, continuous_id, _ = get_discrete_subsystems(sys)
103-
for par in inputs[continuous_id]
104-
is_parameter(sys, par) || error("Discrete subsystem input is not a parameter")
105-
insert_by_type!(disc_buffers, par)
106-
end
107-
end
108190

109191
if has_parameter_dependencies(sys) &&
110192
(pdeps = get_parameter_dependencies(sys)) !== nothing
@@ -117,13 +199,11 @@ function IndexCache(sys::AbstractSystem)
117199
for p in parameters(sys)
118200
p = unwrap(p)
119201
ctype = concrete_symtype(p)
120-
haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue
202+
haskey(disc_clocks, p) && continue
121203
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
122204
insert_by_type!(
123205
if ctype <: Real || ctype <: AbstractArray{<:Real}
124-
if is_discrete_domain(p)
125-
disc_buffers
126-
elseif istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown()
206+
if istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown()
127207
tunable_buffers
128208
else
129209
constant_buffers
@@ -135,6 +215,31 @@ function IndexCache(sys::AbstractSystem)
135215
)
136216
end
137217

218+
disc_idxs = ParamIndexMap()
219+
disc_buffer_sizes = [BufferTemplate[] for _ in 1:length(disc_buffers)]
220+
disc_buffer_types = Set()
221+
for buffer in values(disc_buffers)
222+
union!(disc_buffer_types, keys(buffer))
223+
end
224+
225+
for (clockidx, buffer) in disc_buffers
226+
for (i, btype) in enumerate(disc_buffer_types)
227+
if !haskey(buffer, btype)
228+
push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, 0))
229+
continue
230+
end
231+
push!(disc_buffer_sizes[clockidx], BufferTemplate(btype, length(buffer[btype])))
232+
for (j, sym) in enumerate(buffer[btype])
233+
disc_idxs[sym] = (i, j)
234+
disc_idxs[default_toterm(sym)] = (i, j)
235+
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
236+
disc_idxs[getname(sym)] = (i, j)
237+
disc_idxs[getname(default_toterm(sym))] = (i, j)
238+
end
239+
end
240+
end
241+
end
242+
138243
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
139244
idxs = ParamIndexMap()
140245
buffer_sizes = BufferTemplate[]
@@ -152,20 +257,20 @@ function IndexCache(sys::AbstractSystem)
152257
return idxs, buffer_sizes
153258
end
154259

155-
disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers)
156260
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
157261
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
158262
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
159263
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
160264

161265
return IndexCache(
162266
unk_idxs,
267+
disc_clocks,
163268
disc_idxs,
164269
tunable_idxs,
165270
const_idxs,
166271
dependent_idxs,
167272
nonnumeric_idxs,
168-
discrete_buffer_sizes,
273+
disc_buffer_sizes,
169274
tunable_buffer_sizes,
170275
const_buffer_sizes,
171276
dependent_buffer_sizes,
@@ -193,7 +298,8 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
193298
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
194299
ParameterIndex(SciMLStructures.Tunable(), idx)
195300
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
196-
ParameterIndex(SciMLStructures.Discrete(), idx)
301+
ParameterIndex(
302+
SciMLStructures.Discrete(), (check_index_map(ic.discrete_clocks, sym), idx...))
197303
elseif (idx = check_index_map(ic.constant_idx, sym)) !== nothing
198304
ParameterIndex(SciMLStructures.Constants(), idx)
199305
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
@@ -205,6 +311,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
205311
end
206312
end
207313

314+
function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
315+
return check_index_map(ic.discrete_clocks, sym) !== nothing
316+
end
317+
318+
function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
319+
clockid = check_index_map(ic.discrete_clocks, sym)
320+
clockid === nothing && return nothing
321+
partitionid = check_index_map(ic.discrete_idx, sym)
322+
partitionid === nothing && return nothing
323+
return ParameterTimeseriesIndex(clockid, partitionid)
324+
end
325+
208326
function check_index_map(idxmap, sym)
209327
if (idx = get(idxmap, sym, nothing)) !== nothing
210328
return idx
@@ -229,7 +347,8 @@ function ParameterIndex(ic::IndexCache, p, sub_idx = ())
229347
return if haskey(ic.tunable_idx, p)
230348
ParameterIndex(SciMLStructures.Tunable(), (ic.tunable_idx[p]..., sub_idx...))
231349
elseif haskey(ic.discrete_idx, p)
232-
ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[p]..., sub_idx...))
350+
ParameterIndex(SciMLStructures.Discrete(),
351+
(ic.discrete_clocks[p], ic.discrete_idx[p]..., sub_idx...))
233352
elseif haskey(ic.constant_idx, p)
234353
ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[p]..., sub_idx...))
235354
elseif haskey(ic.dependent_idx, p)
@@ -247,10 +366,14 @@ end
247366
function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
248367
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
249368
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)
369+
for clockbuftemps in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1)
370+
ind += sum(temp.length for temp in clockbuftemps; init = 0)
371+
end
250372
ind += sum(
251-
temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1);
373+
temp.length
374+
for temp in Iterators.take(ic.discrete_buffer_sizes[idx.idx[1]], idx.idx[2] - 1);
252375
init = 0)
253-
ind += idx.idx[2]
376+
ind += idx.idx[3]
254377
return ind
255378
end
256379

@@ -269,30 +392,32 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
269392
param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
270393
for temp in ic.tunable_buffer_sizes)
271394
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
272-
for temp in ic.discrete_buffer_sizes)
395+
for temp in Iterators.flatten(ic.discrete_buffer_sizes))
273396
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
274397
for temp in ic.constant_buffer_sizes)
275398
dep_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
276399
for temp in ic.dependent_buffer_sizes)
277400
nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
278401
for temp in ic.nonnumeric_buffer_sizes)
279-
280402
for p in ps
403+
p = unwrap(p)
281404
if haskey(ic.discrete_idx, p)
282-
i, j = ic.discrete_idx[p]
283-
disc_buf[i][j] = unwrap(p)
405+
disc_offset = length(first(ic.discrete_buffer_sizes))
406+
i = ic.discrete_clocks[p]
407+
j, k = ic.discrete_idx[p]
408+
disc_buf[(i - 1) * disc_offset + j][k] = p
284409
elseif haskey(ic.tunable_idx, p)
285410
i, j = ic.tunable_idx[p]
286-
param_buf[i][j] = unwrap(p)
411+
param_buf[i][j] = p
287412
elseif haskey(ic.constant_idx, p)
288413
i, j = ic.constant_idx[p]
289-
const_buf[i][j] = unwrap(p)
414+
const_buf[i][j] = p
290415
elseif haskey(ic.dependent_idx, p)
291416
i, j = ic.dependent_idx[p]
292-
dep_buf[i][j] = unwrap(p)
417+
dep_buf[i][j] = p
293418
elseif haskey(ic.nonnumeric_idx, p)
294419
i, j = ic.nonnumeric_idx[p]
295-
nonnumeric_buf[i][j] = unwrap(p)
420+
nonnumeric_buf[i][j] = p
296421
else
297422
error("Invalid parameter $p")
298423
end

0 commit comments

Comments
 (0)