Skip to content

Commit 9a84642

Browse files
fix: use new trait in parameter timeseries indexing
1 parent 9a8d1c0 commit 9a84642

File tree

4 files changed

+93
-51
lines changed

4 files changed

+93
-51
lines changed

src/parameter_indexing.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,3 @@
1-
parameter_values(arr::AbstractArray) = arr
2-
parameter_values(arr::Tuple) = arr
3-
parameter_values(arr::AbstractArray, i) = arr[i]
4-
parameter_values(arr::Tuple, i) = arr[i]
5-
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
6-
7-
parameter_values_at_time(p, i) = parameter_values(p)
8-
9-
parameter_values_at_state_time(p, i) = parameter_values(p)
10-
11-
parameter_timeseries(_) = [0]
12-
13-
# Tuple only included for the error message
14-
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
15-
sys[idx] = val
16-
end
17-
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
18-
191
"""
202
getp(indp, sym)
213
@@ -35,18 +17,19 @@ that implement `getindex`.
3517
3618
If the returned function is used on a timeseries object which saves parameter timeseries, it
3719
can be used to index said timeseries. The timeseries object must implement
38-
[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref) and
39-
[`parameter_values_at_state_time`](@ref). The function returned from `getp` will can be passed
40-
`Colon()` (`:`) as the last argument to return the entire parameter timeseries for `p`, or
41-
any index into the parameter timeseries for a subset of values.
20+
[`parameter_timeseries`](@ref), [`parameter_values_at_time`](@ref),
21+
[`parameter_values_at_state_time`](@ref), and [`is_parameter_timeseries`](@ref). The function
22+
returned from `getp` can be passed `Colon()` (`:`) as the last argument to return the
23+
entire parameter timeseries for `p`, or any index into the parameter timeseries for a
24+
subset of values.
4225
"""
4326
function getp(sys, p)
4427
symtype = symbolic_type(p)
4528
elsymtype = symbolic_type(eltype(p))
4629
_getp(sys, symtype, elsymtype, p)
4730
end
4831

49-
struct GetParameterIndex{I} <: AbstractGetIndexer
32+
struct GetParameterIndex{I} <: AbstractParameterGetIndexer
5033
idx::I
5134
end
5235

@@ -59,10 +42,16 @@ function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{Int, CartesianInd
5942
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
6043
gpi.idx)
6144
end
62-
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
45+
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Colon)
6346
parameter_values.(
6447
parameter_values_at_time.((prob,),
65-
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
48+
only(to_indices(parameter_timeseries(prob), (i,)))),
49+
(gpi.idx,))
50+
end
51+
function (gpi::GetParameterIndex)(::Timeseries, prob, i::AbstractArray{Bool})
52+
parameter_values.(
53+
map(Base.Fix1(parameter_values_at_time, prob),
54+
only(to_indices(parameter_timeseries(prob), (i,)))),
6655
(gpi.idx,))
6756
end
6857
function (gpi::GetParameterIndex)(::Timeseries, prob, i)
@@ -84,14 +73,19 @@ struct MultipleParameterGetters{G} <: AbstractGetIndexer
8473
end
8574

8675
function (mpg::MultipleParameterGetters)(::IsTimeseriesTrait, prob)
87-
map(g -> g(prob), mpg.getters)
76+
map(CallWith(prob), mpg.getters)
8877
end
8978
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex})
90-
map(g -> g(prob, i), mpg.getters)
79+
map(CallWith(prob, i), mpg.getters)
9180
end
9281
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i)
93-
[map(g -> g(prob, j), mpg.getters)
94-
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
82+
map.(CallWith.((prob,), only(to_indices(parameter_timeseries(prob), (i,)))),
83+
(mpg.getters,))
84+
end
85+
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::AbstractArray{Bool})
86+
callers = map(
87+
Base.Fix1(CallWith, prob), only(to_indices(parameter_timeseries(prob), (i,))))
88+
map.(callers, (mpg.getters,))
9589
end
9690
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob)
9791
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
@@ -123,10 +117,10 @@ function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::NotTimeseries,
123117
end
124118

125119
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, prob, i...)
126-
mpg(buffer, is_timeseries(prob), prob, i...)
120+
mpg(buffer, is_parameter_timeseries(prob), prob, i...)
127121
end
128122
function (mpg::MultipleParameterGetters)(prob, i...)
129-
mpg(is_timeseries(prob), prob, i...)
123+
mpg(is_parameter_timeseries(prob), prob, i...)
130124
end
131125

