Skip to content

Commit b8139fa

Browse files
fixup! wip: better parameter indexing
1 parent 259d2e6 commit b8139fa

File tree

5 files changed

+280
-104
lines changed

5 files changed

+280
-104
lines changed

src/parameter_indexing.jl

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob)
4242
parameter_values(prob, gpi.idx)
4343
end
4444
function (gpi::GetParameterIndex)(::Timeseries, prob, args)
45-
error("""
46-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
47-
is a parameter timeseries object) with parameter index $(gpi.idx) (which is not\
48-
a `ParameterTimeseriesIndex`) at index $(args) in the timeseries.
49-
""")
45+
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
5046
end
5147
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob)
5248
parameter_values.(
@@ -68,10 +64,7 @@ function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, pr
6864
gpi.((ts,), (prob,), i)
6965
end
7066
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob)
71-
error("""
72-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
73-
is not a parameter timeseries object) using index $(gpi.idx).
74-
""")
67+
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
7568
end
7669

7770
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
@@ -177,26 +170,14 @@ for (indexerTimeseriesType, timeseriesType) in [
177170
end
178171

179172
function (mpg::MixedTimeseriesIndexMPG)(::Timeseries, prob, args...)
180-
error("""
181-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
182-
is a parameter timeseries object) with variables having mixed timeseries indexes\
183-
$(mpg.timeseries_idx.indexes).
184-
""")
173+
throw(MixedParameterTimeseriesIndexError(prob, mpg.timeseries_idx.indexes))
185174
end
186175

187176
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::Timeseries, prob, args)
188-
error("""
189-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
190-
is a parameter timeseries object) with non-timeseries indexer $mpg at index $args\
191-
in the timeseries.
192-
""")
177+
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
193178
end
194179
function (mpg::MultipleParametersGetter{IndexerNotTimeseries})(::AbstractArray, ::Timeseries, prob, args)
195-
error("""
196-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
197-
is a parameter timeseries object) with non-timeseries indexer $mpg at index $args\
198-
in the timeseries.
199-
""")
180+
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, mpg, args))
200181
end
201182
function (mpg::AtLeastTimeseriesMPG)(ts::Timeseries, prob)
202183
map(eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg)))) do i
@@ -242,16 +223,10 @@ function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob
242223
buffer
243224
end
244225
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob)
245-
error("""
246-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
247-
is not a parameter timeseries object) with timeseries indexer $mpg.
248-
""")
226+
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
249227
end
250228
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::AbstractArray, ::NotTimeseries, prob)
251-
error("""
252-
Invalid indexing operation: tried to access object of type $(typeof(prob)) (which\
253-
is not a parameter timeseries object) with timeseries indexer $mpg.
254-
""")
229+
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
255230
end
256231

