Skip to content

Commit bb5f578

Browse files
Merge pull request #53 from SciML/as/discrete-indexing
feat: support indexing in mixed discrete-continuous systems
2 parents fdf7ba8 + 258355e commit bb5f578

File tree

9 files changed

+368
-67
lines changed

9 files changed

+368
-67
lines changed

docs/src/api.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ getu
5252
setu
5353
```
5454

55+
### Parameter timeseries
56+
57+
If a solution object saves a timeseries of parameter values that are updated during the
58+
simulation (such as by callbacks), it must implement the following methods to ensure
59+
correct functioning of [`getu`](@ref) and [`getp`](@ref).
60+
61+
```@docs
62+
parameter_timeseries
63+
parameter_values_at_time
64+
parameter_values_at_state_time
65+
```
66+
5567
# Symbolic Trait
5668

5769
```@docs

docs/src/complete_sii.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,3 +338,81 @@ end
338338

339339
[`hasname`](@ref) is not required to always be `true` for symbolic types. For example,
340340
`Symbolics.Num` returns `false` whenever the wrapped value is a number, or an expression.
341+
342+
## Parameter Timeseries
343+
344+
If a solution object saves modified parameter values (such as through callbacks) during the
345+
simulation, it must implement [`parameter_timeseries`](@ref),
346+
[`parameter_values_at_time`](@ref) and [`parameter_values_at_state_time`](@ref) for correct
347+
functioning of [`getu`](@ref) and [`getp`](@ref). The following mockup gives an example
348+
of correct implementation of these functions and the indexing syntax they enable.
349+
350+
```@example param_timeseries
351+
using SymbolicIndexingInterface
352+
353+
struct ExampleSolution2
354+
sys::SymbolCache
355+
u::Vector{Vector{Float64}}
356+
t::Vector{Float64}
357+
p::Vector{Vector{Float64}}
358+
pt::Vector{Float64}
359+
end
360+
361+
# Add the `:ps` property to automatically wrap in `ParameterIndexingProxy`
362+
function Base.getproperty(fs::ExampleSolution2, s::Symbol)
363+
s === :ps ? ParameterIndexingProxy(fs) : getfield(fs, s)
364+
end
365+
# Use the contained `SymbolCache` for indexing
366+
SymbolicIndexingInterface.symbolic_container(fs::ExampleSolution2) = fs.sys
367+
# By default, `parameter_values` refers to the last value
368+
SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2) = fs.p[end]
369+
SymbolicIndexingInterface.parameter_values(fs::ExampleSolution2, i) = fs.p[end][i]
370+
# Index into the parameter timeseries vector
371+
function SymbolicIndexingInterface.parameter_values_at_time(fs::ExampleSolution2, t)
372+
fs.p[t]
373+
end
374+
# Find the first index in the parameter timeseries vector with a time smaller
375+
# than the time from the state timeseries, and use that to index the parameter
376+
# timeseries
377+
function SymbolicIndexingInterface.parameter_values_at_state_time(fs::ExampleSolution2, t)
378+
ptind = searchsortedfirst(fs.pt, fs.t[t]; lt = <=)
379+
fs.p[ptind - 1]
380+
end
381+
SymbolicIndexingInterface.parameter_timeseries(fs::ExampleSolution2) = fs.pt
382+
# Mark the object as a `Timeseries` object
383+
SymbolicIndexingInterface.is_timeseries(::Type{ExampleSolution2}) = Timeseries()
384+
385+
```
386+
387+
Now we can create an example object and observe the new functionality. Note that
388+
`sol.ps[sym, args...]` is identical to `getp(sol, sym)(sol, args...)`.
389+
390+
```@example param_timeseries
391+
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
392+
sol = ExampleSolution2(
393+
sys,
394+
[i * ones(3) for i in 1:5],
395+
[0.2i for i in 1:5],
396+
[2i * ones(3) for i in 1:10],
397+
[0.1i for i in 1:10]
398+
)
399+
sol.ps[:a] # returns the value at the last timestep
400+
```
401+
402+
```@example param_timeseries
403+
sol.ps[:a, :] # use Colon to fetch the entire parameter timeseries
404+
```
405+
406+
```@example param_timeseries
407+
sol.ps[:a, 3] # index at a specific index in the parameter timeseries
408+
```
409+
410+
```@example param_timeseries
411+
sol.ps[:a, [3, 6, 8]] # index using arrays
412+
```
413+
414+
```@example param_timeseries
415+
idxs = @show rand(Bool, 10) # boolean mask for indexing
416+
sol.ps[:a, idxs]
417+
```
418+

