Skip to content

Commit 8d7c677

Browse files
fix: refactor IndexCache for non-scalarized unknowns
1 parent 0c26977 commit 8d7c677

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
192192
ic = get_index_cache(sys)
193193
h = getsymbolhash(sym)
194194
return haskey(ic.unknown_idx, h) ||
195-
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym)))
195+
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
196+
(istree(sym) && operation(sym) === getindex &&
197+
is_variable(sys, first(arguments(sym))))
196198
else
197199
return any(isequal(sym), variable_symbols(sys)) ||
198200
hasname(sym) && is_variable(sys, getname(sym))
@@ -213,16 +215,15 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
213215
if has_index_cache(sys) && get_index_cache(sys) !== nothing
214216
ic = get_index_cache(sys)
215217
h = getsymbolhash(sym)
216-
return if haskey(ic.unknown_idx, h)
217-
ic.unknown_idx[h]
218-
else
219-
h = getsymbolhash(default_toterm(sym))
220-
if haskey(ic.unknown_idx, h)
221-
ic.unknown_idx[h]
222-
else
223-
nothing
224-
end
225-
end
218+
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
219+
220+
h = getsymbolhash(default_toterm(sym))
221+
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
222+
sym = unwrap(sym)
223+
istree(sym) && operation(sym) === getindex || return nothing
224+
idx = variable_index(sys, first(arguments(sym)))
225+
idx === nothing && return nothing
226+
return idx[arguments(sym)[(begin + 1):end]...]
226227
end
227228
idx = findfirst(isequal(sym), variable_symbols(sys))
228229
if idx === nothing && hasname(sym)

src/systems/index_cache.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
const IndexMap = Dict{UInt, Tuple{Int, Int}}
2222

2323
struct IndexCache
24-
unknown_idx::Dict{UInt, Int}
24+
unknown_idx::Dict{UInt, Union{Int, UnitRange{Int}}}
2525
discrete_idx::IndexMap
2626
param_idx::IndexMap
2727
constant_idx::IndexMap
@@ -36,10 +36,17 @@ end
3636

3737
function IndexCache(sys::AbstractSystem)
3838
unks = solved_unknowns(sys)
39-
unk_idxs = Dict{UInt, Int}()
40-
for (i, sym) in enumerate(unks)
41-
h = getsymbolhash(sym)
42-
unk_idxs[h] = i
39+
unk_idxs = Dict{UInt, Union{Int, UnitRange{Int}}}()
40+
let idx = 1
41+
for sym in unks
42+
h = getsymbolhash(sym)
43+
if Symbolics.isarraysymbolic(sym)
44+
unk_idxs[h] = idx:(idx + length(sym) - 1)
45+
else
46+
unk_idxs[h] = idx
47+
end
48+
idx += length(sym)
49+
end
4350
end
4451

4552
disc_buffers = Dict{DataType, Set{BasicSymbolic}}()

0 commit comments

Comments
 (0)