Skip to content

Commit 320753e

Browse files
refactor: use callable structs for setu and setp
1 parent 635595a commit 320753e

File tree

3 files changed

+57
-43
lines changed

3 files changed

+57
-43
lines changed

src/parameter_indexing.jl

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function getp(sys, p)
4545
_getp(sys, symtype, elsymtype, p)
4646
end
4747

48-
struct GetParameterIndex{I} <: AbstractIndexer
48+
struct GetParameterIndex{I} <: AbstractGetIndexer
4949
idx::I
5050
end
5151

@@ -78,7 +78,7 @@ function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
7878
sys, NotSymbolic(), NotSymbolic(), idx)
7979
end
8080

81-
struct MultipleParameterGetters{G}
81+
struct MultipleParameterGetters{G} <: AbstractGetIndexer
8282
getters::G
8383
end
8484

@@ -148,6 +148,17 @@ function _getp(sys, ::ArraySymbolic, ::NotSymbolic, p)
148148
return getp(sys, collect(p))
149149
end
150150

151+
struct ParameterHookWrapper{S, O} <: AbstractSetIndexer
152+
setter::S
153+
original_index::O
154+
end
155+
156+
function (phw::ParameterHookWrapper)(prob, args...)
157+
res = phw.setter(prob, args...)
158+
finalize_parameters_hook!(prob, phw.original_index)
159+
res
160+
end
161+
151162
"""
152163
setp(sys, p)
153164
@@ -165,33 +176,35 @@ function setp(sys, p; run_hook = true)
165176
symtype = symbolic_type(p)
166177
elsymtype = symbolic_type(eltype(p))
167178
return if run_hook
168-
let _setter! = _setp(sys, symtype, elsymtype, p), p = p
169-
function setter!(prob, args...)
170-
res = _setter!(prob, args...)
171-
finalize_parameters_hook!(prob, p)
172-
res
173-
end
174-
end
179+
return ParameterHookWrapper(_setp(sys, symtype, elsymtype, p), p)
175180
else
176181
_setp(sys, symtype, elsymtype, p)
177182
end
178183
end
179184

185+
struct SetParameterIndex{I} <: AbstractSetIndexer
186+
idx::I
187+
end
188+
189+
function (spi::SetParameterIndex)(prob, val)
190+
set_parameter!(prob, val, spi.idx)
191+
end
192+
180193
function _setp(sys, ::NotSymbolic, ::NotSymbolic, p)
181-
return let p = p
182-
function setter!(sol, val)
183-
set_parameter!(sol, val, p)
184-
end
185-
end
194+
return SetParameterIndex(p)
186195
end
187196

188197
function _setp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
189198
idx = parameter_index(sys, p)
190-
return let idx = idx
191-
function setter!(sol, val)
192-
set_parameter!(sol, val, idx)
193-
end
194-
end
199+
return SetParameterIndex(idx)
200+
end
201+
202+
struct MultipleSetters{S} <: AbstractSetIndexer
203+
setters::S
204+
end
205+
206+
function (ms::MultipleSetters)(prob, val)
207+
map((s!, v) -> s!(prob, v), ms.setters, val)
195208
end
196209

197210
for (t1, t2) in [
@@ -201,11 +214,7 @@ for (t1, t2) in [
201214
]
202215
@eval function _setp(sys, ::NotSymbolic, ::$t1, p::$t2)
203216
setters = setp.((sys,), p; run_hook = false)
204-
return let setters = setters
205-
function setter!(sol, val)
206-
map((s!, v) -> s!(sol, v), setters, val)
207-
end
208-
end
217+
return MultipleSetters(setters)
209218
end
210219
end
211220

src/state_indexing.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function getu(sys, sym)
3333
_getu(sys, symtype, elsymtype, sym)
3434
end
3535

36-
struct GetStateIndex{I} <: AbstractIndexer
36+
struct GetStateIndex{I} <: AbstractGetIndexer
3737
idx::I
3838
end
3939
function (gsi::GetStateIndex)(::Timeseries, prob)
@@ -50,7 +50,7 @@ function _getu(sys, ::NotSymbolic, ::NotSymbolic, sym)
5050
return GetStateIndex(sym)
5151
end
5252

53-
struct GetpAtStateTime{G} <: AbstractIndexer
53+
struct GetpAtStateTime{G} <: AbstractGetIndexer
5454
getter::G
5555
end
5656

@@ -65,12 +65,12 @@ function (g::GetpAtStateTime)(::NotTimeseries, prob)
6565
g.getter(prob)
6666
end
6767

68-
struct GetIndepvar <: AbstractIndexer end
68+
struct GetIndepvar <: AbstractGetIndexer end
6969

7070
(::GetIndepvar)(::IsTimeseriesTrait, prob) = current_time(prob)
7171
(::GetIndepvar)(::Timeseries, prob, i) = current_time(prob, i)
7272

73-
struct TimeDependentObservedFunction{F} <: AbstractIndexer
73+
struct TimeDependentObservedFunction{F} <: AbstractGetIndexer
7474
obsfn::F
7575
end
7676

@@ -89,7 +89,7 @@ function (o::TimeDependentObservedFunction)(::NotTimeseries, prob)
8989
return o.obsfn(state_values(prob), parameter_values(prob), current_time(prob))
9090
end
9191

92-
struct TimeIndependentObservedFunction{F} <: AbstractIndexer
92+
struct TimeIndependentObservedFunction{F} <: AbstractGetIndexer
9393
obsfn::F
9494
end
9595

@@ -116,7 +116,7 @@ function _getu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
116116
error("Invalid symbol $sym for `getu`")
117117
end
118118

119-
struct MultipleGetters{G} <: AbstractIndexer
119+
struct MultipleGetters{G} <: AbstractGetIndexer
120120
getters::G
121121
end
122122

@@ -131,7 +131,7 @@ function (mg::MultipleGetters)(::NotTimeseries, prob)
131131
return map(g -> g(prob), mg.getters)
132132
end
133133

134-
struct AsTupleWrapper{G} <: AbstractIndexer
134+
struct AsTupleWrapper{G} <: AbstractGetIndexer
135135
getter::G
136136
end
137137

@@ -201,18 +201,22 @@ function setu(sys, sym)
201201
_setu(sys, symtype, elsymtype, sym)
202202
end
203203

204+
struct SetStateIndex{I} <: AbstractSetIndexer
205+
idx::I
206+
end
207+
208+
function (ssi::SetStateIndex)(prob, val)
209+
set_state!(prob, val, ssi.idx)
210+
end
211+
204212
function _setu(sys, ::NotSymbolic, ::NotSymbolic, sym)
205-
return function setter!(prob, val)
206-
set_state!(prob, val, sym)
207-
end
213+
return SetStateIndex(sym)
208214
end
209215

210216
function _setu(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, sym)
211217
if is_variable(sys, sym)
212218
idx = variable_index(sys, sym)
213-
return function setter!(prob, val)
214-
set_state!(prob, val, idx)
215-
end
219+
return SetStateIndex(idx)
216220
elseif is_parameter(sys, sym)
217221
return setp(sys, sym)
218222
end
@@ -226,16 +230,14 @@ for (t1, t2) in [
226230
]
227231
@eval function _setu(sys, ::NotSymbolic, ::$t1, sym::$t2)
228232
setters = setu.((sys,), sym)
229-
return function setter!(prob, val)
230-
map((s!, v) -> s!(prob, v), setters, val)
231-
end
233+
return MultipleSetters(setters)
232234
end
233235
end
234236

235237
function _setu(sys, ::ArraySymbolic, ::NotSymbolic, sym)
236238
if is_variable(sys, sym)
237239
idx = variable_index(sys, sym)
238-
return setu(sys, idx)
240+
return MultipleSetters(SetStateIndex.(idx))
239241
elseif is_parameter(sys, sym)
240242
return setp(sys, sym)
241243
end

src/value_provider_interface.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,8 @@ function current_time end
139139

140140
abstract type AbstractIndexer end
141141

142-
(ai::AbstractIndexer)(prob) = ai(is_timeseries(prob), prob)
143-
(ai::AbstractIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)
142+
abstract type AbstractGetIndexer <: AbstractIndexer end
143+
abstract type AbstractSetIndexer <: AbstractIndexer end
144+
145+
(ai::AbstractGetIndexer)(prob) = ai(is_timeseries(prob), prob)
146+
(ai::AbstractGetIndexer)(prob, i) = ai(is_timeseries(prob), prob, i)

0 commit comments

Comments
 (0)