Skip to content

Commit 1394255

Browse files
feat: add timeseries parameters support to SymbolCache
1 parent 4650a34 commit 1394255

File tree

1 file changed

+168
-20
lines changed

1 file changed

+168
-20
lines changed

src/symbol_cache.jl

Lines changed: 168 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,103 @@
11
"""
2-
struct SymbolCache{V,P,I}
3-
function SymbolCache(vars, [params, [indepvars]])
2+
struct SymbolCache
3+
function SymbolCache(vars, [params, [indepvars]]; defaults = Dict(), timeseries_parameters = nothing)
44
55
A struct implementing the index provider interface for the trivial case of having a
66
vector of variables, parameters, and independent variables. It is considered time
77
dependent if it contains at least one independent variable. It returns `true` for
88
`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using
99
`observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one
10-
independent variable.
10+
independent variable. `defaults` is an `AbstractDict` mapping variables and/or parameters
11+
to their default initial values. The default initial values can also be other variables/
12+
parameters or expressions of them. `timeseries_parameters` is an `AbstractDict` the
13+
timeseries parameters in `params` to their [`ParameterTimeseriesIndex`](@ref) indexes.
14+
15+
Instead of arrays, the variables and parameters can also be provided as `AbstractDict`s
16+
mapping symbols to indices.
1117
1218
The independent variable may be specified as a single symbolic variable instead of an
1319
array containing a single variable if the system has only one independent variable.
1420
"""
1521
struct SymbolCache{
16-
V <: Union{Nothing, AbstractVector},
17-
P <: Union{Nothing, AbstractVector},
22+
V <: Union{Nothing, AbstractDict},
23+
P <: Union{Nothing, AbstractDict},
24+
T <: Union{Nothing, AbstractDict},
1825
I,
19-
D <: Dict
26+
D <: AbstractDict
2027
}
2128
variables::V
2229
parameters::P
30+
timeseries_parameters::T
2331
independent_variables::I
2432
defaults::D
2533
end
2634

35+
function to_dict_or_nothing(arr::Union{AbstractArray, Tuple})
36+
eltype(arr) <: Pair && return Dict(arr)
37+
isempty(arr) && return nothing
38+
return Dict(v => k for (k, v) in enumerate(arr))
39+
end
40+
to_dict_or_nothing(d::AbstractDict) = d
41+
to_dict_or_nothing(::Nothing) = nothing
42+
2743
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
28-
defaults = Dict{Symbol, Union{Symbol, Expr, Number}}())
29-
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}(
44+
defaults = Dict(), timeseries_parameters = nothing)
45+
vars = to_dict_or_nothing(vars)
46+
params = to_dict_or_nothing(params)
47+
timeseries_parameters = to_dict_or_nothing(timeseries_parameters)
48+
if timeseries_parameters !== nothing
49+
if indepvars === nothing
50+
throw(ArgumentError("Independent variable is required for timeseries parameters to exist"))
51+
end
52+
for (k, v) in timeseries_parameters
53+
if !haskey(params, k)
54+
throw(ArgumentError("Timeseries parameter $k must also be present in parameters."))
55+
end
56+
if !isa(v, ParameterTimeseriesIndex)
57+
throw(TypeError(:SymbolCache, "index of timeseries parameter $k",
58+
ParameterTimeseriesIndex, v))
59+
end
60+
end
61+
end
62+
return SymbolCache{typeof(vars), typeof(params), typeof(timeseries_parameters),
63+
typeof(indepvars), typeof(defaults)}(
3064
vars,
3165
params,
66+
timeseries_parameters,
3267
indepvars,
3368
defaults)
3469
end
3570

3671
function is_variable(sc::SymbolCache, sym)
37-
sc.variables !== nothing && any(isequal(sym), sc.variables)
72+
sc.variables !== nothing && haskey(sc.variables, sym)
3873
end
3974
function variable_index(sc::SymbolCache, sym)
40-
sc.variables === nothing ? nothing : findfirst(isequal(sym), sc.variables)
75+
sc.variables === nothing ? nothing : get(sc.variables, sym, nothing)
76+
end
77+
function variable_symbols(sc::SymbolCache, i = nothing)
78+
sc.variables === nothing && return []
79+
buffer = collect(keys(sc.variables))
80+
for (k, v) in sc.variables
81+
buffer[v] = k
82+
end
83+
return buffer
4184
end
42-
variable_symbols(sc::SymbolCache, i = nothing) = something(sc.variables, [])
4385
function is_parameter(sc::SymbolCache, sym)
44-
sc.parameters !== nothing && any(isequal(sym), sc.parameters)
86+
sc.parameters !== nothing && haskey(sc.parameters, sym)
4587
end
4688
function parameter_index(sc::SymbolCache, sym)
47-
sc.parameters === nothing ? nothing : findfirst(isequal(sym), sc.parameters)
89+
sc.parameters === nothing ? nothing : get(sc.parameters, sym, nothing)
90+
end
91+
function parameter_symbols(sc::SymbolCache)
92+
sc.parameters === nothing ? [] : collect(keys(sc.parameters))
93+
end
94+
function is_timeseries_parameter(sc::SymbolCache, sym)
95+
sc.timeseries_parameters !== nothing && haskey(sc.timeseries_parameters, sym)
96+
end
97+
function timeseries_parameter_index(sc::SymbolCache, sym)
98+
sc.timeseries_parameters === nothing ? nothing :
99+
get(sc.timeseries_parameters, sym, nothing)
48100
end
49-
parameter_symbols(sc::SymbolCache) = something(sc.parameters, [])
50101
function is_independent_variable(sc::SymbolCache, sym)
51102
sc.independent_variables === nothing && return false
52103
if symbolic_type(sc.independent_variables) == NotSymbolic()
@@ -72,12 +123,14 @@ is_observed(::SymbolCache, ::Expr) = true
72123
is_observed(::SymbolCache, ::AbstractArray{Expr}) = true
73124
is_observed(::SymbolCache, ::Tuple{Vararg{Expr}}) = true
74125

