Skip to content

Commit f08f203

Browse files
refactor: return functors from getu and getp
1 parent 632032f commit f08f203

File tree

4 files changed

+319
-336
lines changed

4 files changed

+319
-336
lines changed

src/SymbolicIndexingInterface.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ export SymbolCache
2424
include("symbol_cache.jl")
2525

2626
export parameter_values, set_parameter!, finalize_parameters_hook!,
27-
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
28-
setp
27+
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries,
28+
state_values, set_state!, current_time
29+
include("value_provider_interface.jl")
30+
31+
export getp, setp
2932
include("parameter_indexing.jl")
3033

31-
export state_values, set_state!, current_time, getu, setu
34+
export getu, setu
3235
include("state_indexing.jl")
3336

3437
export BatchedInterface, associated_systems

src/parameter_indexing.jl

Lines changed: 75 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,21 @@
1-
"""
2-
parameter_values(p)
3-
parameter_values(p, i)
4-
5-
Return an indexable collection containing the value of each parameter in `p`. The two-
6-
argument version of this function returns the parameter value at index `i`. The
7-
two-argument version of this function will default to returning
8-
`parameter_values(p)[i]`.
9-
10-
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
11-
array/tuple.
12-
"""
13-
function parameter_values end
14-
151
parameter_values(arr::AbstractArray) = arr
162
parameter_values(arr::Tuple) = arr
173
parameter_values(arr::AbstractArray, i) = arr[i]
184
parameter_values(arr::Tuple, i) = arr[i]
195
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)
206

21-
"""
22-
parameter_values_at_time(p, i)
23-
24-
Return an indexable collection containing the value of all parameters in `p` at time index
25-
`i`. This is useful when parameter values change during the simulation
26-
(such as through callbacks) and their values are saved. `i` is the time index in the
27-
timeseries formed by these changing parameter values, obtained using
28-
[`parameter_timeseries`](@ref).
29-
30-
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
31-
to be specialized for timeseries objects where parameter values are not constant at all
32-
times. The resultant object should be indexable using [`parameter_values`](@ref).
33-
34-
If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
35-
implemented for [`getu`](@ref) to work correctly.
36-
"""
37-
function parameter_values_at_time end
387
parameter_values_at_time(p, i) = parameter_values(p)
398

40-
"""
41-
parameter_values_at_state_time(p, i)
42-
43-
Return an indexable collection containing the value of all parameters in `p` at time
44-
index `i`. This is useful when parameter values change during the simulation (such as
45-
through callbacks) and their values are saved. `i` is the time index in the timeseries
46-
formed by dependent variables (as opposed to the timeseries of the parameters, as in
47-
[`parameter_values_at_time`](@ref)).
48-
49-
By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
50-
to be specialized for timeseries objects where parameter values are not constant at
51-
all times. The resultant object should be indexable using [`parameter_values`](@ref).
52-
53-
If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
54-
[`getp`](@ref) to work correctly.
55-
"""
56-
function parameter_values_at_state_time end
579
parameter_values_at_state_time(p, i) = parameter_values(p)
5810

59-
"""
60-
parameter_timeseries(p)
61-
62-
Return an iterable of time steps at which the parameter values are saved. This is only
63-
required for objects where `is_timeseries(p) === Timeseries()` and the parameter values
64-
change during the simulation (such as through callbacks). By default, this returns `[0]`.
65-
66-
See also: [`parameter_values_at_time`](@ref).
67-
"""
68-
function parameter_timeseries end
6911
parameter_timeseries(_) = [0]
7012

71-
"""
72-
set_parameter!(sys, val, idx)
73-
74-
Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
75-
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
76-
default implementation does not work for a particular type, this method needs to be
77-
defined to enable the proper functioning of [`setp`](@ref).
78-
79-
See: [`parameter_values`](@ref)
80-
"""
81-
function set_parameter! end
82-
8313
# Tuple only included for the error message
8414
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
8515
sys[idx] = val
8616
end
8717
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
8818

