Skip to content

Commit ec687fd

Browse files
feat: add IndexCache to ODESystem
1 parent 9fb2316 commit ec687fd

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

src/systems/abstractsystem.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,9 @@ Mark a system as completed. If a system is complete, the system will no longer
313313
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
314314
"""
315315
function complete(sys::AbstractSystem)
316+
if has_index_cache(sys)
317+
@set! sys.index_cache = IndexCache(sys)
318+
end
316319
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
317320
end
318321

@@ -354,7 +357,8 @@ for prop in [:eqs
354357
:discrete_subsystems
355358
:solved_unknowns
356359
:split_idxs
357-
:parent]
360+
:parent
361+
:index_cache]
358362
fname1 = Symbol(:get_, prop)
359363
fname2 = Symbol(:has_, prop)
360364
@eval begin

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ struct ODESystem <: AbstractODESystem
131131
"""
132132
complete::Bool
133133
"""
134+
Cached data for fast symbolic indexing.
135+
"""
136+
index_cache::Union{Nothing, IndexCache}
137+
"""
134138
A list of discrete subsystems.
135139
"""
136140
discrete_subsystems::Any
@@ -152,7 +156,7 @@ struct ODESystem <: AbstractODESystem
152156
torn_matching, connector_type, preface, cevents,
153157
devents, metadata = nothing, gui_metadata = nothing,
154158
tearing_state = nothing,
155-
substitutions = nothing, complete = false,
159+
substitutions = nothing, complete = false, index_cache = nothing,
156160
discrete_subsystems = nothing, solved_unknowns = nothing,
157161
split_idxs = nothing, parent = nothing; checks::Union{Bool, Int} = true)
158162
if checks == true || (checks & CheckComponents) > 0
@@ -168,7 +172,7 @@ struct ODESystem <: AbstractODESystem
168172
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
169173
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
170174
connector_type, preface, cevents, devents, metadata, gui_metadata,
171-
tearing_state, substitutions, complete, discrete_subsystems,
175+
tearing_state, substitutions, complete, index_cache, discrete_subsystems,
172176
solved_unknowns, split_idxs, parent)
173177
end
174178
end

src/systems/index_cache.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
abstract type SymbolHash end
2+
3+
struct IndexCache
4+
unknown_idx::Dict{UInt, Int}
5+
discrete_idx::Dict{UInt, Int}
6+
param_idx::Dict{UInt, Int}
7+
discrete_buffer_type_and_size::Vector{Tuple{DataType, Int}}
8+
param_buffer_type_and_size::Vector{Tuple{DataType, Int}}
9+
end
10+
11+
function IndexCache(sys::AbstractSystem)
12+
unks = solved_unknowns(sys)
13+
unk_idxs = Dict{UInt, Int}()
14+
for (i, sym) in enumerate(unks)
15+
h = hash(unwrap(sym))
16+
unk_idxs[h] = i
17+
setmetadata(sym, SymbolHash, h)
18+
end
19+
20+
# split parameters, also by type
21+
discrete_params = Dict{DataType, Any}()
22+
tunable_params = Dict{DataType, Any}()
23+
24+
for p in parameters(sys)
25+
T = symtype(p)
26+
buf = get!(is_discrete_domain(p) ? discrete_params : tunable_params, T, [])
27+
push!(buf, unwrap(p))
28+
end
29+
30+
disc_idxs = Dict{UInt, Int}()
31+
discrete_buffer_type_and_size = Tuple{DataType, Int}[]
32+
didx = 1
33+
34+
for (T, ps) in discrete_params
35+
push!(discrete_buffer_type_and_size, (T, length(ps)))
36+
for p in ps
37+
h = hash(p)
38+
disc_idxs[h] = didx
39+
didx += 1
40+
setmetadata(p, SymbolHash, h)
41+
end
42+
end
43+
44+
param_idxs = Dict{UInt, Int}()
45+
param_buffer_type_and_size = Tuple{DataType, Int}[]
46+
pidx = 1
47+
48+
for (T, ps) in tunable_params
49+
push!(param_buffer_type_and_size, (T, length(ps)))
50+
for p in ps
51+
h = hash(p)
52+
param_idxs[h] = pidx
53+
pidx += 1
54+
setmetadata(p, SymbolHash, h)
55+
end
56+
end
57+
58+
return IndexCache(unk_idxs, disc_idxs, param_idxs, discrete_buffer_type_and_size, param_buffer_type_and_size)
59+
end
60+
61+
function reorder_parameters(ic::IndexCache, ps)
62+
param_bufs = ArrayPartition((Vector{BasicSymbolic{T}}(undef, sz) for (T, sz) in ic.param_buffer_type_and_size)...)
63+
disc_bufs = ArrayPartition((Vector{BasicSymbolic{T}}(undef, sz) for (T, sz) in ic.discrete_buffer_type_and_size)...)
64+
65+
for p in ps
66+
h = hasmetadata(p, SymbolHash) ? getmetadata(p, SymbolHash) : hash(unwrap(p))
67+
if haskey(ic.discrete_idx, h)
68+
disc_bufs[ic.discrete_idx[h]] = unwrap(p)
69+
elseif haskey(ic.param_idx, h)
70+
param_bufs[ic.param_idx[h]] = unwrap(p)
71+
else
72+
error("Invalid parameter $p")
73+
end
74+
end
75+
76+
return (param_bufs.x..., disc_bufs.x...)
77+
end

0 commit comments

Comments
 (0)