Skip to content

Commit c7066ec

Browse files
fixup! wip: better parameter indexing
1 parent b8139fa commit c7066ec

File tree

2 files changed

+109
-30
lines changed

2 files changed

+109
-30
lines changed

src/parameter_indexing.jl

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,50 @@ end
4444
function (gpi::GetParameterIndex)(::Timeseries, prob, args)
4545
throw(ParameterTimeseriesValueIndexMismatchError{Timeseries}(prob, gpi, args))
4646
end
47-
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob)
48-
parameter_values.(
49-
(prob,), (gpi.idx,),
50-
eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
47+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob)
48+
gpi.((ts,), (prob,), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
49+
end
50+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob)
51+
for (buf_idx, ts_idx) in zip(eachindex(buffer), eachindex(parameter_timeseries(prob, indexer_timeseries_index(gpi))))
52+
buffer[buf_idx] = gpi(ts, prob, ts_idx)
53+
end
54+
buffer
5155
end
5256
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::Timeseries, prob, i::Union{Int, CartesianIndex})
5357
parameter_values(prob, gpi.idx, i)
5458
end
5559
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, ::Colon)
5660
gpi(ts, prob)
5761
end
62+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, ::Colon)
63+
gpi(buffer, ts, prob)
64+
end
5865
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i::AbstractArray{Bool})
5966
map(only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,)))) do idx
6067
gpi(ts, prob, idx)
6168
end
6269
end
70+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i::AbstractArray{Bool})
71+
for (buf_idx, ts_idx) in zip(eachindex(buffer), only(to_indices(parameter_timeseries(prob, indexer_timeseries_index(gpi)), (i,))))
72+
buffer[buf_idx] = gpi(ts, prob, ts_idx)
73+
end
74+
buffer
75+
end
6376
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(ts::Timeseries, prob, i)
6477
gpi.((ts,), (prob,), i)
6578
end
79+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(buffer::AbstractArray, ts::Timeseries, prob, i)
80+
for (buf_idx, subidx) in zip(eachindex(buffer), i)
81+
buffer[buf_idx] = gpi(ts, prob, subidx)
82+
end
83+
buffer
84+
end
6685
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::NotTimeseries, prob)
6786
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
6887
end
88+
function (gpi::GetParameterIndex{<:ParameterTimeseriesIndex})(::AbstractArray, ::NotTimeseries, prob)
89+
throw(ParameterTimeseriesValueIndexMismatchError{NotTimeseries}(prob, gpi))
90+
end
6991

7092
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
7193
return GetParameterIndex(p)
@@ -84,6 +106,9 @@ as_timeseries_indexer(::IndexerBoth, gpti::GetParameterTimeseriesIndex) = gpti.p
84106
function (gpti::GetParameterTimeseriesIndex)(ts::Timeseries, prob, args...)
85107
gpti.param_timeseries_idx(ts, prob, args...)
86108
end
109+
function (gpti::GetParameterTimeseriesIndex)(buffer::AbstractArray, ts::Timeseries, prob, args...)
110+
gpti.param_timeseries_idx(buffer, ts, prob, args...)
111+
end
87112
function (gpti::GetParameterTimeseriesIndex)(ts::NotTimeseries, prob)
88113
gpti.param_idx(ts, prob)
89114
end

test/parameter_indexing_test.jl

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -166,55 +166,109 @@ dval = fs.p[4]
166166
bidx = timeseries_parameter_index(sys, :b)
167167
cidx = timeseries_parameter_index(sys, :c)
168168

