Skip to content

Commit 462e0fb

Browse files
feat: implement new SII and SciMLBase interfaces for discrete saving/indexing
1 parent f2aad04 commit 462e0fb

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

src/systems/abstractsystem.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,46 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
475475
return nothing
476476
end
477477

478+
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
479+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
480+
is_timeseries_parameter(ic, sym)
481+
end
482+
483+
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
484+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
485+
timeseries_parameter_index(ic, sym)
486+
end
487+
488+
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
489+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
490+
allvars = vars(sym; op = Symbolics.Operator)
491+
ts_idxs = Set{Int}()
492+
for var in allvars
493+
var = unwrap(var)
494+
# FIXME: Shouldn't have to shift systems
495+
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
496+
var = only(arguments(var))
497+
end
498+
ts_idx = check_index_map(ic.discrete_clocks, unwrap(var))
499+
ts_idx === nothing && continue
500+
push!(ts_idxs, ts_idx)
501+
end
502+
if length(ts_idxs) == 1
503+
ts_idx = only(ts_idxs)
504+
else
505+
ts_idx = nothing
506+
end
507+
obsfn = let raw_obs_fn = build_explicit_observed_function(
508+
sys, sym; param_only = true)
509+
f(p::MTKParameters, t) = raw_obs_fn(p..., t)
510+
end
511+
else
512+
ts_idx = nothing
513+
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
514+
end
515+
return ParameterObservedFunction(ts_idx, obsfn)
516+
end
517+
478518
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
479519
return full_parameters(sys)
480520
end

src/systems/parameter_buffer.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,54 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
420420
return newbuf
421421
end
422422

423+
struct NestedGetIndex{T}
424+
x::T
425+
end
426+
427+
function Base.getindex(ngi::NestedGetIndex, idx::Tuple)
428+
i, j, k... = idx
429+
return ngi.x[i][j][k...]
430+
end
431+
432+
# Required for DiffEqArray constructor to work during interpolation
433+
Base.size(::NestedGetIndex) = ()
434+
435+
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
436+
ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex}
437+
for (i, val) in args
438+
ps.discrete[i] = val.x
439+
end
440+
return ps
441+
end
442+
443+
function SciMLBase.create_parameter_timeseries_collection(
444+
sys::AbstractSystem, ps::MTKParameters, tspan)
445+
ic = get_index_cache(sys) # this exists because the parameters are `MTKParameters`
446+
has_discrete_subsystems(sys) || return nothing
447+
(dss = get_discrete_subsystems(sys)) === nothing && return nothing
448+
_, _, _, id_to_clock = dss
449+
buffers = []
450+
451+
for (i, partition) in enumerate(ps.discrete)
452+
clock = id_to_clock[i + 1]
453+
if clock isa Clock
454+
ts = tspan[1]:(clock.dt):tspan[2]
455+
push!(buffers, DiffEqArray(NestedGetIndex{typeof(partition)}[], ts, (1, 1)))
456+
elseif clock isa SolverStepClock
457+
push!(buffers,
458+
DiffEqArray(NestedGetIndex{typeof(partition)}[], eltype(tspan)[], (1, 1)))
459+
else
460+
error("Unhandled clock $clock")
461+
end
462+
end
463+
464+
return ParameterTimeseriesCollection(Tuple(buffers), copy(ps))
465+
end
466+
467+
function SciMLBase.get_saveable_values(ps::MTKParameters, timeseries_idx)
468+
return NestedGetIndex(deepcopy(ps.discrete[timeseries_idx]))
469+
end
470+
423471
function DiffEqBase.anyeltypedual(
424472
p::MTKParameters, ::Type{Val{counter}} = Val{0}) where {counter}
425473
DiffEqBase.anyeltypedual(p.tunable)

0 commit comments

Comments
 (0)