257232
for (t1, t2) in [

src/state_indexing.jl

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
state_values(arr::AbstractArray) = arr
2-
state_values(arr, i) = state_values(arr)[i]
3-
41
function set_state!(sys, val, idx)
52
state_values(sys)[idx] = val
63
end
74

8-
current_time(p, i) = current_time(p)[i]
95

106
"""
117
getu(indp, sym)
@@ -38,9 +34,12 @@ end
3834
function (gsi::GetStateIndex)(::Timeseries, prob)
3935
getindex.(state_values(prob), (gsi.idx,))
4036
end
41-
function (gsi::GetStateIndex)(::Timeseries, prob, i)
37+
function (gsi::GetStateIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex})
4238
getindex(state_values(prob, i), gsi.idx)
4339
end
40+
function (gsi::GetStateIndex)(::Timeseries, prob, i)
41+
getindex.(state_values(prob, i), gsi.idx)
42+
end
4443
function (gsi::GetStateIndex)(::NotTimeseries, prob)
4544
state_values(prob, gsi.idx)
4645
end
@@ -62,12 +61,34 @@ end
6261
function (g::GetpAtStateTime)(::Timeseries, ::NotTimeseries, prob, _...)
6362
g.getter(prob)
6463
end
65-
function (g::GetpAtStateTime, ::Timeseries, ::Timeseries, prob)
64+
function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob)
65+
g(ts, p_ts, is_indexer_timeseries(g.getter), prob)
66+
end
67+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob)
6668
g.getter.((prob,), parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter)))
6769
end
68-
function (g::GetpAtStateTime, ::Timeseries, ::Timeseries, prob, i)
70+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob)
71+
g.getter(prob)
72+
end
73+
function (g::GetpAtStateTime)(ts::Timeseries, p_ts::Timeseries, prob, i)
74+
g(ts, p_ts, is_indexer_timeseries(g.getter), prob, i)
75+
end
76+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::Union{IndexerTimeseries, IndexerBoth}, prob, i)
6977
g.getter(prob, parameter_timeseries_at_state_time(prob, indexer_timeseries_index(g.getter), i))
7078
end
79+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Union{Int, CartesianIndex})
80+
g.getter(prob)
81+
end
82+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, ::Colon)
83+
map(_ -> g.getter(prob), current_time(prob))
84+
end
85+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i::AbstractArray{Bool})
86+
num_ones = sum(i)
87+
map(_ -> g.getter(prob), 1:num_ones)
88+
end
89+
function (g::GetpAtStateTime)(::Timeseries, ::Timeseries, ::IndexerNotTimeseries, prob, i)
90+
map(_ -> g.getter(prob), 1:length(i))
91+
end
7192
function (g::GetpAtStateTime)(::NotTimeseries, prob)
7293
g.getter(prob)
7394
end
@@ -126,12 +147,22 @@ struct MultipleGetters{G} <: AbstractStateGetIndexer
126147
getters::G
127148
end
128149

129-
function (mg::MultipleGetters)(::Timeseries, prob)
130-
return broadcast(i -> map(g -> g(prob, i), mg.getters),
131-
eachindex(state_values(prob)))
150+
function (mg::MultipleGetters)(ts::Timeseries, prob)
151+
return mg.((ts,), (prob,), eachindex(current_time(prob)))
152+
end
153+
function (mg::MultipleGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex})
154+
return map(CallWith(prob, i), mg.getters)
155+
end
156+
function (mg::MultipleGetters)(ts::Timeseries, prob, ::Colon)
157+
return mg(ts, prob)
132158
end
133-
function (mg::MultipleGetters)(::Timeseries, prob, i)
134-
return map(g -> g(prob, i), mg.getters)
159+
function (mg::MultipleGetters)(ts::Timeseries, prob, i::AbstractArray{Bool})
160+
return map(only(to_indices(current_time(prob), (i,)))) do idx
161+
mg(ts, prob, idx)
162+
end
163+
end
164+
function (mg::MultipleGetters)(ts::Timeseries, prob, i)
165+
mg.((ts,), (prob,), i)
135166
end
136167
function (mg::MultipleGetters)(::NotTimeseries, prob)
137168
return map(g -> g(prob), mg.getters)
@@ -144,9 +175,12 @@ end
144175
function (atw::AsTupleWrapper)(::Timeseries, prob)
145176
return Tuple.(atw.getter(prob))
146177
end
147-
function (atw::AsTupleWrapper)(::Timeseries, prob, i)
178+
function (atw::AsTupleWrapper)(::Timeseries, prob, i::Union{Int, CartesianIndex})
148179
return Tuple(atw.getter(prob, i))
149180
end
181+
function (atw::AsTupleWrapper)(::Timeseries, prob, i)
182+
return Tuple.(atw.getter(prob, i))
183+
end
150184
function (atw::AsTupleWrapper)(::NotTimeseries, prob)
151185
return Tuple(atw.getter(prob))
152186
end
@@ -158,9 +192,13 @@ for (t1, t2) in [
158192
]
159193
@eval function _getu(sys, ::NotSymbolic, ::$t1, sym::$t2)
160194
num_observed = count(x -> is_observed(sys, x), sym)
161-
if num_observed <= 1
162-
getters = getu.((sys,), sym)
163-
return MultipleGetters(getters)
195+
if num_observed == 0
196+
if all(Base.Fix1(is_parameter, sys), sym) && all(!Base.Fix1(is_timeseries_parameter, sys), sym)
197+
GetpAtStateTime(getp(sys, sym))
198+
else
199+
getters = getu.((sys,), sym)
200+
return MultipleGetters(getters)
201+
end
164202
else
165203
obs = observed(sys, sym isa Tuple ? collect(sym) : sym)
166204
getter = if is_time_dependent(sys)
@@ -181,7 +219,7 @@ function _getu(sys, ::ArraySymbolic, ::SymbolicTypeTrait, sym)
181219
idx = variable_index(sys, sym)
182220
return getu(sys, idx)
183221
elseif is_parameter(sys, sym)
184-
return getp(sys, sym)
222+
return GetpAtStateTime(getp(sys, sym))
185223
end
186224
return getu(sys, collect(sym))
187225
end

src/symbol_cache.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,20 @@ function observed(sc::SymbolCache, expr::Expr)
169169
end
170170
end
171171
end
172-
function observed(sc::SymbolCache, exprs::AbstractArray{Expr})
172+
function observed(sc::SymbolCache, exprs::AbstractArray)
173+
for expr in exprs
174+
if !(expr isa Union{Symbol, Expr})
175+
throw(TypeError(:observed, "SymbolCache", Union{Symbol, Expr}, expr))
176+
end
177+
end
173178
return observed(sc, :(reshape([$(exprs...)], $(size(exprs)))))
174179
end
175-
function observed(sc::SymbolCache, exprs::Tuple{Vararg{Expr}})
180+
function observed(sc::SymbolCache, exprs::Tuple)
181+
for expr in exprs
182+
if !(expr isa Union{Symbol, Expr})
183+
throw(TypeError(:observed, "SymbolCache", Union{Symbol, Expr}, expr))
184+
end
185+
end
176186
return observed(sc, :(($(exprs...),)))
177187
end
178188

src/value_provider_interface.jl

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,37 @@ function parameter_timeseries end
7373
Return the index of the timestep in the parameter timeseries at timeseries index `i` which
7474
occurs just before or at the same time as the state timestep with index `j`. The two-
7575
argument version of this function returns an iterable of indexes, one for each timestep in
76-
the state timeseries.
76+
the state timeseries. If `j` is an object that refers to multiple values in the state
77+
timeseries (e.g. `Colon`), return an iterable of the indexes in the parameter timeseries
78+
at the appropriate points.
7779
7880
Both versions of this function have default implementations relying on
79-
[`current_time`](@ref) and [`parameter_timeseries`](@ref).
81+
[`current_time`](@ref) and [`parameter_timeseries`](@ref), for the cases where `j` is one
82+
of: `Int`, `CartesianIndex`, `AbstractArray{Bool}`, `Colon` or an iterable of the
83+
aforementioned.
8084
"""
8185
function parameter_timeseries_at_state_time end
8286

83-
function parameter_timeseries_at_state_time(valp, i, j)
87+
function parameter_timeseries_at_state_time(valp, i, j::Union{Int, CartesianIndex})
8488
state_time = current_time(valp, j)
8589
timeseries = parameter_timeseries(valp, i)
8690
searchsortedlast(timeseries, state_time)
8791
end
8892

93+
function parameter_timeseries_at_state_time(valp, i, ::Colon)
94+
parameter_timeseries_at_state_time(valp, i)
95+
end
96+
97+
function parameter_timeseries_at_state_time(valp, i, j::AbstractArray{Bool})
98+
parameter_timeseries_at_state_time(valp, i, only(to_indices(current_time(valp), (j,))))
99+
end
100+
101+
function parameter_timeseries_at_state_time(valp, i, j)
102+
(parameter_timeseries_at_state_time(valp, i, jj) for jj in j)
103+
end
104+
89105
function parameter_timeseries_at_state_time(valp, i)
90-
(parameter_timeseries_at_state_time(valp, i, j) for j in eachindex(current_time(valp)))
106+
parameter_timeseries_at_state_time(valp, i, eachindex(current_time(valp)))
91107
end
92108

93109
"""
@@ -131,14 +147,21 @@ Return an indexable collection containing the values of all states in the value
131147
each of which contain the state values at the corresponding timestep. In this case, the
132148
two-argument version of the function can also be implemented to efficiently return
133149
the state values at timestep `i`. By default, the two-argument method calls
134-
`state_values(valp)[i]`
150+
`state_values(valp)[i]`. If `i` consists of multiple indices (for example, `Colon`,
151+
`AbstractArray{Int}`, `AbstractArray{Bool}`) specialized methods may be defined for
152+
efficiency. By default, `state_values(valp, ::Colon) = state_values(valp)` to avoid
153+
copying the timeseries.
135154
136155
If this function is called with an `AbstractArray`, it will return the same array.
137156
138157
See: [`is_timeseries`](@ref)
139158
"""
140159
function state_values end
141160

161+
state_values(arr::AbstractArray) = arr
162+
state_values(arr, i) = state_values(arr)[i]
163+
state_values(arr, ::Colon) = state_values(arr)
164+
142165
"""
143166
set_state!(valp, val, idx)
144167
@@ -162,11 +185,21 @@ also be implemented to efficiently return the time at timestep `i`. By default,
162185
argument method calls `current_time(p)[i]`. It is assumed that the timeseries is sorted
163186
in increasing order.
164187
188+
If `i` consists of multiple indices (for example, `Colon`, `AbstractArray{Int}`,
189+
`AbstractArray{Bool}`) specialized methods may be defined for efficiency. By default,
190+
`current_time(valp, ::Colon) = current_time(valp)` to avoid copying the timeseries.
191+
192+
By default, the single-argument version acts as the identity function if
193+
`valp isa AbstractVector`.
165194
166195
See: [`is_timeseries`](@ref)
167196
"""
168197
function current_time end
169198

199+
current_time(arr::AbstractVector) = arr
200+
current_time(valp, i) = current_time(valp)[i]
201+
current_time(valp, ::Colon) = current_time(valp)
202+
170203
###########
171204
# Utilities
172205
###########
@@ -232,3 +265,74 @@ end
232265
function (cw::CallWith)(arg)
233266
arg(cw.args...)
234267
end
268+
269+
###########
270+
# Errors
271+
###########
272+
273+
struct ParameterTimeseriesValueIndexMismatchError{P <: IsTimeseriesTrait} <: Exception
274+
valp
275+
indexer
276+
args
277+
278+
function ParameterTimeseriesValueIndexMismatchError{Timeseries}(valp, indexer, args)
279+
if is_parameter_timeseries(valp) != Timeseries()
280+
throw(ArgumentError("""
281+
This should never happen. Expected parameter timeseries value provider, \
282+
got $(valp). Open an issue in SymbolicIndexingInterface.jl with an MWE.
283+
"""))
284+
end
285+
if is_indexer_timeseries(indexer) != IndexerNotTimeseries()
286+
throw(ArgumentError("""
287+
This should never happen. Expected non-timeseries indexer, got \
288+
$(indexer). Open an issue in SymbolicIndexingInterface.jl with an MWE.
289+
"""))
290+
end
291+
return new{Timeseries}(valp, indexer, args)
292+
end
293+
function ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(valp, indexer)
294+
if is_parameter_timeseries(valp) != NotTimeseries()
295+
throw(ArgumentError("""
296+
This should never happen. Expected non-parameter timeseries value \
297+
provider, got $(valp). Open an issue in SymbolicIndexingInterface.jl \
298+
with an MWE.
299+
"""))
300+
end
301+
if is_indexer_timeseries(indexer) != IndexerTimeseries()
302+
throw(ArgumentError("""
303+
This should never happen. Expected timeseries indexer, got $(indexer). \
304+
Open an issue in SymbolicIndexingInterface.jl with an MWE.
305+
"""))
306+
end
307+
return new{NotTimeseries}(valp, indexer, nothing)
308+
end
309+
end
310+
311+
function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{Timeseries})
312+
print(io, """
313+
Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \
314+
(which is a parameter timeseries object) with non-timeseries indexer \
315+
$(err.indexer) at index $(err.args) in the timeseries.
316+
""")
317+
end
318+
319+
function Base.showerror(io::IO, err::ParameterTimeseriesValueIndexMismatchError{NotTimeseries})
320+
print(io, """
321+
Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \
322+
(which is not a parameter timeseries object) using timeseries indexer \
323+
$(err.indexer).
324+
""")
325+
end
326+
327+
struct MixedParameterTimeseriesIndexError <: Exception
328+
valp
329+
ts_idxs
330+
end
331+
332+
function Base.showerror(io::IO, err::MixedParameterTimeseriesIndexError)
333+
print(io, """
334+
Invalid indexing operation: tried to access object of type $(typeof(err.valp)) \
335+
(which is a parameter timeseries object) with variables having mixed timeseries \
336+
indexes $(err.ts_idxs).
337+
""")
338+
end

0 commit comments

Comments
 (0)