169-
for (sym, indexer_trait, timeseries_index, val, check_inference) in [
170-
(:a, IndexerNotTimeseries, 0, aval, true),
171-
(1, IndexerNotTimeseries, 0, aval, true),
172-
([:a, :d], IndexerNotTimeseries, 0, [aval, dval], true),
173-
((:a, :d), IndexerNotTimeseries, 0, (aval, dval), true),
174-
([1, 4], IndexerNotTimeseries, 0, [aval, dval], true),
175-
((1, 4), IndexerNotTimeseries, 0, (aval, dval), true),
176-
([:a, 4], IndexerNotTimeseries, 0, [aval, dval], true),
177-
((:a, 4), IndexerNotTimeseries, 0, (aval, dval), true),
178-
(:b, IndexerBoth, 1, bval, true),
179-
(bidx, IndexerTimeseries, 1, bval, true),
180-
([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true),
181-
((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true),
182-
([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], true),
183-
((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), true),
184-
([:b, :b], IndexerBoth, 1, vcat.(bval, bval), true),
185-
((:b, :b), IndexerBoth, 1, tuple.(bval, bval), true),
186-
([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), true),
187-
((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval), true),
188-
([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), true),
189-
((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), true),
169+
for (sym, indexer_trait, timeseries_index, val, buffer, check_inference) in [
170+
(:a, IndexerNotTimeseries, 0, aval, nothing, true),
171+
(1, IndexerNotTimeseries, 0, aval, nothing, true),
172+
([:a, :d], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
173+
((:a, :d), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
174+
([1, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
175+
((1, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
176+
([:a, 4], IndexerNotTimeseries, 0, [aval, dval], zeros(2), true),
177+
((:a, 4), IndexerNotTimeseries, 0, (aval, dval), zeros(2), true),
178+
(:b, IndexerBoth, 1, bval, zeros(length(bval)), true),
179+
(bidx, IndexerTimeseries, 1, bval, zeros(length(bval)), true),
180+
([:a, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true),
181+
((:a, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true),
182+
([1, :b], IndexerNotTimeseries, 0, [aval, bval[end]], zeros(2), true),
183+
((1, :b), IndexerNotTimeseries, 0, (aval, bval[end]), zeros(2), true),
184+
([:b, :b], IndexerBoth, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
185+
((:b, :b), IndexerBoth, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true),
186+
([bidx, :b], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
187+
((bidx, :b), IndexerTimeseries, 1, tuple.(bval, bval),map(_ -> zeros(2), bval), true),
188+
([bidx, bidx], IndexerTimeseries, 1, vcat.(bval, bval), map(_ -> zeros(2), bval), true),
189+
((bidx, bidx), IndexerTimeseries, 1, tuple.(bval, bval), map(_ -> zeros(2), bval), true),
190190
]
191191
getter = getp(sys, sym)
192192
@test is_indexer_timeseries(getter) isa indexer_trait
193193
if indexer_trait <: Union{IndexerTimeseries, IndexerBoth}
194194
@test indexer_timeseries_index(getter) == timeseries_index
195195
end
196+
test_inplace = buffer !== nothing
197+
test_non_timeseries = indexer_trait !== IndexerTimeseries
198+
if test_inplace && test_non_timeseries
199+
non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end]
200+
non_timeseries_buffer = indexer_trait == IndexerNotTimeseries ? deepcopy(buffer) : deepcopy(buffer[end])
201+
test_non_timeseries_inplace = non_timeseries_buffer isa AbstractArray
202+
end
196203
if check_inference
197204
@inferred getter(fs)
198-
if indexer_trait != IndexerTimeseries
205+
if test_inplace
206+
@inferred getter(deepcopy(buffer), fs)
207+
end
208+
if test_non_timeseries
199209
@inferred getter(parameter_values(fs))
210+
if test_inplace && test_non_timeseries_inplace && test_non_timeseries_inplace
211+
@inferred getter(deepcopy(non_timeseries_buffer), parameter_values(fs))
212+
end
200213
end
201214
end
202215
@test getter(fs) == val
203-
204-
if indexer_trait == IndexerTimeseries
205-
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs))
206-
else
216+
if test_inplace
217+
tmp = deepcopy(buffer)
218+
getter(tmp, fs)
219+
if val isa Tuple
220+
target = collect(val)
221+
elseif eltype(val) <: Tuple
222+
target = collect.(val)
223+
else
224+
target = val
225+
end
226+
@test tmp == target
227+
end
228+
if test_non_timeseries
207229
non_timeseries_val = indexer_trait == IndexerNotTimeseries ? val : val[end]
208230
@test getter(parameter_values(fs)) == non_timeseries_val
231+
if test_inplace && test_non_timeseries && test_non_timeseries_inplace
232+
getter(non_timeseries_buffer, parameter_values(fs))
233+
if non_timeseries_val isa Tuple
234+
target = collect(non_timeseries_val)
235+
else
236+
target = non_timeseries_val
237+
end
238+
@test non_timeseries_buffer == target
239+
end
240+
else
241+
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter(parameter_values(fs))
242+
if test_inplace
243+
@test_throws ParameterTimeseriesValueIndexMismatchError{NotTimeseries} getter([], parameter_values(fs))
244+
end
209245
end
210246
for subidx in [1, CartesianIndex(1), :, rand(Bool, length(val)), rand(eachindex(val), 3), 1:2]
211247
if indexer_trait <: IndexerNotTimeseries
212248
@test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter(fs, subidx)
249+
if test_inplace
250+
@test_throws ParameterTimeseriesValueIndexMismatchError{Timeseries} getter([], fs, subidx)
251+
end
213252
else
214253
if check_inference
215254
@inferred getter(fs, subidx)
255+
if test_inplace && buffer[subidx] isa AbstractArray
256+
@inferred getter(deepcopy(buffer[subidx]), fs, subidx)
257+
end
216258
end
217259
@test getter(fs, subidx) == val[subidx]
260+
if test_inplace && buffer[subidx] isa AbstractArray
261+
tmp = deepcopy(buffer[subidx])
262+
getter(tmp, fs, subidx)
263+
if val[subidx] isa Tuple
264+
target = collect(val[subidx])
265+
elseif eltype(val) <: Tuple
266+
target = collect.(val[subidx])
267+
else
268+
target = val[subidx]
269+
end
270+
@test tmp == target
271+
end
218272
end
219273
end
220274
end

0 commit comments

Comments
 (0)