Skip to content

Commit c603e9c

Browse files
Merge pull request #47 from SciML/as/model-defaults
feat: add `default_values` to interface
2 parents 2a45e9c + 7713bc0 commit c603e9c

File tree

8 files changed

+160
-12
lines changed

8 files changed

+160
-12
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@ uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
33
authors = ["Aayush Sabharwal <[email protected]> and contributors"]
44
version = "0.3.10"
55

6+
[deps]
7+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
8+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
9+
610
[compat]
711
Aqua = "0.8"
12+
MacroTools = "0.5.13"
13+
RuntimeGeneratedFunctions = "0.5"
814
SafeTestsets = "0.0.1"
915
Test = "1"
1016
julia = "1.10"

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ parameter_symbols
1313
is_independent_variable
1414
independent_variable_symbols
1515
is_observed
16+
default_values
1617
is_time_dependent
1718
constant_structure
1819
all_variable_symbols
@@ -73,6 +74,7 @@ NotSymbolic
7374
symbolic_type
7475
hasname
7576
getname
77+
symbolic_evaluate
7678
```
7779

7880
# Types

docs/src/complete_sii.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct ExampleSystem
88
state_index::Dict{Symbol,Int}
99
parameter_index::Dict{Symbol,Int}
1010
independent_variable::Union{Symbol,Nothing}
11+
defaults::Dict{Symbol, Float64}
1112
# mapping from observed variable to Expr to calculate its value
1213
observed::Dict{Symbol,Expr}
1314
end
@@ -77,6 +78,10 @@ function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem)
7778
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
7879
)
7980
end
81+
82+
function SymbolicIndexingInterface.default_values(sys::ExampleSystem)
83+
return sys.defaults
84+
end
8085
```
8186

