@@ -24,12 +24,13 @@ const UnknownIndexMap = Dict{
24
24
25
25
struct IndexCache
26
26
unknown_idx:: UnknownIndexMap
27
+ discrete_clocks:: Dict{Union{Symbol, BasicSymbolic}, Int}
27
28
discrete_idx:: ParamIndexMap
28
29
tunable_idx:: ParamIndexMap
29
30
constant_idx:: ParamIndexMap
30
31
dependent_idx:: ParamIndexMap
31
32
nonnumeric_idx:: ParamIndexMap
32
- discrete_buffer_sizes:: Vector{BufferTemplate}
33
+ discrete_buffer_sizes:: Vector{Vector{ BufferTemplate} }
33
34
tunable_buffer_sizes:: Vector{BufferTemplate}
34
35
constant_buffer_sizes:: Vector{BufferTemplate}
35
36
dependent_buffer_sizes:: Vector{BufferTemplate}
@@ -71,7 +72,8 @@ function IndexCache(sys::AbstractSystem)
71
72
end
72
73
end
73
74
74
- disc_buffers = Dict {Any, Set{BasicSymbolic}} ()
75
+ disc_buffers = Dict {Int, Dict{Any, Set{BasicSymbolic}}} ()
76
+ disc_clocks = Dict {Union{Symbol, BasicSymbolic}, Int} ()
75
77
tunable_buffers = Dict {Any, Set{BasicSymbolic}} ()
76
78
constant_buffers = Dict {Any, Set{BasicSymbolic}} ()
77
79
dependent_buffers = Dict {Any, Set{BasicSymbolic}} ()
@@ -84,27 +86,107 @@ function IndexCache(sys::AbstractSystem)
84
86
push! (buf, sym)
85
87
end
86
88
89
+ if has_discrete_subsystems (sys) && get_discrete_subsystems (sys) != = nothing
90
+ syss, inputs, continuous_id, _ = get_discrete_subsystems (sys)
91
+
92
+ for (i, (inps, disc_sys)) in enumerate (zip (inputs, syss))
93
+ i == continuous_id && continue
94
+ disc_buffers[i - 1 ] = Dict {Any, Set{BasicSymbolic}} ()
95
+
96
+ for inp in inps
97
+ inp = unwrap (inp)
98
+ is_parameter (sys, inp) ||
99
+ error (" Discrete subsystem $i input $inp is not a parameter" )
100
+ disc_clocks[inp] = i - 1
101
+ disc_clocks[default_toterm (inp)] = i - 1
102
+ if hasname (inp) && (! istree (inp) || operation (inp) != = getindex)
103
+ disc_clocks[getname (inp)] = i - 1
104
+ disc_clocks[default_toterm (inp)] = i - 1
105
+ end
106
+ insert_by_type! (disc_buffers[i - 1 ], inp)
107
+ end
108
+
109
+ for sym in unknowns (disc_sys)
110
+ sym = unwrap (sym)
111
+ is_parameter (sys, sym) ||
112
+ error (" Discrete subsystem $i unknown $sym is not a parameter" )
113
+ disc_clocks[sym] = i - 1
114
+ disc_clocks[default_toterm (sym)] = i - 1
115
+ if hasname (sym) && (! istree (sym) || operation (sym) != = getindex)
116
+ disc_clocks[getname (sym)] = i - 1
117
+ disc_clocks[getname (default_toterm (sym))] = i - 1
118
+ end
119
+ insert_by_type! (disc_buffers[i - 1 ], sym)
120
+ end
121
+ t = get_iv (sys)
122
+ for eq in observed (disc_sys)
123
+ # TODO : Is this a valid check
124
+ # FIXME : This shouldn't be necessary
125
+ eq. rhs === - 0.0 && continue
126
+ sym = eq. lhs
127
+ if istree (sym) && operation (sym) == Shift (t, 1 )
128
+ sym = only (arguments (sym))
129
+ end
130
+ # is_parameter(sys, sym) || is_parameter(sys, Hold(sym)) || continue
131
+ disc_clocks[sym] = i - 1
132
+ disc_clocks[sym] = i - 1
133
+ disc_clocks[default_toterm (sym)] = i - 1
134
+ if hasname (sym) && (! istree (sym) || operation (sym) != = getindex)
135
+ disc_clocks[getname (sym)] = i - 1
136
+ disc_clocks[getname (default_toterm (sym))] = i - 1
137
+ end
138
+ end
139
+ end
140
+
141
+ for par in inputs[continuous_id]
142
+ is_parameter (sys, par) || error (" Discrete subsystem input is not a parameter" )
143
+ istree (par) && operation (par) isa Hold ||
144
+ error (" Continuous subsystem input is not a Hold" )
145
+ if haskey (disc_clocks, par)
146
+ sym = par
147
+ else
148
+ sym = first (arguments (par))
149
+ end
150
+ haskey (disc_clocks, sym) ||
151
+ error (" Variable $par not part of a discrete subsystem" )
152
+ disc_clocks[par] = disc_clocks[sym]
153
+ insert_by_type! (disc_buffers[disc_clocks[sym]], par)
154
+ end
155
+ end
156
+
87
157
affs = vcat (affects (continuous_events (sys)), affects (discrete_events (sys)))
158
+ user_affect_clock = maximum (values (disc_clocks); init = 1 )
88
159
for affect in affs
89
160
if affect isa Equation
90
161
is_parameter (sys, affect. lhs) || continue
91
- insert_by_type! (disc_buffers, affect. lhs)
162
+
163
+ disc_clocks[affect. lhs] = user_affect_clock
164
+ disc_clocks[default_toterm (affect. lhs)] = user_affect_clock
165
+ if hasname (affect. lhs) &&
166
+ (! istree (affect. lhs) || operation (affect. lhs) != = getindex)
167
+ disc_clocks[getname (affect. lhs)] = user_affect_clock
168
+ disc_clocks[getname (default_toterm (affect. lhs))] = user_affect_clock
169
+ end
170
+ buffer = get! (disc_buffers, user_affect_clock, Dict {Any, Set{BasicSymbolic}} ())
171
+ insert_by_type! (buffer, affect. lhs)
92
172
else
93
173
discs = discretes (affect)
94
174
for disc in discs
95
175
is_parameter (sys, disc) ||
96
176
error (" Expected discrete variable $disc in callback to be a parameter" )
97
- insert_by_type! (disc_buffers, disc)
177
+ disc = unwrap (disc)
178
+ disc_clocks[disc] = user_affect_clock
179
+ disc_clocks[default_toterm (disc)] = user_affect_clock
180
+ if hasname (disc) && (! istree (disc) || operation (disc) != = getindex)
181
+ disc_clocks[getname (disc)] = user_affect_clock
182
+ disc_clocks[getname (default_toterm (disc))] = user_affect_clock
183
+ end
184
+ buffer = get! (
185
+ disc_buffers, user_affect_clock, Dict {Any, Set{BasicSymbolic}} ())
186
+ insert_by_type! (buffer, disc)
98
187
end
99
188
end
100
189
end
101
- if has_discrete_subsystems (sys) && get_discrete_subsystems (sys) != = nothing
102
- _, inputs, continuous_id, _ = get_discrete_subsystems (sys)
103
- for par in inputs[continuous_id]
104
- is_parameter (sys, par) || error (" Discrete subsystem input is not a parameter" )
105
- insert_by_type! (disc_buffers, par)
106
- end
107
- end
108
190
109
191
if has_parameter_dependencies (sys) &&
110
192
(pdeps = get_parameter_dependencies (sys)) != = nothing
@@ -117,13 +199,11 @@ function IndexCache(sys::AbstractSystem)
117
199
for p in parameters (sys)
118
200
p = unwrap (p)
119
201
ctype = concrete_symtype (p)
120
- haskey (disc_buffers, ctype) && p in disc_buffers[ctype] && continue
202
+ haskey (disc_clocks, p) && continue
121
203
haskey (dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
122
204
insert_by_type! (
123
205
if ctype <: Real || ctype <: AbstractArray{<:Real}
124
- if is_discrete_domain (p)
125
- disc_buffers
126
- elseif istunable (p, true ) && Symbolics. shape (p) != = Symbolics. Unknown ()
206
+ if istunable (p, true ) && Symbolics. shape (p) != = Symbolics. Unknown ()
127
207
tunable_buffers
128
208
else
129
209
constant_buffers
@@ -135,6 +215,31 @@ function IndexCache(sys::AbstractSystem)
135
215
)
136
216
end
137
217
218
+ disc_idxs = ParamIndexMap ()
219
+ disc_buffer_sizes = [BufferTemplate[] for _ in 1 : length (disc_buffers)]
220
+ disc_buffer_types = Set ()
221
+ for buffer in values (disc_buffers)
222
+ union! (disc_buffer_types, keys (buffer))
223
+ end
224
+
225
+ for (clockidx, buffer) in disc_buffers
226
+ for (i, btype) in enumerate (disc_buffer_types)
227
+ if ! haskey (buffer, btype)
228
+ push! (disc_buffer_sizes[clockidx], BufferTemplate (btype, 0 ))
229
+ continue
230
+ end
231
+ push! (disc_buffer_sizes[clockidx], BufferTemplate (btype, length (buffer[btype])))
232
+ for (j, sym) in enumerate (buffer[btype])
233
+ disc_idxs[sym] = (i, j)
234
+ disc_idxs[default_toterm (sym)] = (i, j)
235
+ if hasname (sym) && (! istree (sym) || operation (sym) != = getindex)
236
+ disc_idxs[getname (sym)] = (i, j)
237
+ disc_idxs[getname (default_toterm (sym))] = (i, j)
238
+ end
239
+ end
240
+ end
241
+ end
242
+
138
243
function get_buffer_sizes_and_idxs (buffers:: Dict{Any, Set{BasicSymbolic}} )
139
244
idxs = ParamIndexMap ()
140
245
buffer_sizes = BufferTemplate[]
@@ -152,20 +257,20 @@ function IndexCache(sys::AbstractSystem)
152
257
return idxs, buffer_sizes
153
258
end
154
259
155
- disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs (disc_buffers)
156
260
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs (tunable_buffers)
157
261
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs (constant_buffers)
158
262
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs (dependent_buffers)
159
263
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs (nonnumeric_buffers)
160
264
161
265
return IndexCache (
162
266
unk_idxs,
267
+ disc_clocks,
163
268
disc_idxs,
164
269
tunable_idxs,
165
270
const_idxs,
166
271
dependent_idxs,
167
272
nonnumeric_idxs,
168
- discrete_buffer_sizes ,
273
+ disc_buffer_sizes ,
169
274
tunable_buffer_sizes,
170
275
const_buffer_sizes,
171
276
dependent_buffer_sizes,
@@ -193,7 +298,8 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
193
298
return if (idx = check_index_map (ic. tunable_idx, sym)) != = nothing
194
299
ParameterIndex (SciMLStructures. Tunable (), idx)
195
300
elseif (idx = check_index_map (ic. discrete_idx, sym)) != = nothing
196
- ParameterIndex (SciMLStructures. Discrete (), idx)
301
+ ParameterIndex (
302
+ SciMLStructures. Discrete (), (check_index_map (ic. discrete_clocks, sym), idx... ))
197
303
elseif (idx = check_index_map (ic. constant_idx, sym)) != = nothing
198
304
ParameterIndex (SciMLStructures. Constants (), idx)
199
305
elseif (idx = check_index_map (ic. nonnumeric_idx, sym)) != = nothing
@@ -205,6 +311,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
205
311
end
206
312
end
207
313
314
+ function SymbolicIndexingInterface. is_timeseries_parameter (ic:: IndexCache , sym)
315
+ return check_index_map (ic. discrete_clocks, sym) != = nothing
316
+ end
317
+
318
+ function SymbolicIndexingInterface. timeseries_parameter_index (ic:: IndexCache , sym)
319
+ clockid = check_index_map (ic. discrete_clocks, sym)
320
+ clockid === nothing && return nothing
321
+ partitionid = check_index_map (ic. discrete_idx, sym)
322
+ partitionid === nothing && return nothing
323
+ return ParameterTimeseriesIndex (clockid, partitionid)
324
+ end
325
+
208
326
function check_index_map (idxmap, sym)
209
327
if (idx = get (idxmap, sym, nothing )) != = nothing
210
328
return idx
@@ -229,7 +347,8 @@ function ParameterIndex(ic::IndexCache, p, sub_idx = ())
229
347
return if haskey (ic. tunable_idx, p)
230
348
ParameterIndex (SciMLStructures. Tunable (), (ic. tunable_idx[p]. .. , sub_idx... ))
231
349
elseif haskey (ic. discrete_idx, p)
232
- ParameterIndex (SciMLStructures. Discrete (), (ic. discrete_idx[p]. .. , sub_idx... ))
350
+ ParameterIndex (SciMLStructures. Discrete (),
351
+ (ic. discrete_clocks[p], ic. discrete_idx[p]. .. , sub_idx... ))
233
352
elseif haskey (ic. constant_idx, p)
234
353
ParameterIndex (SciMLStructures. Constants (), (ic. constant_idx[p]. .. , sub_idx... ))
235
354
elseif haskey (ic. dependent_idx, p)
@@ -247,10 +366,14 @@ end
247
366
function discrete_linear_index (ic:: IndexCache , idx:: ParameterIndex )
248
367
idx. portion isa SciMLStructures. Discrete || error (" Discrete variable index expected" )
249
368
ind = sum (temp. length for temp in ic. tunable_buffer_sizes; init = 0 )
369
+ for clockbuftemps in Iterators. take (ic. discrete_buffer_sizes, idx. idx[1 ] - 1 )
370
+ ind += sum (temp. length for temp in clockbuftemps; init = 0 )
371
+ end
250
372
ind += sum (
251
- temp. length for temp in Iterators. take (ic. discrete_buffer_sizes, idx. idx[1 ] - 1 );
373
+ temp. length
374
+ for temp in Iterators. take (ic. discrete_buffer_sizes[idx. idx[1 ]], idx. idx[2 ] - 1 );
252
375
init = 0 )
253
- ind += idx. idx[2 ]
376
+ ind += idx. idx[3 ]
254
377
return ind
255
378
end
256
379
@@ -269,30 +392,32 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
269
392
param_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
270
393
for temp in ic. tunable_buffer_sizes)
271
394
disc_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
272
- for temp in ic. discrete_buffer_sizes)
395
+ for temp in Iterators . flatten ( ic. discrete_buffer_sizes) )
273
396
const_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
274
397
for temp in ic. constant_buffer_sizes)
275
398
dep_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
276
399
for temp in ic. dependent_buffer_sizes)
277
400
nonnumeric_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
278
401
for temp in ic. nonnumeric_buffer_sizes)
279
-
280
402
for p in ps
403
+ p = unwrap (p)
281
404
if haskey (ic. discrete_idx, p)
282
- i, j = ic. discrete_idx[p]
283
- disc_buf[i][j] = unwrap (p)
405
+ disc_offset = length (first (ic. discrete_buffer_sizes))
406
+ i = ic. discrete_clocks[p]
407
+ j, k = ic. discrete_idx[p]
408
+ disc_buf[(i - 1 ) * disc_offset + j][k] = p
284
409
elseif haskey (ic. tunable_idx, p)
285
410
i, j = ic. tunable_idx[p]
286
- param_buf[i][j] = unwrap (p)
411
+ param_buf[i][j] = p
287
412
elseif haskey (ic. constant_idx, p)
288
413
i, j = ic. constant_idx[p]
289
- const_buf[i][j] = unwrap (p)
414
+ const_buf[i][j] = p
290
415
elseif haskey (ic. dependent_idx, p)
291
416
i, j = ic. dependent_idx[p]
292
- dep_buf[i][j] = unwrap (p)
417
+ dep_buf[i][j] = p
293
418
elseif haskey (ic. nonnumeric_idx, p)
294
419
i, j = ic. nonnumeric_idx[p]
295
- nonnumeric_buf[i][j] = unwrap (p)
420
+ nonnumeric_buf[i][j] = p
296
421
else
297
422
error (" Invalid parameter $p " )
298
423
end
0 commit comments