Skip to content

refactor: return functors from getu and getp #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/SymbolicIndexingInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ export SymbolCache
include("symbol_cache.jl")

export parameter_values, set_parameter!, finalize_parameters_hook!,
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries, getp,
setp
parameter_values_at_time, parameter_values_at_state_time, parameter_timeseries,
state_values, set_state!, current_time
include("value_provider_interface.jl")

export getp, setp
include("parameter_indexing.jl")

export state_values, set_state!, current_time, getu, setu
export getu, setu
include("state_indexing.jl")

export BatchedInterface, associated_systems
Expand Down
239 changes: 75 additions & 164 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,101 +1,21 @@
"""
parameter_values(p)
parameter_values(p, i)

Return an indexable collection containing the value of each parameter in `p`. The two-
argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.

If this function is called with an `AbstractArray` or `Tuple`, it will return the same
array/tuple.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::Tuple) = arr
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(arr::Tuple, i) = arr[i]
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
parameter_values_at_time(p, i)

Return an indexable collection containing the value of all parameters in `p` at time index
`i`. This is useful when parameter values change during the simulation
(such as through callbacks) and their values are saved. `i` is the time index in the
timeseries formed by these changing parameter values, obtained using
[`parameter_timeseries`](@ref).

By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at all
times. The resultant object should be indexable using [`parameter_values`](@ref).

If this function is implemented, [`parameter_values_at_state_time`](@ref) must be
implemented for [`getu`](@ref) to work correctly.
"""
function parameter_values_at_time end
parameter_values_at_time(p, i) = parameter_values(p)

"""
parameter_values_at_state_time(p, i)

Return an indexable collection containing the value of all parameters in `p` at time
index `i`. This is useful when parameter values change during the simulation (such as
through callbacks) and their values are saved. `i` is the time index in the timeseries
formed by dependent variables (as opposed to the timeseries of the parameters, as in
[`parameter_values_at_time`](@ref)).

By default, this function returns `parameter_values(p)` regardless of `i`, and only needs
to be specialized for timeseries objects where parameter values are not constant at
all times. The resultant object should be indexable using [`parameter_values`](@ref).

If this function is implemented, [`parameter_values_at_time`](@ref) must be implemented for
[`getp`](@ref) to work correctly.
"""
function parameter_values_at_state_time end
parameter_values_at_state_time(p, i) = parameter_values(p)

"""
parameter_timeseries(p)

Return an iterable of time steps at which the parameter values are saved. This is only
required for objects where `is_timeseries(p) === Timeseries()` and the parameter values
change during the simulation (such as through callbacks). By default, this returns `[0]`.

See also: [`parameter_values_at_time`](@ref).
"""
function parameter_timeseries end
parameter_timeseries(_) = [0]

"""
set_parameter!(sys, val, idx)

Set the parameter at index `idx` to `val` for system `sys`. This defaults to modifying
`parameter_values(sys)`. If any additional bookkeeping needs to be performed or the
default implementation does not work for a particular type, this method needs to be
defined to enable the proper functioning of [`setp`](@ref).

See: [`parameter_values`](@ref)
"""
function set_parameter! end

# Tuple only included for the error message
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)

"""
finalize_parameters_hook!(prob, p)

This is a callback run one for each call to the function returned by [`setp`](@ref)
which can be used to update internal data structures when parameters are modified.
This is in contrast to [`set_parameter!`](@ref) which is run once for each parameter
that is updated.
"""
finalize_parameters_hook!(prob, p) = nothing

"""
getp(sys, p)

Expand Down Expand Up @@ -125,43 +45,87 @@ function getp(sys, p)
_getp(sys, symtype, elsymtype, p)
end

struct GetParameterIndex{I} <: AbstractIndexer
idx::I
end

function (gpi::GetParameterIndex)(::IsTimeseriesTrait, prob)
parameter_values(prob, gpi.idx)
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
gpi.idx)
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
(gpi.idx,))
end
function (gpi::GetParameterIndex)(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (gpi.idx,))
end

function _getp(sys, ::NotSymbolic, ::NotSymbolic, p)
return let p = p
function _getter(::NotTimeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob)
parameter_values(prob, p)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
parameter_values(
parameter_values_at_time(
prob, only(to_indices(parameter_timeseries(prob), (i,)))),
p)
end
function _getter(::Timeseries, prob, i::Union{AbstractArray{Bool}, Colon})
parameter_values.(
parameter_values_at_time.((prob,),
(j for j in only(to_indices(parameter_timeseries(prob), (i,))))),
(p,))
end
function _getter(::Timeseries, prob, i)
parameter_values.(parameter_values_at_time.((prob,), i), (p,))
end
getter = let _getter = _getter
function getter(prob, args...)
return _getter(is_timeseries(prob), prob, args...)
end
end
getter
end
return GetParameterIndex(p)
end

function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
idx = parameter_index(sys, p)
return invoke(_getp, Tuple{Any, NotSymbolic, NotSymbolic, Any},
sys, NotSymbolic(), NotSymbolic(), idx)
return _getp(sys, NotSymbolic(), NotSymbolic(), idx)
end

struct MultipleParameterGetters{G}
getters::G
end

function (mpg::MultipleParameterGetters)(::IsTimeseriesTrait, prob)
map(g -> g(prob), mpg.getters)
end
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), mpg.getters)
end
function (mpg::MultipleParameterGetters)(::Timeseries, prob, i)
[map(g -> g(prob, j), mpg.getters)
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob)
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function (mpg::MultipleParameterGetters)(
buffer::AbstractArray, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::Timeseries, prob, i)
for (bufi, tsi) in zip(
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(mpg.getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end
end
buffer
end
function (mpg::MultipleParameterGetters)(buffer::AbstractArray, ::NotTimeseries, prob)
for (g, bufi) in zip(mpg.getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end

function (mpg::MultipleParameterGetters)(buffer::AbstractArray, prob, i...)
mpg(buffer, is_timeseries(prob), prob, i...)
end
function (mpg::MultipleParameterGetters)(prob, i...)
mpg(is_timeseries(prob), prob, i...)
end

for (t1, t2) in [
Expand All @@ -171,60 +135,7 @@ for (t1, t2) in [
]
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
getters = getp.((sys,), p)

return let getters = getters
function _getter(::NotTimeseries, prob)
map(g -> g(prob), getters)
end
function _getter(::Timeseries, prob)
map(g -> g(prob), getters)
end
function _getter(::Timeseries, prob, i::Union{Int, CartesianIndex})
map(g -> g(prob, i), getters)
end
function _getter(::Timeseries, prob, i)
[map(g -> g(prob, j), getters)
for j in only(to_indices(parameter_timeseries(prob), (i,)))]
end
function _getter!(buffer, ::NotTimeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob)
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob, i::Union{Int, CartesianIndex})
for (g, bufi) in zip(getters, eachindex(buffer))
buffer[bufi] = g(prob, i)
end
buffer
end
function _getter!(buffer, ::Timeseries, prob, i)
for (bufi, tsi) in zip(
eachindex(buffer), only(to_indices(parameter_timeseries(prob), (i,))))
for (g, bufj) in zip(getters, eachindex(buffer[bufi]))
buffer[bufi][bufj] = g(prob, tsi)
end
end
buffer
end
_getter, _getter!
getter = let _getter = _getter, _getter! = _getter!
function getter(prob, i...)
return _getter(is_timeseries(prob), prob, i...)
end
function getter(buffer::AbstractArray, prob, i...)
return _getter!(buffer, is_timeseries(prob), prob, i...)
end
getter
end
getter
end
return MultipleParameterGetters(getters)
end
end

Expand Down
Loading