8287
### Observed Equation Handling
@@ -339,6 +344,37 @@ end
339344
[`hasname`](@ref) is not required to always be `true` for symbolic types. For example,
340345
`Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression.
341346

347+
Introducing a type to represent expression trees:
348+
349+
```julia
350+
struct MyExpr
351+
op::Function
352+
args::Vector{Union{MyExpr, MySym, MySymArr, Number, Array}}
353+
end
354+
```
355+
356+
[`symbolic_evaluate`](@ref) can be implemented as follows:
357+
358+
```julia
359+
function symbolic_evaluate(expr::Union{MySym, MySymArr}, syms::Dict)
360+
get(syms, expr, expr)
361+
end
362+
function symbolic_evaluate(expr::MyExpr, syms::Dict)
363+
for i in eachindex(expr.args)
364+
if expr.args[i] isa Union{MySym, MySymArr, MyExpr}
365+
expr.args[i] = symbolic_evaluate(expr.args[i], syms)
366+
end
367+
end
368+
if all(x -> symbolic_type(x) === NotSymbolic(), expr.args)
369+
return expr.op(expr.args...)
370+
end
371+
end
372+
```
373+
374+
Note the evaluation of the operation if all of the arguments are not symbolic. This is
375+
required since `symbolic_evaluate` must return an evaluated value if all symbolic variables
376+
are substituted.
377+
342378
## Parameter Timeseries
343379

344380
If a solution object saves modified parameter values (such as through callbacks) during the

src/SymbolicIndexingInterface.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
module SymbolicIndexingInterface
22

3+
import MacroTools
4+
using RuntimeGeneratedFunctions
5+
RuntimeGeneratedFunctions.init(@__MODULE__)
6+
37
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
48
Timeseries, NotTimeseries, is_timeseries
59
include("trait.jl")
@@ -9,7 +13,7 @@ export is_variable, variable_index, variable_symbols, is_parameter, parameter_in
913
is_observed,
1014
observed, is_time_dependent, constant_structure, symbolic_container,
1115
all_variable_symbols,
12-
all_symbols, solvedvariables, allvariables
16+
all_symbols, solvedvariables, allvariables, default_values, symbolic_evaluate
1317
include("interface.jl")
1418

1519
export SymbolCache

src/interface.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ variables.
134134
"""
135135
all_symbols(sys) = all_symbols(symbolic_container(sys))
136136

137+
"""
138+
default_values(sys)
139+
140+
Return a dictionary mapping symbols in the system to their default value, if any. This
141+
includes parameter symbols. The dictionary must be mutable.
142+
"""
143+
default_values(sys) = default_values(symbolic_container(sys))
144+
137145
struct SolvedVariables end
138146

139147
"""

src/symbol_cache.jl

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,35 @@
22
struct SymbolCache{V,P,I}
33
function SymbolCache(vars, [params, [indepvars]])
44
5-
A struct implementing the symbolic indexing interface for the trivial case
6-
of having a vector of variables, parameters, and independent variables. This
7-
struct does not implement `observed`, and `is_observed` returns `false` for
8-
all input symbols. It is considered time dependent if it contains
9-
at least one independent variable.
5+
A struct implementing the symbolic indexing interface for the trivial case of having a
6+
vector of variables, parameters, and independent variables. It is considered time
7+
dependent if it contains at least one independent variable. It returns `true` for
8+
`is_observed(::SymbolCache, sym)` if `sym isa Expr`. Functions can be generated using
9+
`observed` for `Expr`s involving variables in the `SymbolCache` if it has at most one
10+
independent variable.
1011
1112
The independent variable may be specified as a single symbolic variable instead of an
1213
array containing a single variable if the system has only one independent variable.
1314
"""
1415
struct SymbolCache{
1516
V <: Union{Nothing, AbstractVector},
1617
P <: Union{Nothing, AbstractVector},
17-
I
18+
I,
19+
D <: Dict
1820
}
1921
variables::V
2022
parameters::P
2123
independent_variables::I
24+
defaults::D
2225
end
2326

24-
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing)
25-
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars)}(vars,
27+
function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
28+
defaults = Dict{Symbol, Union{Symbol, Expr, Number}}())
29+
return SymbolCache{typeof(vars), typeof(params), typeof(indepvars), typeof(defaults)}(
30+
vars,
2631
params,
27-
indepvars)
32+
indepvars,
33+
defaults)
2834
end
2935

3036
function is_variable(sc::SymbolCache, sym)
@@ -62,6 +68,45 @@ function independent_variable_symbols(sc::SymbolCache)
6268
end
6369
end
6470
is_observed(sc::SymbolCache, sym) = false
71+
is_observed(::SymbolCache, ::Expr) = true
72+
function observed(sc::SymbolCache, expr::Expr)
73+
let cache = Dict{Expr, Function}()
74+
return get!(cache, expr) do
75+
fnbody = Expr(:block)
76+
declared = Set{Symbol}()
77+
MacroTools.postwalk(expr) do sym
78+
sym isa Symbol || return
79+
sym in declared && return
80+
if sc.variables !== nothing &&
81+
(idx = findfirst(isequal(sym), sc.variables)) !== nothing
82+
push!(fnbody.args, :($sym = u[$idx]))
83+
push!(declared, sym)
84+
elseif sc.parameters !== nothing &&
85+
(idx = findfirst(isequal(sym), sc.parameters)) !== nothing
86+
push!(fnbody.args, :($sym = p[$idx]))
87+
push!(declared, sym)
88+
elseif sym === sc.independent_variables ||
89+
sc.independent_variables isa Vector &&
90+
sym == only(sc.independent_variables)
91+
push!(fnbody.args, :($sym = t))
92+
push!(declared, sym)
93+
end
94+
end
95+
fnexpr = if is_time_dependent(sc)
96+
:(function (u, p, t)
97+
$fnbody
98+
return $expr
99+
end)
100+
else
101+
:(function (u, p)
102+
$fnbody
103+
return $expr
104+
end)
105+
end
106+
return RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(fnexpr)
107+
end
108+
end
109+
end
65110
function is_time_dependent(sc::SymbolCache)
66111
sc.independent_variables === nothing && return false
67112
if symbolic_type(sc.independent_variables) == NotSymbolic()
@@ -75,10 +120,11 @@ all_variable_symbols(sc::SymbolCache) = variable_symbols(sc)
75120
function all_symbols(sc::SymbolCache)
76121
vcat(variable_symbols(sc), parameter_symbols(sc), independent_variable_symbols(sc))
77122
end
123+
default_values(sc::SymbolCache) = sc.defaults
78124

79125
function Base.copy(sc::SymbolCache)
80126
return SymbolCache(sc.variables === nothing ? nothing : copy(sc.variables),
81127
sc.parameters === nothing ? nothing : copy(sc.parameters),
82128
sc.independent_variables isa AbstractArray ? copy(sc.independent_variables) :
83-
sc.independent_variables)
129+
sc.independent_variables, copy(sc.defaults))
84130
end

src/trait.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,33 @@ Get the name of a symbolic variable as a `Symbol`
6161
"""
6262
function getname end
6363

64+
"""
65+
symbolic_evaluate(expr, syms::Dict)
66+
67+
Return the value of symbolic expression `expr` where the values of variables involved are
68+
obtained from the dictionary `syms`. The keys of `syms` are symbolic variables (not
69+
expressions of variables). The values of `syms` can be values or symbolic
70+
expressions.
71+
72+
The returned value should either be a value or an expression involving symbolic variables
73+
not present as keys in `syms`.
74+
75+
This is already implemented for
76+
`symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})`.
77+
"""
78+
function symbolic_evaluate(expr::Union{Symbol, Expr}, syms::Dict{Symbol})
79+
while (new_expr = MacroTools.postwalk(expr) do sym
80+
return get(syms, sym, sym)
81+
end) != expr
82+
expr = new_expr
83+
end
84+
return try
85+
eval(expr)
86+
catch
87+
expr
88+
end
89+
end
90+
6491
############ IsTimeseriesTrait
6592

6693
abstract type IsTimeseriesTrait end

test/symbol_cache_test.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using SymbolicIndexingInterface
22
using Test
33

4-
sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
4+
sc = SymbolCache(
5+
[:x, :y, :z], [:a, :b], [:t]; defaults = Dict(:x => 1, :y => :(2b), :b => :(2a + x)))
56

67
@test all(is_variable.((sc,), [:x, :y, :z]))
78
@test all(.!is_variable.((sc,), [:a, :b, :t, :q]))
@@ -19,10 +20,26 @@ sc = SymbolCache([:x, :y, :z], [:a, :b], [:t])
1920
@test independent_variable_symbols(sc) == [:t]
2021
@test all_variable_symbols(sc) == [:x, :y, :z]
2122
@test sort(all_symbols(sc)) == [:a, :b, :t, :x, :y, :z]
23+
@test default_values(sc)[:x] == 1
24+
@test default_values(sc)[:y] == :(2b)
25+
@test default_values(sc)[:b] == :(2a + x)
26+
27+
@test symbolic_evaluate(:x, default_values(sc)) == 1
28+
@test symbolic_evaluate(:y, default_values(sc)) == :(2 * (2a + 1))
29+
@test symbolic_evaluate(:(x + y), merge(default_values(sc), Dict(:a => 2))) == 11
30+
31+
@test is_observed(sc, :(x + a + t))
32+
obsfn = observed(sc, :(x + a + t))
33+
@test obsfn(ones(3), 2ones(2), 3.0) == 6.0
34+
obsfn2 = observed(sc, :(x + a + t))
35+
@test obsfn === obsfn2
2236

2337
sc = SymbolCache([:x, :y], [:a, :b])
2438
@test !is_time_dependent(sc)
2539
@test sort(all_symbols(sc)) == [:a, :b, :x, :y]
40+
@test is_observed(sc, :(x + b))
41+
obsfn = observed(sc, :(x + b))
42+
@test obsfn(ones(2), 2ones(2)) == 3.0
2643
# make sure the constructor works
2744
@test_nowarn SymbolCache([:x, :y])
2845

@@ -38,6 +55,7 @@ sc = SymbolCache()
3855
@test !is_time_dependent(sc)
3956
@test all_variable_symbols(sc) == []
4057
@test all_symbols(sc) == []
58+
@test isempty(default_values(sc))
4159

4260
sc = SymbolCache(nothing, nothing, :t)
4361
@test all(.!is_independent_variable.((sc,), [:x, :y, :a, :b]))
@@ -46,6 +64,7 @@ sc = SymbolCache(nothing, nothing, :t)
4664
@test is_time_dependent(sc)
4765
@test all_variable_symbols(sc) == []
4866
@test all_symbols(sc) == [:t]
67+
@test isempty(default_values(sc))
4968

5069
sc2 = copy(sc)
5170
@test sc.variables == sc2.variables

0 commit comments

Comments
 (0)