Skip to content

Commit dff1b5f

Browse files
feat: add support for SciMLStructures.Constants portion, and dependent parameters
1 parent a397d1f commit dff1b5f

File tree

3 files changed

+155
-54
lines changed

3 files changed

+155
-54
lines changed

src/systems/abstractsystem.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,14 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
255255
if has_index_cache(sys) && get_index_cache(sys) !== nothing
256256
ic = get_index_cache(sys)
257257
h = getsymbolhash(sym)
258-
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h)
258+
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
259+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h)
259260
true
260261
else
261262
h = getsymbolhash(default_toterm(sym))
262-
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) || hasname(sym) && is_parameter(sys, getname(sym))
263+
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
264+
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
265+
hasname(sym) && is_parameter(sys, getname(sym))
263266
end
264267
end
265268
return any(isequal(sym), parameter_symbols(sys)) ||
@@ -284,12 +287,20 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
284287
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
285288
elseif haskey(ic.discrete_idx, h)
286289
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
290+
elseif haskey(ic.constant_idx, h)
291+
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
292+
elseif haskey(ic.dependent_idx, h)
293+
ParameterIndex(nothing, ic.dependent_idx[h])
287294
else
288295
h = getsymbolhash(default_toterm(sym))
289296
if haskey(ic.param_idx, h)
290297
ParameterIndex(SciMLStructures.Tunable(), ic.param_idx[h])
291298
elseif haskey(ic.discrete_idx, h)
292299
ParameterIndex(SciMLStructures.Discrete(), ic.discrete_idx[h])
300+
elseif haskey(ic.constant_idx, h)
301+
ParameterIndex(SciMLStructures.Constants(), ic.constant_idx[h])
302+
elseif haskey(ic.dependent_idx, h)
303+
ParameterIndex(nothing, ic.dependent_idx[h])
293304
else
294305
nothing
295306
end

src/systems/index_cache.jl

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@ struct IndexCache
1616
unknown_idx::Dict{UInt, Int}
1717
discrete_idx::Dict{UInt, Int}
1818
param_idx::Dict{UInt, Int}
19+
constant_idx::Dict{UInt, Int}
20+
dependent_idx::Dict{UInt, Int}
1921
discrete_buffer_sizes::Vector{BufferTemplate}
2022
param_buffer_sizes::Vector{BufferTemplate}
23+
constant_buffer_sizes::Vector{BufferTemplate}
24+
dependent_buffer_sizes::Vector{BufferTemplate}
2125
end
2226

2327
function IndexCache(sys::AbstractSystem)
@@ -31,6 +35,8 @@ function IndexCache(sys::AbstractSystem)
3135

3236
disc_buffers = Dict{DataType, Set{BasicSymbolic}}()
3337
tunable_buffers = Dict{DataType, Set{BasicSymbolic}}()
38+
constant_buffers = Dict{DataType, Set{BasicSymbolic}}()
39+
dependent_buffers = Dict{DataType, Set{BasicSymbolic}}()
3440

3541
function insert_by_type!(buffers::Dict{DataType, Set{BasicSymbolic}}, sym)
3642
sym = unwrap(sym)
@@ -53,40 +59,64 @@ function IndexCache(sys::AbstractSystem)
5359
end
5460
end
5561

62+
all_ps = Set(unwrap.(parameters(sys)))
63+
for (sym, value) in defaults(sys)
64+
sym = unwrap(sym)
65+
if sym in all_ps && symbolic_type(unwrap(value)) !== NotSymbolic()
66+
insert_by_type!(dependent_buffers, sym)
67+
end
68+
end
69+
5670
for p in parameters(sys)
5771
p = unwrap(p)
5872
ctype = concrete_symtype(p)
5973
haskey(disc_buffers, ctype) && p in disc_buffers[ctype] && continue
60-
61-
insert_by_type!(is_discrete_domain(p) ? disc_buffers : tunable_buffers, p)
74+
haskey(dependent_buffers, ctype) && p in dependent_buffers[ctype] && continue
75+
76+
insert_by_type!(
77+
if is_discrete_domain(p)
78+
disc_buffers
79+
elseif istunable(p, true)
80+
tunable_buffers
81+
else
82+
constant_buffers
83+
end,
84+
p
85+
)
6286
end
6387