132126
for (t1, t2) in [

src/state_indexing.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function getu(sys, sym)
3232
_getu(sys, symtype, elsymtype, sym)
3333
end
3434

35-
struct GetStateIndex{I} <: AbstractGetIndexer
35+
struct GetStateIndex{I} <: AbstractStateGetIndexer
3636
idx::I
3737
end
3838
function (gsi::GetStateIndex)(::Timeseries, prob)
@@ -49,13 +49,12 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
4949
return GetStateIndex(sym)
5050
end
5151

52-
struct GetpAtStateTime{G} <: AbstractGetIndexer
52+
struct GetpAtStateTime{G} <: AbstractStateGetIndexer
5353
getter::G
5454
end
5555

5656
function (g::GetpAtStateTime)(::Timeseries, prob)
57-
[g.getter(parameter_values_at_state_time(prob, i))
58-
for i in eachindex(current_time(prob))]
57+
g.getter.(parameter_values_at_state_time(prob))
5958
end
6059
function (g::GetpAtStateTime)(::Timeseries, prob, i)
6160
g.getter(parameter_values_at_state_time(prob, i))
@@ -64,20 +63,19 @@ function (g::GetpAtStateTime)(::NotTimeseries, prob)
6463
g.getter(prob)
6564
end
6665

67-
struct GetIndepvar <: AbstractGetIndexer end
66+
struct GetIndepvar <: AbstractStateGetIndexer end
6867

6968
(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
7069
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)
7170

72-
struct TimeDependentObservedFunction{F} <: AbstractGetIndexer
71+
struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer
7372
obsfn::F
7473
end
7574

7675
function (o::TimeDependentObservedFunction)(::Timeseries, prob)
77-
curtime = current_time(prob)
78-
return o.obsfn.(state_values(prob),
79-
(parameter_values_at_state_time(prob, i) for i in eachindex(curtime)),
80-
curtime)
76+
o.obsfn.(state_values(prob),
77+
parameter_values_at_state_time(prob),
78+
current_time(prob))
8179
end
8280
function (o::TimeDependentObservedFunction)(::Timeseries, prob, i)
8381
return o.obsfn(state_values(prob, i),
@@ -88,7 +86,7 @@ function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
8886
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
8987
end
9088

91-
struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer
89+
struct TimeIndependentObservedFunction{F} <: AbstractStateGetIndexer
9290
obsfn::F
9391
end
9492

@@ -115,7 +113,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
115113
error("Invalid symbol $sym for `getu`")
116114
end
117115

118-
struct MultipleGetters{G} <: AbstractGetIndexer
116+
struct MultipleGetters{G} <: AbstractStateGetIndexer
119117
getters::G
120118
end
121119

@@ -130,7 +128,7 @@ function (mg::MultipleGetters)(::NotTimeseries, prob)
130128
return map(g -> g(prob), mg.getters)
131129
end
132130

133-
struct AsTupleWrapper{G} <: AbstractGetIndexer
131+
struct AsTupleWrapper{G} <: AbstractStateGetIndexer
134132
getter::G
135133
end
136134

src/value_provider_interface.jl

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,23 @@
99
Return an indexable collection containing the value of each parameter in `valp`. The two-
1010
argument version of this function returns the parameter value at index `i`. The
1111
two-argument version of this function will default to returning
12-
`parameter_values(valp)[i]`.
12+
`parameter_values(valp)[i]`. For a parameter timeseries object, this should return the
13+
parameter values at the initial time.
1314
1415
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
1516
array/tuple.
1617
"""
1718
function parameter_values end
1819

20+
parameter_values(arr::AbstractArray) = arr
21+
parameter_values(arr::Tuple) = arr
22+
parameter_values(arr::AbstractArray, i) = arr[i]
23+
parameter_values(arr::Tuple, i) = arr[i]
24+
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
25+
1926
"""
2027
parameter_values_at_time(valp, i)
28+
parameter_values_at_time(valp)
2129
2230
Return an indexable collection containing the value of all parameters in `valp` at time
2331
index `i`. This is useful when parameter values change during the simulation (such as
@@ -28,11 +36,21 @@ By default, this function returns `parameter_values(valp)` regardless of `i`, an
2836
to be specialized for timeseries objects where parameter values are not constant at all
2937
times. The resultant object should be indexable using [`parameter_values`](@ref).
3038
31-
If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
32-
implemented for [`getu`](@ref) to work correctly.
39+
If this function is implemented for a type, [`parameter_values_at_state_time`](@ref) must
40+
be implemented for [`getu`](@ref) to work correctly. Additionally,
41+
[`is_parameter_timeseries`](@ref) must be [`Timeseries`](@ref) for the type.
42+
43+
The single-argument version of this function is a shorthand to return parameter values
44+
at each point in the parameter timeseries. This has a default implementation relying on
45+
[`parameter_timeseries`](@ref) and the two-argument version of this method.
3346
"""
3447
function parameter_values_at_time end
3548

49+
parameter_values_at_time(valp, i) = parameter_values(valp)
50+
function parameter_values_at_time(valp)
51+
parameter_values_at_time.((valp,), eachindex(parameter_timeseries(valp)))
52+
end
53+
3654
"""
3755
parameter_values_at_state_time(valp, i)
3856
@@ -48,20 +66,31 @@ all times. The resultant object should be indexable using [`parameter_values`](@
4866
4967
If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
5068
[`getp`](@ref) to work correctly.
69+
70+
The single-argument version of this function is a shorthand to return parameter values
71+
at each point in the state timeseries. This has a default implementation relying on
72+
[`current_time`](@ref) and the two-argument version of this method.
5173
"""
5274
function parameter_values_at_state_time end
5375

76+
parameter_values_at_state_time(p, i) = parameter_values(p)
77+
function parameter_values_at_state_time(p)
78+
parameter_values_at_state_time.((p,), eachindex(current_time(p)))
79+
end
80+
5481
"""
5582
parameter_timeseries(valp)
5683
5784
Return an iterable of time steps at which the parameter values are saved. This is only
58-
required for objects where `is_timeseries(valp) === Timeseries()` and the parameter values
59-
change during the simulation (such as through callbacks). By default, this returns `[0]`.
85+
required for objects where `is_parameter_timeseries(valp) === Timeseries()`. By default,
86+
this returns `[0]`.
6087
6188
See also: [`parameter_values_at_time`](@ref).
6289
"""
6390
function parameter_timeseries end
6491

92+
parameter_timeseries(_) = [0]
93+
6594
"""
6695
set_parameter!(valp, val, idx)
6796
@@ -74,6 +103,12 @@ See: [`parameter_values`](@ref)
74103
"""
75104
function set_parameter! end
76105

106+
# Tuple only included for the error message
107+
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
108+
sys[idx] = val
109+
end
110+
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
111+
77112
"""
78113
finalize_parameters_hook!(valp, sym)
79114
@@ -139,7 +174,21 @@ function current_time end
139174
abstract type AbstractIndexer end
140175

141176
abstract type AbstractGetIndexer <: AbstractIndexer end
177+
abstract type AbstractStateGetIndexer <: AbstractGetIndexer end
178+
abstract type AbstractParameterGetIndexer <: AbstractGetIndexer end
142179
abstract type AbstractSetIndexer <: AbstractIndexer end
143180

144-
(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob)
145-
(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)
181+
(ai::AbstractStateGetIndexer)(prob) = ai(is_timeseries(prob), prob)
182+
(ai::AbstractStateGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)
183+
(ai::AbstractParameterGetIndexer)(prob) = ai(is_parameter_timeseries(prob), prob)
184+
(ai::AbstractParameterGetIndexer)(prob, i) = ai(is_parameter_timeseries(prob), prob, i)
185+
186+
struct CallWith{A}
187+
args::A
188+
189+
CallWith(args...) = new{typeof(args)}(args)
190+
end
191+
192+
function (cw::CallWith)(arg)
193+
arg(cw.args...)
194+
end

test/parameter_indexing_test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ function SymbolicIndexingInterface.parameter_values_at_state_time(fs::FakeSoluti
124124
end
125125
SymbolicIndexingInterface.parameter_timeseries(fs::FakeSolution) = fs.pt
126126
SymbolicIndexingInterface.is_timeseries(::Type{FakeSolution}) = Timeseries()
127+
SymbolicIndexingInterface.is_parameter_timeseries(::Type{FakeSolution}) = Timeseries()
127128
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
128129
fs = FakeSolution(
129130
sys,

0 commit comments

Comments
 (0)