Skip to content

Commit d0ad0ef

Browse files
refactor: improve Symbol indexing
1 parent 74922ed commit d0ad0ef

File tree

1 file changed

+40
-84
lines changed

1 file changed

+40
-84
lines changed

src/systems/index_cache.jl

Lines changed: 40 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ end
2121

2222
ParameterIndex(portion, idx) = ParameterIndex(portion, idx, false)
2323

24-
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
24+
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
2525
const UnknownIndexMap = Dict{
26-
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
26+
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
2727

2828
struct IndexCache
2929
unknown_idx::UnknownIndexMap
30-
discrete_idx::Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int, Int}}
30+
discrete_idx::Dict{BasicSymbolic, Tuple{Int, Int, Int}}
3131
tunable_idx::ParamIndexMap
3232
constant_idx::ParamIndexMap
3333
dependent_idx::ParamIndexMap
3434
nonnumeric_idx::ParamIndexMap
35-
observed_syms::Set{Union{Symbol, BasicSymbolic}}
35+
observed_syms::Set{BasicSymbolic}
3636
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
3737
tunable_buffer_sizes::Vector{BufferTemplate}
3838
constant_buffer_sizes::Vector{BufferTemplate}
@@ -57,14 +57,6 @@ function IndexCache(sys::AbstractSystem)
5757
end
5858
unk_idxs[usym] = sym_idx
5959
unk_idxs[rsym] = sym_idx
60-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
61-
name = getname(usym)
62-
rname = getname(rsym)
63-
unk_idxs[name] = sym_idx
64-
unk_idxs[rname] = sym_idx
65-
symbol_to_variable[name] = sym
66-
symbol_to_variable[rname] = sym
67-
end
6860
idx += length(sym)
6961
end
7062
for sym in unks
@@ -80,14 +72,6 @@ function IndexCache(sys::AbstractSystem)
8072
rsym = renamespace(sys, arrsym)
8173
unk_idxs[arrsym] = idxs
8274
unk_idxs[rsym] = idxs
83-
if hasname(arrsym)
84-
name = getname(arrsym)
85-
rname = getname(rsym)
86-
unk_idxs[name] = idxs
87-
unk_idxs[rname] = idxs
88-
symbol_to_variable[name] = arrsym
89-
symbol_to_variable[rname] = arrsym
90-
end
9175
end
9276
end
9377

@@ -102,16 +86,6 @@ function IndexCache(sys::AbstractSystem)
10286
push!(observed_syms, ttsym)
10387
push!(observed_syms, rsym)
10488
push!(observed_syms, rttsym)
105-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
106-
symbol_to_variable[getname(sym)] = eq.lhs
107-
symbol_to_variable[getname(ttsym)] = eq.lhs
108-
symbol_to_variable[getname(rsym)] = eq.lhs
109-
symbol_to_variable[getname(rttsym)] = eq.lhs
110-
push!(observed_syms, getname(sym))
111-
push!(observed_syms, getname(ttsym))
112-
push!(observed_syms, getname(rsym))
113-
push!(observed_syms, getname(rttsym))
114-
end
11589
end
11690
end
11791

@@ -143,16 +117,12 @@ function IndexCache(sys::AbstractSystem)
143117
rttinp = renamespace(sys, ttinp)
144118
is_parameter(sys, inp) ||
145119
error("Discrete subsystem $i input $inp is not a parameter")
120+
146121
disc_clocks[inp] = i
147122
disc_clocks[ttinp] = i
148123
disc_clocks[rinp] = i
149124
disc_clocks[rttinp] = i
150-
if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex)
151-
disc_clocks[getname(inp)] = i
152-
disc_clocks[getname(ttinp)] = i
153-
disc_clocks[getname(rinp)] = i
154-
disc_clocks[getname(rttinp)] = i
155-
end
125+
156126
insert_by_type!(disc_buffers[i], inp)
157127
end
158128

@@ -163,16 +133,12 @@ function IndexCache(sys::AbstractSystem)
163133
rttsym = renamespace(sys, ttsym)
164134
is_parameter(sys, sym) ||
165135
error("Discrete subsystem $i unknown $sym is not a parameter")
136+
166137
disc_clocks[sym] = i
167138
disc_clocks[ttsym] = i
168139
disc_clocks[rsym] = i
169140
disc_clocks[rttsym] = i
170-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
171-
disc_clocks[getname(sym)] = i
172-
disc_clocks[getname(ttsym)] = i
173-
disc_clocks[getname(rsym)] = i
174-
disc_clocks[getname(rttsym)] = i
175-
end
141+
176142
insert_by_type!(disc_buffers[i], sym)
177143
end
178144
t = get_iv(sys)
@@ -191,12 +157,6 @@ function IndexCache(sys::AbstractSystem)
191157
disc_clocks[ttsym] = i
192158
disc_clocks[rsym] = i
193159
disc_clocks[rttsym] = i
194-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
195-
disc_clocks[getname(sym)] = i
196-
disc_clocks[getname(ttsym)] = i
197-
disc_clocks[getname(rsym)] = i
198-
disc_clocks[getname(rttsym)] = i
199-
end
200160
end
201161
end
202162