64-
disc_idxs = Dict{UInt, Int}()
65-
discrete_buffer_sizes = BufferTemplate[]
66-
didx = 1
67-
for (T, buf) in disc_buffers
68-
for p in buf
69-
h = hash(p)
70-
setmetadata(p, SymbolHash, h)
71-
disc_idxs[h] = didx
72-
didx += 1
73-
end
74-
push!(discrete_buffer_sizes, BufferTemplate(T, length(buf)))
75-
end
76-
param_idxs = Dict{UInt, Int}()
77-
param_buffer_sizes = BufferTemplate[]
78-
pidx = 1
79-
for (T, buf) in tunable_buffers
80-
for p in buf
81-
h = hash(p)
82-
setmetadata(p, SymbolHash, h)
83-
param_idxs[h] = pidx
84-
pidx += 1
88+
function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}})
89+
idxs = Dict{UInt, Int}()
90+
buffer_sizes = BufferTemplate[]
91+
idx = 1
92+
for (T, buf) in buffers
93+
for p in buf
94+
h = hash(p)
95+
setmetadata(p, SymbolHash, h)
96+
idxs[h] = idx
97+
idx += 1
98+
end
99+
push!(buffer_sizes, BufferTemplate(T, length(buf)))
85100
end
86-
push!(param_buffer_sizes, BufferTemplate(T, length(buf)))
101+
return idxs, buffer_sizes
87102
end
88103

89-
return IndexCache(unk_idxs, disc_idxs, param_idxs, discrete_buffer_sizes, param_buffer_sizes)
104+
disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers)
105+
param_idxs, param_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
106+
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
107+
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
108+
109+
return IndexCache(
110+
unk_idxs,
111+
disc_idxs,
112+
param_idxs,
113+
const_idxs,
114+
dependent_idxs,
115+
discrete_buffer_sizes,
116+
param_buffer_sizes,
117+
const_buffer_sizes,
118+
dependent_buffer_sizes,
119+
)
90120
end
91121

92122
function reorder_parameters(sys::AbstractSystem, ps; kwargs...)
@@ -102,19 +132,25 @@ end
102132
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
103133
param_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.param_buffer_sizes)...)
104134
disc_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.discrete_buffer_sizes)...)
135+
const_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.constant_buffer_sizes)...)
136+
dep_buf = ArrayPartition((fill(variable(:DEF), temp.length) for temp in ic.dependent_buffer_sizes)...)
105137

106138
for p in ps
107139
h = getsymbolhash(p)
108140
if haskey(ic.discrete_idx, h)
109141
disc_buf[ic.discrete_idx[h]] = unwrap(p)
110142
elseif haskey(ic.param_idx, h)
111143
param_buf[ic.param_idx[h]] = unwrap(p)
144+
elseif haskey(ic.constant_idx, h)
145+
const_buf[ic.constant_idx[h]] = unwrap(p)
146+
elseif haskey(ic.dependent_idx, h)
147+
dep_buf[ic.dependent_idx[h]] = unwrap(p)
112148
else
113149
error("Invalid parameter $p")
114150
end
115151
end
116152

117-
result = broadcast.(unwrap, (param_buf.x..., disc_buf.x...))
153+
result = broadcast.(unwrap, (param_buf.x..., disc_buf.x..., const_buf.x..., dep_buf.x...))
118154
if drop_missing
119155
result = map(result) do buf
120156
filter(buf) do sym

src/systems/parameter_buffer.jl

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
struct MTKParameters{T, D}
1+
struct MTKParameters{T, D, C, E, F}
22
tunable::T
33
discrete::D
4+
constant::C
5+
dependent::E
6+
dependent_update::F
47
end
58

69
function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat = false, use_union = false)
@@ -10,7 +13,7 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
1013
error("Cannot create MTKParameters if system does not have index_cache")
1114
end
1215
all_ps = Set(unwrap.(parameters(sys)))
13-
if p isa Vector && !(eltype(p) <: Pair)
16+
if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
1417
ps = parameters(sys)
1518
length(p) == length(ps) || error("Invalid parameters")
1619
p = ps .=> p
@@ -24,12 +27,20 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
2427