126+
# TODO: Make this less hacky
75127
struct ExpressionSearcher
128+
parameters::Set{Symbol}
76129
declared::Set{Symbol}
77130
fnbody::Expr
78131
end
79132

80-
ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Expr(:block))
133+
ExpressionSearcher() = ExpressionSearcher(Set{Symbol}(), Set{Symbol}(), Expr(:block))
81134

82135
function (exs::ExpressionSearcher)(sys, expr::Expr)
83136
for arg in expr.args
@@ -94,7 +147,8 @@ function (exs::ExpressionSearcher)(sys, sym::Symbol)
94147
push!(exs.fnbody.args, :($sym = u[$idx]))
95148
elseif is_parameter(sys, sym)
96149
idx = parameter_index(sys, sym)
97-
push!(exs.fnbody.args, :($sym = p[$idx]))
150+
push!(exs.parameters, sym)
151+
push!(exs.fnbody.args, :($sym = parameter_values(p, $idx)))
98152
elseif is_independent_variable(sys, sym)
99153
push!(exs.fnbody.args, :($sym = t))
100154
end
@@ -124,11 +178,104 @@ function observed(sc::SymbolCache, expr::Expr)
124178
end
125179
end
126180
end
127-
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
128-
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
181+
182+
to_expr(exprs::AbstractArray) = :(reshape([$(exprs...)], $(size(exprs))))
183+
to_expr(exprs::Tuple) = :(($(exprs...),))
184+
185+
function inplace_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple})
186+
let cache = Dict{Expr, Function}()
187+
return get!(cache, to_expr(exprs)) do
188+
exs = ExpressionSearcher()
189+
for expr in exprs
190+
exs(sc, expr)
191+
end
192+
update_expr = Expr(:block)
193+
for (i, expr) in enumerate(exprs)
194+
push!(update_expr.args, :(buffer[$i] = $expr))
195+
end
196+
fnexpr = if is_time_dependent(sc)
197+
:(function (buffer, u, p, t)
198+
$(exs.fnbody)
199+
$update_expr
200+
return buffer
201+
end)
202+
else
203+
:(function (buffer, u, p)
204+
$(exs.fnbody)
205+
$update_expr
206+
return buffer
207+
end)
208+
end
209+
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
210+
end
211+
end
212+
end
213+
214+
function observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple})
215+
for expr in exprs
216+
if !(expr isa Union{Symbol, Expr})
217+
throw(TypeError(:observed, "SymbolCache", Union{Symbol, Expr}, expr))
218+
end
219+
end
220+
return observed(sc, to_expr(exprs))
129221
end
130-
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
131-
return observed(sc, :(($(exprs...),)))
222+
223+
function parameter_observed(sc::SymbolCache, expr::Expr)
224+
if is_time_dependent(sc)
225+
exs = ExpressionSearcher()
226+
exs(sc, expr)
227+
ts_idxs = Set()
228+
for p in exs.parameters
229+
is_timeseries_parameter(sc, p) || continue
230+
push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx)
231+
end
232+
f = let fn = observed(sc, expr)
233+
f1(p, t) = fn(nothing, p, t)
234+
end
235+
if length(ts_idxs) == 1
236+
return ParameterObservedFunction(only(ts_idxs), f)
237+
else
238+
return ParameterObservedFunction(nothing, f)
239+
end
240+
else
241+
f = let fn = observed(sc, expr)
242+
f2(p) = fn(nothing, p)
243+
end
244+
return ParameterObservedFunction(nothing, f)
245+
end
246+
end
247+
248+
function parameter_observed(sc::SymbolCache, exprs::Union{AbstractArray, Tuple})
249+
for ex in exprs
250+
if !(ex isa Union{Symbol, Expr})
251+
throw(TypeError(:parameter_observed, "SymbolCache", Union{Symbol, Expr}, ex))
252+
end
253+
end
254+
if is_time_dependent(sc)
255+
exs = ExpressionSearcher()
256+
exs(sc, to_expr(exprs))
257+
ts_idxs = Set()
258+
for p in exs.parameters
259+
is_timeseries_parameter(sc, p) || continue
260+
push!(ts_idxs, timeseries_parameter_index(sc, p).timeseries_idx)
261+
end
262+
263+
f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
264+
f1(p, t) = oop(nothing, p, t)
265+
f1(buffer, p, t) = iip(buffer, nothing, p, t)
266+
end
267+
if length(ts_idxs) == 1
268+
return ParameterObservedFunction(only(ts_idxs), f)
269+
else
270+
return ParameterObservedFunction(nothing, f)
271+
end
272+
else
273+
f = let oop = observed(sc, to_expr(exprs)), iip = inplace_observed(sc, exprs)
274+
f2(p) = oop(nothing, p)
275+
f2(buffer, p) = iip(buffer, nothing, p)
276+
end
277+
return ParameterObservedFunction(nothing, f)
278+
end
132279
end
133280

134281
function is_time_dependent(sc::SymbolCache)
@@ -149,6 +296,7 @@ default_values(sc::SymbolCache) = sc.defaults
149296
function Base.copy(sc::SymbolCache)
150297
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
151298
sc.parameters === nothing ? nothing : copy(sc.parameters),
299+
sc.timeseries_parameters === nothing ? nothing : copy(sc.timeseries_parameters),
152300
sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) :
153301
sc.independent_variables, copy(sc.defaults))
154302
end

0 commit comments

Comments
 (0)