@@ -237,13 +197,7 @@ function IndexCache(sys::AbstractSystem)
237197
disc_clocks[ttsym] = user_affect_clock
238198
disc_clocks[rsym] = user_affect_clock
239199
disc_clocks[rttsym] = user_affect_clock
240-
if hasname(sym) &&
241-
(!iscall(sym) || operation(sym) !== getindex)
242-
disc_clocks[getname(sym)] = user_affect_clock
243-
disc_clocks[getname(ttsym)] = user_affect_clock
244-
disc_clocks[getname(rsym)] = user_affect_clock
245-
disc_clocks[getname(rttsym)] = user_affect_clock
246-
end
200+
247201
buffer = get!(disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
248202
insert_by_type!(buffer, affect.lhs)
249203
else
@@ -259,12 +213,7 @@ function IndexCache(sys::AbstractSystem)
259213
disc_clocks[ttdisc] = user_affect_clock
260214
disc_clocks[rdisc] = user_affect_clock
261215
disc_clocks[rttdisc] = user_affect_clock
262-
if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex)
263-
disc_clocks[getname(disc)] = user_affect_clock
264-
disc_clocks[getname(ttdisc)] = user_affect_clock
265-
disc_clocks[getname(rdisc)] = user_affect_clock
266-
disc_clocks[getname(rttdisc)] = user_affect_clock
267-
end
216+
268217
buffer = get!(
269218
disc_buffers, user_affect_clock, Dict{Any, Set{BasicSymbolic}}())
270219
insert_by_type!(buffer, disc)
@@ -316,21 +265,13 @@ function IndexCache(sys::AbstractSystem)
316265
for (j, sym) in enumerate(buffer[btype])
317266
disc_idxs[sym] = (clockidx, i, j)
318267
disc_idxs[default_toterm(sym)] = (clockidx, i, j)
319-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
320-
disc_idxs[getname(sym)] = (clockidx, i, j)
321-
disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j)
322-
end
323268
end
324269
end
325270
end
326271
for (sym, clockid) in disc_clocks
327272
haskey(disc_idxs, sym) && continue
328273
disc_idxs[sym] = (clockid, 0, 0)
329274
disc_idxs[default_toterm(sym)] = (clockid, 0, 0)
330-
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
331-
disc_idxs[getname(sym)] = (clockid, 0, 0)
332-
disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0)
333-
end
334275
end
335276

336277
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
@@ -345,16 +286,6 @@ function IndexCache(sys::AbstractSystem)
345286
idxs[ttp] = (i, j)
346287
idxs[rp] = (i, j)
347288
idxs[rttp] = (i, j)
348-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
349-
idxs[getname(p)] = (i, j)
350-
idxs[getname(ttp)] = (i, j)
351-
idxs[getname(rp)] = (i, j)
352-
idxs[getname(rttp)] = (i, j)
353-
symbol_to_variable[getname(p)] = p
354-
symbol_to_variable[getname(ttp)] = p
355-
symbol_to_variable[getname(rp)] = p
356-
symbol_to_variable[getname(rttp)] = p
357-
end
358289
end
359290
push!(buffer_sizes, BufferTemplate(T, length(buf)))
360291
end
@@ -366,6 +297,14 @@ function IndexCache(sys::AbstractSystem)
366297
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
367298
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
368299

300+
for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs),
301+
keys(const_idxs), keys(dependent_idxs), keys(nonnumeric_idxs),
302+
observed_syms, independent_variable_symbols(sys)))
303+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
304+
symbol_to_variable[getname(sym)] = sym
305+
end
306+
end
307+
369308
return IndexCache(
370309
unk_idxs,
371310
disc_idxs,
@@ -384,18 +323,26 @@ function IndexCache(sys::AbstractSystem)
384323
end
385324

386325
function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
387-
return check_index_map(ic.unknown_idx, sym) !== nothing
388-
end
389-
390-
function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym::Symbol)
326+
if sym isa Symbol
327+
sym = get(ic.symbol_to_variable, sym, nothing)
328+
sym === nothing && return false
329+
end
391330
return check_index_map(ic.unknown_idx, sym) !== nothing
392331
end
393332

394333
function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
334+
if sym isa Symbol
335+
sym = get(ic.symbol_to_variable, sym, nothing)
336+
sym === nothing && return nothing
337+
end
395338
return check_index_map(ic.unknown_idx, sym)
396339
end
397340

398341
function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
342+
if sym isa Symbol
343+
sym = get(ic.symbol_to_variable, sym, nothing)
344+
sym === nothing && return false
345+
end
399346
return check_index_map(ic.tunable_idx, sym) !== nothing ||
400347
check_index_map(ic.discrete_idx, sym) !== nothing ||
401348
check_index_map(ic.constant_idx, sym) !== nothing ||
@@ -405,7 +352,8 @@ end
405352

406353
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
407354
if sym isa Symbol
408-
sym = ic.symbol_to_variable[sym]
355+
sym = get(ic.symbol_to_variable, sym, nothing)
356+
sym === nothing && return nothing
409357
end
410358
validate_size = Symbolics.isarraysymbolic(sym) &&
411359
Symbolics.shape(sym) !== Symbolics.Unknown()
@@ -425,10 +373,18 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
425373
end
426374

427375
function SymbolicIndexingInterface.is_timeseries_parameter(ic::IndexCache, sym)
376+
if sym isa Symbol
377+
sym = get(ic.symbol_to_variable, sym, nothing)
378+
sym === nothing && return false
379+
end
428380
return check_index_map(ic.discrete_idx, sym) !== nothing
429381
end
430382

431383
function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sym)
384+
if sym isa Symbol
385+
sym = get(ic.symbol_to_variable, sym, nothing)
386+
sym === nothing && return nothing
387+
end
432388
idx = check_index_map(ic.discrete_idx, sym)
433389
idx === nothing && return nothing
434390
clockid, partitionid... = idx

0 commit comments

Comments
 (0)