2528
tunable_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.param_buffer_sizes)...)
2629
disc_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.discrete_buffer_sizes)...)
30+
const_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.constant_buffer_sizes)...)
31+
dep_buffer = ArrayPartition((Vector{temp.type}(undef, temp.length) for temp in ic.dependent_buffer_sizes)...)
32+
dependencies = Dict{Num, Num}()
2733
function set_value(sym, val)
2834
h = getsymbolhash(sym)
2935
if haskey(ic.param_idx, h)
3036
tunable_buffer[ic.param_idx[h]] = val
3137
elseif haskey(ic.discrete_idx, h)
3238
disc_buffer[ic.discrete_idx[h]] = val
39+
elseif haskey(ic.constant_idx, h)
40+
const_buffer[ic.constant_idx[h]] = val
41+
elseif haskey(ic.dependent_idx, h)
42+
dep_buffer[ic.dependent_idx[h]] = val
43+
dependencies[wrap(sym)] = wrap(p[sym])
3344
end
3445
end
3546

@@ -49,22 +60,46 @@ function MTKParameters(sys::AbstractSystem, p; toterm = default_toterm, tofloat
4960
end
5061
end
5162

63+
dep_exprs = ArrayPartition((wrap.(v) for v in dep_buffer.x)...)
64+
for (sym, val) in dependencies
65+
h = getsymbolhash(sym)
66+
idx = ic.dependent_idx[h]
67+
dep_exprs[idx] = wrap(fixpoint_sub(val, dependencies))
68+
end
69+
p = reorder_parameters(ic, parameters(sys))[begin:end-length(dep_buffer.x)]
70+
update_function = if isempty(dep_exprs.x)
71+
(_...) -> ()
72+
else
73+
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(build_function(dep_exprs, p...)[2])
74+
end
5275
# everything is an ArrayPartition so it's easy to figure out how many
5376
# distinct vectors we have for each portion as `ArrayPartition.x`
54-
if tunable_buffer isa ArrayPartition && isempty(tunable_buffer.x) || isempty(tunable_buffer)
77+
if isempty(tunable_buffer.x)
5578
tunable_buffer = ArrayPartition(Float64[])
5679
end
57-
if disc_buffer isa ArrayPartition && isempty(disc_buffer.x) || isempty(disc_buffer)
80+
if isempty(disc_buffer.x)
5881
disc_buffer = ArrayPartition(Float64[])
5982
end
83+
if isempty(const_buffer.x)
84+
const_buffer = ArrayPartition(Float64[])
85+
end
86+
if isempty(dep_buffer.x)
87+
dep_buffer = ArrayPartition(Float64[])
88+
end
6089
if use_union
6190
tunable_buffer = ArrayPartition(restrict_array_to_union(tunable_buffer))
6291
disc_buffer = ArrayPartition(restrict_array_to_union(disc_buffer))
92+
const_buffer = ArrayPartition(restrict_array_to_union(const_buffer))
93+
dep_buffer = ArrayPartition(restrict_array_to_union(dep_buffer))
6394
elseif tofloat
6495
tunable_buffer = ArrayPartition(Float64.(tunable_buffer))
6596
disc_buffer = ArrayPartition(Float64.(disc_buffer))
97+
const_buffer = ArrayPartition(Float64.(const_buffer))
98+
dep_buffer = ArrayPartition(Float64.(dep_buffer))
6699
end
67-
return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer)}(tunable_buffer, disc_buffer)
100+
return MTKParameters{typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer),
101+
typeof(dep_buffer), typeof(update_function)}(tunable_buffer,
102+
disc_buffer, const_buffer, dep_buffer, update_function)
68103
end
69104

70105
SciMLStructures.isscimlstructure(::MTKParameters) = true
@@ -74,20 +109,24 @@ SciMLStructures.ismutablescimlstructure(::MTKParameters) = true
74109
for (Portion, field) in [
75110
(SciMLStructures.Tunable, :tunable)
76111
(SciMLStructures.Discrete, :discrete)
112+
(SciMLStructures.Constants, :constant)
77113
]
78114
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
79115
function repack(values)
80116
p.$field .= values
117+
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
81118
end
82119
return p.$field, repack, true
83120
end
84121

85122
@eval function SciMLStructures.replace(::$Portion, p::MTKParameters, newvals)
86-
@set p.$field = newvals
123+
@set! p.$field = newvals
124+
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
87125
end
88126