docs/src/usage.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ sol2 = solve(prob, Tsit5())
168168
σ_ρ_getter(sol)
169169
```
170170

171-
To set the entire parameter vector at once, [`parameter_values`](@ref) can be used
172-
(note the usage of broadcasted assignment).
171+
To set the entire parameter vector at once, [`setp`](@ref) can be used
172+
(note that the order of symbols passed to `setp` must match the order of values in the array).
173173

174174
```@example Usage
175-
parameter_values(prob) .= [29.0, 11.0, 2.5]
175+
setp(prob, parameter_symbols(prob))(prob, [29.0, 11.0, 2.5])
176176
parameter_values(prob)
177177
```
178178

src/SymbolicIndexingInterface.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SymbolicIndexingInterface
22

3-
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname
3+
export ScalarSymbolic, ArraySymbolic, NotSymbolic, symbolic_type, hasname, getname,
4+
Timeseries, NotTimeseries, is_timeseries
45
include("trait.jl")
56

67
export is_variable, variable_index, variable_symbols, is_parameter, parameter_index,
@@ -14,11 +15,11 @@ include("interface.jl")
1415
export SymbolCache
1516
include("symbol_cache.jl")
1617

17-
export parameter_values, set_parameter!, getp, setp
18+
export parameter_values, set_parameter!, parameter_values_at_time,
19+
parameter_values_at_state_time, parameter_timeseries, getp, setp
1820
include("parameter_indexing.jl")
1921

20-
export Timeseries,
21-
NotTimeseries, is_timeseries, state_values, set_state!, current_time, getu, setu
22+
export state_values, set_state!, current_time, getu, setu
2223
include("state_indexing.jl")
2324

2425
export ParameterIndexingProxy

src/parameter_indexing.jl

Lines changed: 135 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,56 @@ parameter_values(arr::AbstractArray) = arr
1515
parameter_values(arr::AbstractArray, i) = arr[i]
1616
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
1717

18+
"""
19+
parameter_values_at_time(p, i)
20+
21+
Return an indexable collection containing the value of all parameters in `p` at time index
22+
`i`. This is useful when parameter values change during the simulation
23+
(such as through callbacks) and their values are saved. `i` is the time index in the
24+
timeseries formed by these changing parameter values, obtained using
25+
[`parameter_timeseries`](@ref).
26+
27+
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
28+
to be specialized for timeseries objects where parameter values are not constant at all
29+
times. The resultant object should be indexable using [`parameter_values`](@ref).
30+
31+
If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
32+
implemented for [`getu`](@ref) to work correctly.
33+
"""
34+
function parameter_values_at_time end
35+
parameter_values_at_time(p, i) = parameter_values(p)
36+
37+
"""
38+
parameter_values_at_state_time(p, i)
39+
40+
Return an indexable collection containing the value of all parameters in `p` at time
41+
index `i`. This is useful when parameter values change during the simulation (such as
42+
through callbacks) and their values are saved. `i` is the time index in the timeseries
43+
formed by dependent variables (as opposed to the timeseries of the parameters, as in
44+
[`parameter_values_at_time`](@ref)).
45+
46+
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
47+
to be specialized for timeseries objects where parameter values are not constant at
48+
all times. The resultant object should be indexable using [`parameter_values`](@ref).
49+
50+
If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
51+
[`getp`](@ref) to work correctly.
52+
"""
53+
function parameter_values_at_state_time end
54+
parameter_values_at_state_time(p, i) = parameter_values(p)
55+
56+
"""
57+
parameter_timeseries(p)
58+
59+
Return an iterable of time steps at which the parameter values are saved. This is only
60+
required for objects where `is_timeseries(p) === Timeseries()` and the parameter values
61+
change during the simulation (such as through callbacks). By default, this returns `[0]`.
62+
63+
See also: [`parameter_values_at_time`](@ref).
64+
"""
65+
function parameter_timeseries end
66+
parameter_timeseries(_) = [0]
67+
1868
"""
1969
set_parameter!(sys, val, idx)
2070
@@ -47,6 +97,13 @@ solution from which the values are obtained.
4797
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
4898
typically does not need to be implemented, and has a default implementation relying on
4999
[`parameter_values`](@ref).
100+
101+
If the returned function is used on a timeseries object which saves parameter timeseries, it
102+
can be used to index said timeseries. The timeseries object must implement
103+
[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref) and
104+
[`parameter_values_at_state_time`](@ref). The function returned from `getp` will can be passed
105+
`Colon()` (`:`) as the last argument to return the entire parameter timeseries for `p`, or
106+
any index into the parameter timeseries for a subset of values.
50107
"""
51108
function getp(sys, p)
52109
symtype = symbolic_type(p)
@@ -55,18 +112,42 @@ function getp(sys, p)
55112
end
56113

57114
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
58-
return function getter(sol)
59-
return parameter_values(sol, p)
115+
return let p = p
116+
function _getter(::NotTimeseries, prob)
117+
parameter_values(prob, p)
118+
end
119+
function _getter(::Timeseries, prob)
120+
parameter_values(prob, p)
121+
end
122+
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
123+
parameter_values(
124+
parameter_values_at_time(
125+
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
126+
p)
127+
end
128+
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
129+
parameter_values.(
130+
parameter_values_at_time.((prob,),
131+
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
132+
p)
133+
end
134+
function _getter(::Timeseries, prob, i)
135+
parameter_values.(parameter_values_at_time.((prob,), i), (p,))
136+
end
137+
getter = let _getter = _getter
138+
function getter(prob, args...)
139+
return _getter(is_timeseries(prob), prob, args...)
140+
end
141+
end
142+
getter
60143
end
61144
end
62145

63146
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
64147
idx = parameter_index(sys, p)
65-
return let idx = idx
66-
function getter(sol)
67-
return parameter_values(sol, idx)
68-
end
69-
end
148+
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
149+
sys, NotSymbolic(), NotSymbolic(), idx)
150+
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)
70151
end
71152

72153
for (t1, t2) in [
@@ -78,15 +159,57 @@ for (t1, t2) in [
78159
getters = getp.((sys,), p)
79160

80161
return let getters = getters
81-
function getter(sol)
82-
map(g -> g(sol), getters)
162+
function _getter(::NotTimeseries, prob)
163+
map(g -> g(prob), getters)
83164
end
84-
function getter(buffer, sol)
85-
for (i, g) in zip(eachindex(buffer), getters)
86-
buffer[i] = g(sol)
165+
function _getter(::Timeseries, prob)
166+
map(g -> g(prob), getters)
167+
end
168+
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
169+
map(g -> g(prob, i), getters)
170+
end
171+
function _getter(::Timeseries, prob, i)
172+
[map(g -> g(prob, j), getters)
173+
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
174+
end
175+
function _getter!(buffer, ::NotTimeseries, prob)
176+
for (g, bufi) in zip(getters, eachindex(buffer))
177+
buffer[bufi] = g(prob)
87178
end
88179
buffer
89180
end
181+
function _getter!(buffer, ::Timeseries, prob)
182+
for (g, bufi) in zip(getters, eachindex(buffer))
183+
buffer[bufi] = g(prob)
184+
end
185+
buffer
186+
end
187+
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
188+
for (g, bufi) in zip(getters, eachindex(buffer))
189+
buffer[bufi] = g(prob, i)
190+
end
191+
buffer
192+
end
193+
function _getter!(buffer, ::Timeseries, prob, i)
194+
for (bufi, tsi) in zip(
195+
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
196+
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
197+
buffer[bufi][bufj] = g(prob, tsi)
198+
end
199+
end
200+
buffer
201+
end
202+
_getter, _getter!
203+
getter = let _getter = _getter, _getter! = _getter!
204+
function getter(prob, i...)
205+
return _getter(is_timeseries(prob), prob, i...)
206+
end
207+
function getter(buffer::AbstractArray, prob, i...)
208+
return _getter!(buffer, is_timeseries(prob), prob, i...)
209+
end
210+
getter
211+
end
212+
getter
90213
end
91214
end
92215
end

src/parameter_indexing_proxy.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ struct ParameterIndexingProxy{T}
1010
wrapped::T
1111
end
1212

13-
function Base.getindex(p::ParameterIndexingProxy, idx)
14-
return getp(p.wrapped, idx)(p.wrapped)
13+
function Base.getindex(p::ParameterIndexingProxy, idx, args...)
14+
getp(p.wrapped, idx)(p.wrapped, args...)
1515
end
1616

1717
function Base.setindex!(p::ParameterIndexingProxy, val, idx)

0 commit comments

Comments
 (0)