Skip to content

Commit 1eeb5c5

Browse files
fixup! wip: better parameter indexing
1 parent c7066ec commit 1eeb5c5

12 files changed

+215
-45
lines changed

src/index_provider_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function is_timeseries_parameter(indp, sym)
6262
if hasmethod(symbolic_container, Tuple{typeof(indp)})
6363
is_timeseries_parameter(symbolic_container(indp), sym)
6464
else
65-
false
65+
return false
6666
end
6767
end
6868

@@ -94,7 +94,7 @@ function timeseries_parameter_index(indp, sym)
9494
if hasmethod(symbolic_container, Tuple{typeof(indp)})
9595
timeseries_parameter_index(symbolic_container(indp), sym)
9696
else
97-
nothing
97+
return nothing
9898
end
9999
end
100100

src/parameter_indexing.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,24 @@ Requires that the value provider implement [`parameter_values`](@ref). This func
1515
may not always need to be implemented, and has a default implementation for collections
1616
that implement `getindex`.
1717
18-
If the returned function is used on a timeseries object which saves parameter timeseries, it
19-
can be used to index said timeseries. The timeseries object must implement
18+
If the returned function is used on a timeseries object which saves parameter timeseries,
19+
it can be used to index said timeseries. The timeseries object must implement
2020
[`parameter_timeseries`](@ref), [`parameter_values_at_state_time`](@ref),
21-
[`parameter_timeseries_at_state_time`](@ref) and [`is_parameter_timeseries`](@ref). The
22-
function returned from `getp` can be passed `Colon()` (`:`) as the last argument to return
23-
the entire parameter timeseries for `p`, or any index into the parameter timeseries for a
24-
subset of values.
21+
[`parameter_timeseries_at_state_time`](@ref) and [`is_parameter_timeseries`](@ref).
22+
23+
If `sym` is a timeseries parameter, the function will return the timeseries of the
24+
parameter if the value provider is a parameter timeseries object. An additional argument
25+
can be provided to the function indicating the specific indexes in the timeseries at
26+
which to access the values. If `sym` is an array of parameters, the following cases
27+
apply:
28+
29+
- All parameters are non-timeseries parameters: The function returns the value of each
30+
parameter.
31+
- All parameters are timeseries parameters: All the parameters must belong to the same
32+
timeseries (otherwise `getp` will error). The function returns the timeseries of all
33+
parameter values, and can be accessed at specific indices in the timeseries.
34+
- A mix of timeseries and non-timeseries parameters: The function can _only_ be used on
35+
non-timeseries objects and will return the value of each parameter at in the object.
2536
"""
2637
function getp(sys, p)
2738
symtype = symbolic_type(p)
@@ -51,7 +62,7 @@ function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractAr
5162
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
5263
buffer[buf_idx] = gpi(ts, prob, ts_idx)
5364
end
54-
buffer
65+
return buffer
5566
end
5667
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex})
5768
parameter_values(prob, gpi.idx, i)
@@ -71,7 +82,7 @@ function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractAr
7182
for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,))))
7283
buffer[buf_idx] = gpi(ts, prob, ts_idx)
7384
end
74-
buffer
85+
return buffer
7586
end
7687
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i)
7788
gpi.((ts,), (prob,), i)
@@ -80,7 +91,7 @@ function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractAr
8091
for (buf_idx, subidx) in zip(eachindex(buffer), i)
8192
buffer[buf_idx] = gpi(ts, prob, subidx)
8293
end
83-
buffer
94+
return buffer
8495
end
8596
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob)
8697
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
@@ -190,7 +201,7 @@ for (indexerTimeseriesType, timeseriesType) in [
190201
for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters)
191202
buffer[buf_idx] = getter(prob)
192203
end
193-
buffer
204+
return buffer
194205
end
195206
end
196207

@@ -227,13 +238,13 @@ function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob
227238
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(mpg))))
228239
mpg(buffer[buf_idx], ts, prob, ts_idx)
229240
end
230-
buffer
241+
return buffer
231242
end
232243
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
233244
for (buf_idx, getter) in zip(eachindex(buffer), mpg.getters)
234245
buffer[buf_idx] = getter(prob, i)
235246
end
236-
buffer
247+
return buffer
237248
end
238249
function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
239250
mpg(buffer, ts, prob)
@@ -245,7 +256,7 @@ function (mpg::AtLeastTimeseriesMPG)(buffer::AbstractArray, ts::Timeseries, prob
245256
for (buf_idx, ts_idx) in zip(eachindex(buffer), i)
246257
mpg(buffer[buf_idx], ts, prob, ts_idx)
247258
end
248-
buffer
259+
return buffer
249260
end
250261
function (mpg::MultipleParametersGetter{IndexerTimeseries})(::NotTimeseries, prob)
251262
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, mpg))
@@ -285,7 +296,7 @@ end
285296
function (phw::ParameterHookWrapper)(prob, args...)
286297
res = phw.setter(prob, args...)
287298
finalize_parameters_hook!(prob, phw.original_index)
288-
res
299+
return res
289300
end
290301

291302
"""