89127
@eval function SciMLStructures.replace!(::$Portion, p::MTKParameters, newvals)
90128
p.$field .= newvals
129+
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
91130
nothing
92131
end
93132
end
@@ -98,6 +137,10 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, i::Paramet
98137
return p.tunable[idx]
99138
elseif portion isa SciMLStructures.Discrete
100139
return p.discrete[idx]
140+
elseif portion isa SciMLStructures.Constants
141+
return p.constant[idx]
142+
elseif portion === nothing
143+
return p.dependent[idx]
101144
else
102145
error("Unhandled portion $portion")
103146
end
@@ -109,51 +152,62 @@ function SymbolicIndexingInterface.set_parameter!(p::MTKParameters, val, idx::Pa
109152
p.tunable[idx] = val
110153
elseif portion isa SciMLStructures.Discrete
111154
p.discrete[idx] = val
155+
elseif portion isa SciMLStructures.Constants
156+
p.constant[idx] = val
157+
elseif portion === nothing
158+
error("Cannot set value of parameter: ")
112159
else
113160
error("Unhandled portion $portion")
114161
end
162+
p.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
115163
end
116164

117165
# for compiling callbacks
118166
# getindex indexes the vectors, setindex! linearly indexes values
119167
# it's inconsistent, but we need it to be this way
120168
function Base.getindex(buf::MTKParameters, i)
121-
if i <= length(buf.tunable.x)
122-
buf.tunable.x[i]
123-
else
124-
buf.discrete.x[i - length(buf.tunable.x)]
169+
if !isempty(buf.tunable)
170+
i <= length(buf.tunable.x) && return buf.tunable.x[i]
171+
i -= length(buf.tunable.x)
172+
end
173+
if !isempty(buf.discrete)
174+
i <= length(buf.discrete.x) && return buf.discrete.x[i]
175+
i -= length(buf.discrete.x)
176+
end
177+
if !isempty(buf.constant)
178+
i <= length(buf.constant.x) && return buf.constant.x[i]
179+
i -= length(buf.constant.x)
125180
end
181+
isempty(buf.dependent) || return buf.dependent.x[i]
182+
throw(BoundsError(buf, i))
126183
end
127184
function Base.setindex!(buf::MTKParameters, val, i)
128185
if i <= length(buf.tunable)
129186
buf.tunable[i] = val
130-
else
187+
elseif i <= length(buf.tunable) + length(buf.discrete)
131188
buf.discrete[i - length(buf.tunable)] = val
189+
else
190+
buf.constant[i - length(buf.tunable) - length(buf.discrete)] = val
132191
end
192+
buf.dependent_update(p.dependent, p.tunable.x..., p.discrete.x..., p.constant.x...)
133193
end
134194

135195
function Base.iterate(buf::MTKParameters, state = 1)
136-
tunable = if isempty(buf.tunable)
137-
()
138-
elseif buf.tunable isa ArrayPartition
139-
buf.tunable.x
140-
end
141-
discrete = if isempty(buf.discrete)
142-
()
143-
elseif buf.discrete isa ArrayPartition
144-
buf.discrete.x
145-
end
146-
if state <= length(tunable)
147-
return (tunable[state], state + 1)
148-
elseif state <= length(tunable) + length(discrete)
149-
return (discrete[state - length(tunable)], state + 1)
196+
total_len = 0
197+
isempty(buf.tunable) || (total_len += length(buf.tunable.x))
198+
isempty(buf.discrete) || (total_len += length(buf.discrete.x))
199+
isempty(buf.constant) || (total_len += length(buf.constant.x))
200+
isempty(buf.dependent) || (total_len += length(buf.dependent.x))
201+
if state <= total_len
202+
return (buf[state], state + 1)
150203
else
151204
return nothing
152205
end
153206
end
154207

155208
function Base.:(==)(a::MTKParameters, b::MTKParameters)
156-
return a.tunable == b.tunable && a.discrete == b.discrete
209+
return a.tunable == b.tunable && a.discrete == b.discrete &&
210+
a.constant == b.constant && a.dependent == b.dependent
157211
end
158212

159213
# to support linearize/linearization_function

0 commit comments

Comments
 (0)