Skip to content

Commit 25fe115

Browse files
refactor: rework timeseries parameter indexing API
1 parent 7a8351c commit 25fe115

7 files changed

+172
-136
lines changed

src/SymbolicIndexingInterface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ export SymbolCache
2424
include("symbol_cache.jl")
2525

2626
export parameter_values, set_parameter!, finalize_parameters_hook!,
27-
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries,
28-
parameter_timeseries_at_state_time, state_values, set_state!, current_time
27+
get_parameter_timeseries_collection, with_updated_parameter_timeseries_values,
28+
state_values, set_state!, current_time
2929
include("value_provider_interface.jl")
3030

31+
export ParameterTimeseriesCollection
32+
include("parameter_timeseries_collection.jl")
33+
3134
export getp, setp
3235
include("parameter_indexing.jl")
3336

@@ -40,9 +43,6 @@ include("batched_interface.jl")
4043
export ProblemState
4144
include("problem_state.jl")
4245

43-
export ParameterTimeseriesCollection
44-
include("parameter_timeseries_collection.jl")
45-
4646
export ParameterIndexingProxy
4747
include("parameter_indexing_proxy.jl")
4848

src/parameter_indexing.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ that implement `getindex`.
1717
1818
If the returned function is used on a timeseries object which saves parameter timeseries,
1919
it can be used to index said timeseries. The timeseries object must implement
20-
[`parameter_timeseries`](@ref), [`parameter_values_at_state_time`](@ref),
21-
[`parameter_timeseries_at_state_time`](@ref) and [`is_parameter_timeseries`](@ref).
20+
[`is_parameter_timeseries`](@ref) and [`get_parameter_timeseries_collection`](@ref).
21+
Additionally, the parameter object must implement
22+
[`with_updated_parameter_timeseries_values`](@ref).
2223
2324
If `sym` is a timeseries parameter, the function will return the timeseries of the
2425
parameter if the value provider is a parameter timeseries object. An additional argument
@@ -58,9 +59,8 @@ end
5859
function (gpi::GetParameterIndex)(::Timeseries, prob, args)
5960
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
6061
end
61-
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob)
62-
gpi.((ts,), (prob,),
63-
eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
62+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob)
63+
get_parameter_timeseries_collection(prob)[gpi.idx]
6464
end
6565
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(
6666
buffer::AbstractArray, ts::Timeseries, prob)
@@ -200,8 +200,9 @@ function (gpo::GetParameterObserved)(buffer::AbstractArray, ::NotTimeseries, pro
200200
return buffer
201201
end
202202
function (gpo::GetParameterObserved)(::Timeseries, prob)
203-
times = parameter_timeseries(prob, gpo.timeseries_idx)
204-
gpo.obsfn.(parameter_values_at_time.((prob,), times), times)
203+
map(parameter_timeseries(prob, gpo.timeseries_idx)) do t
204+
gpo.obsfn(parameter_values_at_time(prob, t), t)
205+
end
205206
end
206207
function (gpo::MultipleGetParameterObserved)(buffer::AbstractArray, ::Timeseries, prob)
207208
times = parameter_timeseries(prob, gpo.timeseries_idx)
@@ -256,7 +257,10 @@ function (gpo::SingleGetParameterObserved)(
256257
return buffer
257258
end
258259
function (gpo::GetParameterObserved)(ts::Timeseries, prob, i)
259-
gpo.((ts,), (prob,), i)
260+
map(i) do idx
261+
gpo(ts, prob, idx)
262+
end
263+
# gpo.((ts,), (prob,), i)
260264
end
261265
function (gpo::MultipleGetParameterObserved)(buffer::AbstractArray, ts::Timeseries, prob, i)
262266
for (buf_idx, time_idx) in zip(eachindex(buffer), i)

src/parameter_timeseries_collection.jl

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,18 @@ The three-argument version of [`parameter_values`](@ref) is implemented for this
3030
[`parameter_timeseries`](@ref) is implemented for this type. This type does not implement
3131
any traits.
3232
"""
33-
struct ParameterTimeseriesCollection{T}
33+
struct ParameterTimeseriesCollection{T, P}
3434
collection::T
35+
paramcache::P
3536

36-
function ParameterTimeseriesCollection(collection::T) where {T}
37+
function ParameterTimeseriesCollection(collection::T, paramcache::P) where {T, P}
3738
if any(x -> is_timeseries(x) == NotTimeseries(), collection)
3839
throw(ArgumentError("""
39-
All objects passed to `ParameterTimeseriesCollection` must be timeseries\
40-
objects.
40+
All objects in the collection `ParameterTimeseriesCollection` must be \
41+
timeseries objects.
4142
"""))
4243
end
43-
new{T}(collection)
44+
new{T, P}(collection, paramcache)
4445
end
4546
end
4647

@@ -73,6 +74,111 @@ function parameter_values(
7374
ptc::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex, subidx)
7475
return ptc[idx, subidx]
7576
end
77+
function parameter_values(prob, i::ParameterTimeseriesIndex, j)
78+
parameter_values(get_parameter_timeseries_collection(prob), i, j)
79+
end
7680
function parameter_timeseries(ptc::ParameterTimeseriesCollection, idx)
7781
return current_time(ptc[idx])
7882
end
83+
84+
function _timeseries_value(ptc::ParameterTimeseriesCollection, ts_idx, t)
85+
ts_obj = ptc[ts_idx]
86+
time_idx = searchsortedlast(current_time(ts_obj), t)
87+
value = state_values(ts_obj, time_idx)
88+
return value
89+
end
90+
91+
"""
92+
parameter_values_at_time(valp, t)
93+
94+
Return an indexable collection containing the value of all parameters in `valp` at time
95+
`t`. Note that `t` here is a floating-point time, and not an index into a timeseries.
96+
97+
This has a default implementation relying on [`get_parameter_timeseries_collection`](@ref)
98+
and [`with_updated_parameter_timeseries_values`](@ref).
99+
"""
100+
function parameter_values_at_time(valp, t)
101+
ptc = get_parameter_timeseries_collection(valp)
102+
with_updated_parameter_timeseries_values(ptc.paramcache,
103+
(ts_idx => _timeseries_value(ptc, ts_idx, t) for ts_idx in eachindex(ptc))...)
104+
end
105+
106+
"""
107+
parameter_values_at_state_time(valp, i)
108+
parameter_values_at_state_time(valp)
109+
110+
Return an indexable collection containing the value of all parameters in `valp` at time
111+
index `i` in the state timeseries.
112+
113+
By default, this function relies on [`parameter_values_at_time`](@ref) and
114+
[`current_time`](@ref) for a default implementation.
115+
116+
The single-argument version of this function is a shorthand to return parameter values
117+
at each point in the state timeseries. This also has a default implementation relying on
118+
[`parameter_values_at_time`](@ref) and [`current_time`](@ref).
119+
"""
120+
function parameter_values_at_state_time end
121+
122+
function parameter_values_at_state_time(p, i)
123+
state_time = current_time(p, i)
124+
return parameter_values_at_time(p, state_time)
125+
end
126+
function parameter_values_at_state_time(p)
127+
return (parameter_values_at_time(p, t) for t in current_time(p))
128+
end
129+
130+
"""
131+
parameter_timeseries(valp, i)
132+
133+
Return a vector of the time steps at which the parameter values in the parameter
134+
timeseries at index `i` are saved. This is only required for objects where
135+
`is_parameter_timeseries(valp) === Timeseries()`. It will not be called otherwise. It is
136+
assumed that the timeseries is sorted in increasing order.
137+
138+
See also: [`is_parameter_timeseries`](@ref).
139+
"""
140+
function parameter_timeseries end
141+
142+
function parameter_timeseries(valp, i)
143+
return parameter_timeseries(get_parameter_timeseries_collection(valp), i)
144+
end
145+
146+
"""
147+
parameter_timeseries_at_state_time(valp, i, j)
148+
parameter_timeseries_at_state_time(valp, i)
149+
150+
Return the index of the timestep in the parameter timeseries at timeseries index `i` which
151+
occurs just before or at the same time as the state timestep with index `j`. The two-
152+
argument version of this function returns an iterable of indexes, one for each timestep in
153+
the state timeseries. If `j` is an object that refers to multiple values in the state
154+
timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries
155+
at the appropriate points.
156+
157+
Both versions of this function have default implementations relying on
158+
[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one
159+
of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the
160+
aforementioned.
161+
"""
162+
function parameter_timeseries_at_state_time end
163+
164+
function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex})
165+
state_time = current_time(valp, j)
166+
timeseries = parameter_timeseries(valp, i)
167+
searchsortedlast(timeseries, state_time)
168+
end
169+
170+
function parameter_timeseries_at_state_time(valp, i, ::Colon)
171+
parameter_timeseries_at_state_time(valp, i)
172+
end
173+
174+
function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool})
175+
parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,))))
176+
end
177+
178+
function parameter_timeseries_at_state_time(valp, i, j)
179+
(parameter_timeseries_at_state_time(valp, i, jj) for jj in j)
180+
end
181+
182+
function parameter_timeseries_at_state_time(valp, i)
183+
parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp)))
184+
end

src/state_indexing.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ relying on the above functions.
2323
2424
If the value provider is a parameter timeseries object, the same rules apply as
2525
[`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols,
26-
and the values are always returned corresponding to the state timeseries. This utilizes
27-
[`parameter_values_at_state_time`](@ref) and [`parameter_timeseries_at_state_time`](@ref).
26+
and the values are always returned corresponding to the state timeseries.
2827
"""
2928
function getu(sys, sym)
3029
symtype = symbolic_type(sym)
@@ -117,9 +116,8 @@ function (o::TimeDependentObservedFunction)(ts::Timeseries, prob)
117116
return o(ts, is_parameter_timeseries(prob), prob)
118117
end
119118
function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob)
120-
o.obsfn.(state_values(prob),
121-
parameter_values_at_state_time(prob),
122-
current_time(prob))
119+
map(o.obsfn, state_values(prob),
120+
parameter_values_at_state_time(prob), current_time(prob))
123121
end
124122
function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob)
125123
o.obsfn.(state_values(prob),

src/value_provider_interface.jl

Lines changed: 10 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,12 @@
55
"""
66
parameter_values(valp)
77
parameter_values(valp, i)
8-
parameter_values(valp, i::ParameterTimeseriesIndex, j)
98
109
Return an indexable collection containing the value of each parameter in `valp`. The two-
1110
argument version of this function returns the parameter value at index `i`. The
1211
two-argument version of this function will default to returning
1312
`parameter_values(valp)[i]`.
1413
15-
For a parameter timeseries object, this should return the parameter values at the final
16-
time. The two-argument version for a parameter timeseries object should also access
17-
parameter values at the final time. An additional three-argument version is also
18-
necessary for parameter timeseries objects. It accepts a [`ParameterTimeseriesIndex`](@ref)
19-
object passed as `i`, the index in the corresponding timeseries `j` and returns the value
20-
of that parameter at the specified time index `j` in the appropriate parameter timeseries.
21-
2214
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
2315
array/tuple.
2416
"""
@@ -31,90 +23,22 @@ parameter_values(arr::Tuple, i) = arr[i]
3123
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
3224

3325
"""
34-
parameter_values_at_time(valp, t)
35-
36-
Return an indexable collection containing the value of all parameters in `valp` at time
37-
`t`. Note that `t` here is a floating-point time, and not an index into a timeseries.
26+
get_parameter_timeseries_collection(valp)
3827
39-
This is useful for parameter timeseries objects, since some parameters change over time.
28+
Return the [`ParameterTimeseriesCollection`](@ref) stored in `valp`. Only required for
29+
parameter timeseries objects.
4030
"""
41-
function parameter_values_at_time end
31+
function get_parameter_timeseries_collection end
4232

4333
"""
44-
parameter_values_at_state_time(valp, i)
45-
parameter_values_at_state_time(valp)
46-
47-
Return an indexable collection containing the value of all parameters in `valp` at time
48-
index `i` in the state timeseries.
49-
50-
By default, this function relies on [`parameter_values_at_time`](@ref) and
51-
[`current_time`](@ref) for a default implementation.
34+
with_updated_parameter_timeseries_values(valp, args::Pair...)
5235
53-
The single-argument version of this function is a shorthand to return parameter values
54-
at each point in the state timeseries. This also has a default implementation relying on
55-
[`parameter_values_at_time`](@ref) and [`current_time`](@ref).
36+
Return an indexable collection containing the value of all parameters in `valp`, with
37+
parameters belonging to specific timeseries updated to different values. Each element in
38+
`args...` contains the timeseries index as the first value, and the saved parameter values
39+
in that partition. Not all parameter timeseries have to be updated using this method.
5640
"""
57-
function parameter_values_at_state_time end
58-
59-
function parameter_values_at_state_time(p, i)
60-
state_time = current_time(p, i)
61-
return parameter_values_at_time(p, state_time)
62-
end
63-
function parameter_values_at_state_time(p)
64-
parameter_values_at_time.((p,), current_time(p))
65-
end
66-
67-
"""
68-
parameter_timeseries(valp, i)
69-
70-
Return a vector of the time steps at which the parameter values in the parameter
71-
timeseries at index `i` are saved. This is only required for objects where
72-
`is_parameter_timeseries(valp) === Timeseries()`. It will not be called otherwise. It is
73-
assumed that the timeseries is sorted in increasing order.
74-
75-
See also: [`is_parameter_timeseries`](@ref).
76-
"""
77-
function parameter_timeseries end
78-
79-
"""
80-
parameter_timeseries_at_state_time(valp, i, j)
81-
parameter_timeseries_at_state_time(valp, i)
82-
83-
Return the index of the timestep in the parameter timeseries at timeseries index `i` which
84-
occurs just before or at the same time as the state timestep with index `j`. The two-
85-
argument version of this function returns an iterable of indexes, one for each timestep in
86-
the state timeseries. If `j` is an object that refers to multiple values in the state
87-
timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries
88-
at the appropriate points.
89-
90-
Both versions of this function have default implementations relying on
91-
[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one
92-
of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the
93-
aforementioned.
94-
"""
95-
function parameter_timeseries_at_state_time end
96-
97-
function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex})
98-
state_time = current_time(valp, j)
99-
timeseries = parameter_timeseries(valp, i)
100-
searchsortedlast(timeseries, state_time)
101-
end
102-
103-
function parameter_timeseries_at_state_time(valp, i, ::Colon)
104-
parameter_timeseries_at_state_time(valp, i)
105-
end
106-
107-
function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool})
108-
parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,))))
109-
end
110-
111-
function parameter_timeseries_at_state_time(valp, i, j)
112-
(parameter_timeseries_at_state_time(valp, i, jj) for jj in j)
113-
end
114-
115-
function parameter_timeseries_at_state_time(valp, i)
116-
parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp)))
117-
end
41+
function with_updated_parameter_timeseries_values end
11842

11943
"""
12044
set_parameter!(valp, val, idx)

0 commit comments

Comments
 (0)