89-
"""
90-
finalize_parameters_hook!(prob, p)
91-
92-
This is a callback run one for each call to the function returned by [`setp`](@ref)
93-
which can be used to update internal data structures when parameters are modified.
94-
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
95-
that is updated.
96-
"""
97-
finalize_parameters_hook!(prob, p) = nothing
98-
9919
"""
10020
getp(sys, p)
10121
@@ -125,43 +45,87 @@ function getp(sys, p)
12545
_getp(sys, symtype, elsymtype, p)
12646
end
12747

48+
struct GetParameterIndex{I} <: AbstractIndexer
49+
idx::I
50+
end
51+
52+
function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob)
53+
parameter_values(prob, gpi.idx)
54+
end
55+
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex})
56+
parameter_values(
57+
parameter_values_at_time(
58+
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
59+
gpi.idx)
60+
end
61+
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
62+
parameter_values.(
63+
parameter_values_at_time.((prob,),
64+
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
65+
(gpi.idx,))
66+
end
67+
function (gpi::GetParameterIndex)(::Timeseries, prob, i)
68+
parameter_values.(parameter_values_at_time.((prob,), i), (gpi.idx,))
69+
end
70+
12871
function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
129-
return let p = p
130-
function _getter(::NotTimeseries, prob)
131-
parameter_values(prob, p)
132-
end
133-
function _getter(::Timeseries, prob)
134-
parameter_values(prob, p)
135-
end
136-
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
137-
parameter_values(
138-
parameter_values_at_time(
139-
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
140-
p)
141-
end
142-
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
143-
parameter_values.(
144-
parameter_values_at_time.((prob,),
145-
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
146-
(p,))
147-
end
148-
function _getter(::Timeseries, prob, i)
149-
parameter_values.(parameter_values_at_time.((prob,), i), (p,))
150-
end
151-
getter = let _getter = _getter
152-
function getter(prob, args...)
153-
return _getter(is_timeseries(prob), prob, args...)
154-
end
155-
end
156-
getter
157-
end
72+
return GetParameterIndex(p)
15873
end
15974

16075
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
16176
idx = parameter_index(sys, p)
16277
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
16378
sys, NotSymbolic(), NotSymbolic(), idx)
164-
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)
79+
end
80+
81+
struct MultipleParameterGetters{G}
82+
getters::G
83+
end
84+
85+
function (mpg::MultipleParameterGetters)(::IsTimeseriesTrait, prob)
86+
map(g -> g(prob), mpg.getters)
87+
end
88+
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex})
89+
map(g -> g(prob, i), mpg.getters)
90+
end
91+
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i)
92+
[map(g -> g(prob, j), mpg.getters)
93+
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
94+
end
95+
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob)
96+
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
97+
buffer[bufi] = g(prob)
98+
end
99+
buffer
100+
end
101+
function (mpg::MultipleParameterGetters)(
102+
buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
103+
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
104+
buffer[bufi] = g(prob, i)
105+
end
106+
buffer
107+
end
108+
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob, i)
109+
for (bufi, tsi) in zip(
110+
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
111+
for (g, bufj) in zip(mpg.getters, eachindex(buffer[bufi]))
112+
buffer[bufi][bufj] = g(prob, tsi)
113+
end
114+
end
115+
buffer
116+
end
117+
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::NotTimeseries, prob)
118+
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
119+
buffer[bufi] = g(prob)
120+
end
121+
buffer
122+
end
123+
124+
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, prob, i...)
125+
mpg(buffer, is_timeseries(prob), prob, i...)
126+
end
127+
function (mpg::MultipleParameterGetters)(prob, i...)
128+
mpg(is_timeseries(prob), prob, i...)
165129
end
166130

167131
for (t1, t2) in [
@@ -171,60 +135,7 @@ for (t1, t2) in [
171135
]
172136
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
173137
getters = getp.((sys,), p)
174-
175-
return let getters = getters
176-
function _getter(::NotTimeseries, prob)
177-
map(g -> g(prob), getters)
178-
end
179-
function _getter(::Timeseries, prob)
180-
map(g -> g(prob), getters)
181-
end
182-
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
183-
map(g -> g(prob, i), getters)
184-
end
185-
function _getter(::Timeseries, prob, i)
186-
[map(g -> g(prob, j), getters)
187-
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
188-
end
189-
function _getter!(buffer, ::NotTimeseries, prob)
190-
for (g, bufi) in zip(getters, eachindex(buffer))
191-
buffer[bufi] = g(prob)
192-
end
193-
buffer
194-
end
195-
function _getter!(buffer, ::Timeseries, prob)
196-
for (g, bufi) in zip(getters, eachindex(buffer))
197-
buffer[bufi] = g(prob)
198-
end
199-
buffer
200-
end
201-
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
202-
for (g, bufi) in zip(getters, eachindex(buffer))
203-
buffer[bufi] = g(prob, i)
204-
end
205-
buffer
206-
end
207-
function _getter!(buffer, ::Timeseries, prob, i)
208-
for (bufi, tsi) in zip(
209-
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
210-
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
211-
buffer[bufi][bufj] = g(prob, tsi)
212-
end
213-
end
214-
buffer
215-
end
216-
_getter, _getter!
217-
getter = let _getter = _getter, _getter! = _getter!
218-
function getter(prob, i...)
219-
return _getter(is_timeseries(prob), prob, i...)
220-
end
221-
function getter(buffer::AbstractArray, prob, i...)
222-
return _getter!(buffer, is_timeseries(prob), prob, i...)
223-
end
224-
getter
225-
end
226-
getter
227-
end
138+
return MultipleParameterGetters(getters)
228139
end
229140
end
230141

0 commit comments

Comments
 (0)