src/parameter_timeseries_collection.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ The collection is expected to implement `Base.eachindex`, `Base.iterate` and
1313
returned by calling [`timeseries_parameter_index`](@ref) on the corresponding index
1414
provider.
1515
16-
This type forwards `eachindex` and `iterate` to the contained `collection`. It implements
17-
`Base.parent` to allow access to the contained `collection`, and has the following
18-
`getindex` methods:
16+
This type forwards `eachindex`, `iterate` and `length` to the contained `collection`. It
17+
implements `Base.parent` to allow access to the contained `collection`, and has the
18+
following `getindex` methods:
1919
2020
- `getindex(ptc::ParameterTimeseriesCollection, idx) = ptc.collection[idx]`.
2121
- `getindex(::ParameterTimeseriesCollection, idx::ParameterTimeseriesIndex)` returns the
@@ -48,6 +48,8 @@ Base.eachindex(ptc::ParameterTimeseriesCollection) = eachindex(ptc.collection)
4848

4949
Base.iterate(ptc::ParameterTimeseriesCollection, args...) = iterate(ptc.collection, args...)
5050

51+
Base.length(ptc::ParameterTimeseriesCollection) = length(ptc.collection)
52+
5153
Base.parent(ptc::ParameterTimeseriesCollection) = ptc.collection
5254

5355
Base.getindex(ptc::ParameterTimeseriesCollection, idx) = ptc.collection[idx]

