Skip to content

Commit f584815

Browse files
feat: support indexing with Symbols in IndexCache
1 parent 9b6f861 commit f584815

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

src/systems/abstractsystem.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
346346
end
347347

348348
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
349+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
350+
return haskey(ic.unknown_idx, hash(sym))
351+
end
349352
return any(isequal(sym), getname.(variable_symbols(sys))) ||
350353
count('', string(sym)) == 1 &&
351354
count(isequal(sym), Symbol.(nameof(sys), :₊, getname.(variable_symbols(sys)))) ==
@@ -377,6 +380,9 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
377380
end
378381

379382
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
383+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
384+
return get(ic.unknown_idx, h, nothing)
385+
end
380386
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
381387
if idx !== nothing
382388
return idx
@@ -401,19 +407,24 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
401407
ic = get_index_cache(sys)
402408
h = getsymbolhash(sym)
403409
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
404-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
410+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
411+
haskey(ic.nonnumeric_idx, h)
405412
true
406413
else
407414
h = getsymbolhash(default_toterm(sym))
408415
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
409-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
416+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
417+
haskey(ic.nonnumeric_idx, h)
410418
end
411419
end
412420
return any(isequal(sym), parameter_symbols(sys)) ||
413421
hasname(sym) && is_parameter(sys, getname(sym))
414422
end
415423

416424
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
425+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
426+
return ParameterIndex(ic, sym) !== nothing
427+
end
417428
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
418429
count('', string(sym)) == 1 &&
419430
count(isequal(sym),
@@ -426,7 +437,6 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
426437
end
427438
if has_index_cache(sys) && get_index_cache(sys) !== nothing
428439
ic = get_index_cache(sys)
429-
h = getsymbolhash(sym)
430440
return if (idx = ParameterIndex(ic, sym)) !== nothing
431441
idx
432442
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
@@ -444,6 +454,9 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
444454
end
445455

446456
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
457+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
458+
return ParameterIndex(ic, sym)
459+
end
447460
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
448461
if idx !== nothing
449462
return idx

src/systems/index_cache.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,16 @@ function IndexCache(sys::AbstractSystem)
4040
let idx = 1
4141
for sym in unks
4242
h = getsymbolhash(sym)
43-
if Symbolics.isarraysymbolic(sym)
44-
unk_idxs[h] = idx:(idx + length(sym) - 1)
43+
sym_idx = if Symbolics.isarraysymbolic(sym)
44+
idx:(idx + length(sym) - 1)
4545
else
46-
unk_idxs[h] = idx
46+
idx
47+
end
48+
unk_idxs[h] = sym_idx
49+
50+
if hasname(sym)
51+
h = hash(getname(sym))
52+
unk_idxs[h] = sym_idx
4753
end
4854
idx += length(sym)
4955
end
@@ -122,6 +128,10 @@ function IndexCache(sys::AbstractSystem)
122128
idxs[h] = (i, j)
123129
h = getsymbolhash(default_toterm(p))
124130
idxs[h] = (i, j)
131+
if hasname(p)
132+
h = hash(getname(p))
133+
idxs[h] = (i, j)
134+
end
125135
end
126136
push!(buffer_sizes, BufferTemplate(T, length(buf)))
127137
end
@@ -151,7 +161,7 @@ end
151161

152162
function ParameterIndex(ic::IndexCache, p, sub_idx = ())
153163
p = unwrap(p)
154-
h = getsymbolhash(p)
164+
h = p isa Symbol ? hash(p) : getsymbolhash(p)
155165
return if haskey(ic.param_idx, h)
156166
ParameterIndex(SciMLStructures.Tunable(), (ic.param_idx[h]..., sub_idx...))
157167
elseif haskey(ic.discrete_idx, h)

0 commit comments

Comments
 (0)