src/state_indexing.jl

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ support symbolic expressions, the value provider must implement [`observed`](@re
2121
2222
This function typically does not need to be implemented, and has a default implementation
2323
relying on the above functions.
24+
25+
If the value provider is a parameter timeseries object, the same rules apply as
26+
[`getp`](@ref). The difference here is that `sym` may also contain non-parameter symbols,
27+
and the values are always returned corresponding to the state timeseries. This utilizes
28+
[`parameter_values_at_state_time`](@ref) and [`parameter_timeseries_at_state_time`](@ref).
2429
"""
2530
function getu(sys, sym)
2631
symtype = symbolic_type(sym)
@@ -102,16 +107,41 @@ struct TimeDependentObservedFunction{F} <: AbstractStateGetIndexer
102107
obsfn::F
103108
end
104109

105-
function (o::TimeDependentObservedFunction)(::Timeseries, prob)
110+
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob)
111+
return o(ts, is_parameter_timeseries(prob), prob)
112+
end
113+
function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob)
106114
o.obsfn.(state_values(prob),
107115
parameter_values_at_state_time(prob),
108116
current_time(prob))
109117
end
110-
function (o::TimeDependentObservedFunction)(::Timeseries, prob, i)
118+
function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob)
119+
o.obsfn.(state_values(prob),
120+
(parameter_values(prob),),
121+
current_time(prob))
122+
end
123+
function (o::TimeDependentObservedFunction)(ts::Timeseries, prob, i)
124+
return o(ts, is_parameter_timeseries(prob), prob, i)
125+
end
126+
function (o::TimeDependentObservedFunction)(::Timeseries, ::Timeseries, prob, i::Union{Int, CartesianIndex})
111127
return o.obsfn(state_values(prob, i),
112128
parameter_values_at_state_time(prob, i),
113129
current_time(prob, i))
114130
end
131+
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, ::Colon)
132+
return o(ts, p_ts, prob)
133+
end
134+
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i::AbstractArray{Bool})
135+
map(only(to_indices(current_time(prob), (i,)))) do idx
136+
o(ts, p_ts, prob, idx)
137+
end
138+
end
139+
function (o::TimeDependentObservedFunction)(ts::Timeseries, p_ts::IsTimeseriesTrait, prob, i)
140+
o.((ts,), (p_ts,), (prob,), i)
141+
end
142+
function (o::TimeDependentObservedFunction)(::Timeseries, ::NotTimeseries, prob, i::Union{Int, CartesianIndex})
143+
o.obsfn(state_values(prob, i), parameter_values(prob), current_time(prob, i))
144+
end
115145
function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
116146
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
117147
end
@@ -168,21 +198,29 @@ function (mg::MultipleGetters)(::NotTimeseries, prob)
168198
return map(g -> g(prob), mg.getters)
169199
end
170200

171-
struct AsTupleWrapper{G} <: AbstractStateGetIndexer
201+
struct AsTupleWrapper{N, G} <: AbstractStateGetIndexer
172202
getter::G
173203
end
174204

175-
function (atw::AsTupleWrapper)(::Timeseries, prob)
176-
return Tuple.(atw.getter(prob))
205+
AsTupleWrapper{N}(getter::G) where {N, G} = AsTupleWrapper{N, G}(getter)
206+
207+
wrap_tuple(::AsTupleWrapper{N}, val) where {N} = ntuple(i -> val[i], Val(N))
208+
209+
function (atw::(AsTupleWrapper{N} where {N}))(::Timeseries, prob)
210+
return wrap_tuple.((atw,), atw.getter(prob))
211+
# return Tuple.(atw.getter(prob))
177212
end
178213
function (atw::AsTupleWrapper)(::Timeseries, prob, i::Union{Int, CartesianIndex})
179-
return Tuple(atw.getter(prob, i))
214+
return wrap_tuple(atw, atw.getter(prob, i))
215+
# return Tuple(atw.getter(prob, i))
180216
end
181217
function (atw::AsTupleWrapper)(::Timeseries, prob, i)
182-
return Tuple.(atw.getter(prob, i))
218+
return wrap_tuple.((atw,), atw.getter(prob, i))
219+
# return Tuple.(atw.getter(prob, i))
183220
end
184221
function (atw::AsTupleWrapper)(::NotTimeseries, prob)
185-
return Tuple(atw.getter(prob))
222+
wrap_tuple(atw, atw.getter(prob))
223+
# return Tuple(atw.getter(prob))
186224
end
187225

188226
for (t1, t2) in [
@@ -207,7 +245,7 @@ for (t1, t2) in [
207245
TimeIndependentObservedFunction(obs)
208246
end
209247
if sym isa Tuple
210-
getter = AsTupleWrapper(getter)
248+
getter = AsTupleWrapper{length(sym)}(getter)
211249
end
212250
return getter
213251
end

src/symbol_cache.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ function SymbolCache(vars = nothing, params = nothing, indepvars = nothing;
4646
params = to_dict_or_nothing(params)
4747
timeseries_parameters = to_dict_or_nothing(timeseries_parameters)
4848
if timeseries_parameters !== nothing
49+
if indepvars === nothing
50+
throw(ArgumentError("Independent variable is required for timeseries parameters to exist"))
51+
end
4952
for (k, v) in timeseries_parameters
5053
if !haskey(params, k)
5154
throw(ArgumentError("Timeseries parameter $k must also be present in parameters."))
@@ -75,7 +78,7 @@ function variable_symbols(sc::SymbolCache, i = nothing)
7578
for (k, v) in sc.variables
7679
buffer[v] = k
7780
end
78-
buffer
81+
return buffer
7982
end
8083
function is_parameter(sc::SymbolCache, sym)
8184
sc.parameters !== nothing && haskey(sc.parameters, sym)

src/value_provider_interface.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,8 @@ abstract type AbstractSetIndexer <: AbstractIndexer end
215215
(ai::AbstractStateGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)
216216
(ai::AbstractParameterGetIndexer)(prob) = ai(is_parameter_timeseries(prob), prob)
217217
(ai::AbstractParameterGetIndexer)(prob, i) = ai(is_parameter_timeseries(prob), prob, i)
218-
# unfortunately, this is ambiguous
219-
function (ai::AbstractParameterGetIndexer)(arg1::AbstractArray, arg2)
220-
if hasmethod(parameter_values, Tuple{typeof(arg2)})
221-
# arg1 is buffer
222-
ai(arg1, is_parameter_timeseries(arg2), arg2)
223-
else
224-
# arg1 is value provider
225-
ai(is_parameter_timeseries(arg1), arg1, arg2)
226-
end
218+
function (ai::AbstractParameterGetIndexer)(buffer::AbstractArray, prob)
219+
ai(buffer, is_parameter_timeseries(prob), prob)
227220
end
228221
function (ai::AbstractParameterGetIndexer)(buffer::AbstractArray, prob, i)
229222
ai(buffer, is_parameter_timeseries(prob), prob, i)

test/example_test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
7171
@test all(.!is_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r]))
7272
@test all(parameter_index.((sys,), [:c, :a, :b]) .== [3, 1, 2])
7373
@test all(parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing)
74+
@test all(.!is_timeseries_parameter.((sys,), [:x, :y, :z, :t, :p, :q, :r])) # fallback even if not implemented
75+
@test all(timeseries_parameter_index.((sys,), [:x, :y, :z, :t, :p, :q, :r]) .=== nothing) # fallback
7476
@test is_independent_variable(sys, :t)
7577
@test all(.!is_independent_variable.((sys,), [:x, :y, :z, :a, :b, :c, :p, :q, :r]))
7678
@test all(is_observed.((sys,), [:x, :y, :z, :a, :b, :c, :t]))
@@ -88,6 +90,7 @@ sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], :t)
8890
@test independent_variable_symbols(sys) == [:t]
8991
@test all_variable_symbols(sys) == [:x, :y, :z]
9092
@test sort(all_symbols(sys)) == [:a, :b, :c, :t, :x, :y, :z]
93+
@test default_values(sys) == Dict() # fallback even if not implemented
9194

9295
sys = SystemMockup(true, [:x, :y, :z], [:a, :b, :c], nothing)
9396

test/parameter_indexing_test.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ using SymbolicIndexingInterface: IndexerTimeseries, IndexerNotTimeseries, Indexe
55
MixedParameterTimeseriesIndexError
66
using Test
77

8+
arr = [1.0, 2.0, 3.0]
9+
@test parameter_values(arr) == arr
10+
@test current_time(arr) == arr
11+
tp = (1.0, 2.0, 3.0)
12+
@test parameter_values(tp) == tp
13+
814
struct FakeIntegrator{S, P}
915
sys::S
1016
p::P
@@ -29,6 +35,9 @@ for sys in [
2935
p = [1.0, 2.0, 3.0, 4.0]
3036
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
3137
new_p = [4.0, 5.0, 6.0, 7.0]
38+
for i in [7, CartesianIndex(5)]
39+
@test parameter_values_at_state_time(fi, i) == parameter_values(fi)
40+
end
3241
for (sym, oldval, newval, check_inference) in [
3342
(:a, p[1], new_p[1], true),
3443
(1, p[1], new_p[1], true),
@@ -308,6 +317,9 @@ for (sym, val_is_timeseries, val, check_inference) in [
308317
((:x, :b, :c), true, tuple.(xval, bval_state, cval_state), true),
309318
([:a, :b, :x], true, vcat.(aval, bval_state, xval), false),
310319
((:a, :b, :x), true, tuple.(aval, bval_state, xval), true),
320+
(:(2b), true, 2 .* bval_state, true),
321+
([:x, :(2b), :(3c)], true, vcat.(xval, 2 .* bval_state, 3 .* cval_state), true),
322+
((:x, :(2b), :(3c)), true, tuple.(xval, 2 .* bval_state, 3 .* cval_state), true),
311323
]
312324
getter = getu(sys, sym)
313325
if val isa DataType
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using SymbolicIndexingInterface
2+
using Test
3+
4+
struct MyDiffEqArray
5+
t::Vector{Float64}
6+
u::Vector{Vector{Float64}}
7+
end
8+
SymbolicIndexingInterface.current_time(mda::MyDiffEqArray) = mda.t
9+
SymbolicIndexingInterface.state_values(mda::MyDiffEqArray) = mda.u
10+
SymbolicIndexingInterface.is_timeseries(::Type{MyDiffEqArray}) = Timeseries()
11+
12+
@test_throws ArgumentError ParameterTimeseriesCollection((ones(3), 2ones(3)))
13+
14+
a_timeseries = MyDiffEqArray(collect(0:0.1:0.9), [[2.5i, sin(0.2i)] for i in 1:10])
15+
b_timeseries = MyDiffEqArray(collect(0:0.25:0.9), [[3.5i, log(1.3i)] for i in 1:4])
16+
c_timeseries = MyDiffEqArray(collect(0:0.17:0.90), [[4.3i] for i in 1:5])
17+
collection = (a_timeseries, b_timeseries, c_timeseries)
18+
ptc = ParameterTimeseriesCollection(collection)
19+
20+
@test collect(eachindex(ptc)) == [1, 2, 3]
21+
@test [x for x in ptc] == [a_timeseries, b_timeseries, c_timeseries]
22+
@test length(ptc) == 3
23+
@test parent(ptc) === collection
24+
25+
for i in 1:3
26+
@test ptc[i] === collection[i]
27+
@test parameter_timeseries(ptc, i) == collection[i].t
28+
for j in eachindex(collection[i].u[1])
29+
pti = ParameterTimeseriesIndex(i, j)
30+
@test ptc[pti] == getindex.(collection[i].u, j)
31+
for k in eachindex(collection[i].u)
32+
rhs = collection[i].u[k][j]
33+
@test ptc[pti, CartesianIndex(k)] == rhs
34+
@test ptc[pti, k] == rhs
35+
@test ptc[i, k] == collection[i].u[k]
36+
@test ptc[i, k, j] == rhs
37+
@test parameter_values(ptc, pti, k) == rhs
38+
end
39+
allidxs = eachindex(collection[i].u)
40+
for subidx in [:, rand(allidxs, 3), rand(Bool, length(allidxs))]
41+
rhs = getindex.(collection[i].u[subidx], j)
42+
@test ptc[pti, subidx] == rhs
43+
@test ptc[i, subidx, j] == rhs
44+
@test parameter_values(ptc, pti, subidx) == rhs
45+
end
46+
end
47+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ if GROUP == "All" || GROUP == "Core"
2727
@safetestset "Fallback test" begin
2828
@time include("fallback_test.jl")
2929
end
30+
@safetestset "ParameterTimeseriesCollection test" begin
31+
@time include("parameter_timeseries_collection_test.jl")
32+
end
3033
@safetestset "Parameter indexing test" begin
3134
@time include("parameter_indexing_test.jl")
3235
end

0 commit comments

Comments
 (0)