Skip to content

Commit 08ed9a2

Browse files
refactor: store parameters from different clock partitions separately
1 parent ceed151 commit 08ed9a2

File tree

3 files changed

+341
-86
lines changed

3 files changed

+341
-86
lines changed

src/systems/abstractsystem.jl

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ end
439439
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
440440
sym = unwrap(sym)
441441
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
442-
return is_parameter(ic, sym) ||
442+
return sym isa ParameterIndex || is_parameter(ic, sym) ||
443443
iscall(sym) && operation(sym) === getindex &&
444444
is_parameter(ic, first(arguments(sym)))
445445
end
@@ -465,11 +465,21 @@ end
465465
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
466466
sym = unwrap(sym)
467467
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
468-
return if (idx = parameter_index(ic, sym)) !== nothing
469-
idx
468+
return if sym isa ParameterIndex
469+
sym
470+
elseif (idx = parameter_index(ic, sym)) !== nothing
471+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
472+
return nothing
473+
else
474+
idx
475+
end
470476
elseif iscall(sym) && operation(sym) === getindex &&
471477
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
472-
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
478+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing
479+
return nothing
480+
else
481+
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
482+
end
473483
else
474484
nothing
475485
end
@@ -487,7 +497,12 @@ end
487497

488498
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
489499
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
490-
return parameter_index(ic, sym)
500+
idx = parameter_index(ic, sym)
501+
if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
502+
return nothing
503+
else
504+
return idx
505+
end
491506
end
492507
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
493508
if idx !== nothing
@@ -500,6 +515,67 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
500515
return nothing
501516
end
502517

518+
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
519+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
520+
is_timeseries_parameter(ic, sym)
521+
end
522+
523+
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
524+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
525+
timeseries_parameter_index(ic, sym)
526+
end
527+
528+
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
529+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
530+
allvars = vars(sym; op = Symbolics.Operator)
531+
ts_idxs = Set{Int}()
532+
for var in allvars
533+
var = unwrap(var)
534+
# FIXME: Shouldn't have to shift systems
535+
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
536+
var = only(arguments(var))
537+
end
538+
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
539+
ts_idx === nothing && continue
540+
push!(ts_idxs, ts_idx[1])
541+
end
542+
if length(ts_idxs) == 1
543+
ts_idx = only(ts_idxs)
544+
else
545+
ts_idx = nothing
546+
end
547+
rawobs = build_explicit_observed_function(
548+
sys, sym; param_only = true, return_inplace = true)
549+
if rawobs isa Tuple
550+
if is_time_dependent(sys)
551+
obsfn = let oop = rawobs[1], iip = rawobs[2]
552+
f1a(p::MTKParameters, t) = oop(p..., t)
553+
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
554+
end
555+
else
556+
obsfn = let oop = rawobs[1], iip = rawobs[2]
557+
f1b(p::MTKParameters) = oop(p...)
558+
f1b(out, p::MTKParameters) = iip(out, p...)
559+
end
560+
end
561+
else
562+
if is_time_dependent(sys)
563+
obsfn = let rawobs = rawobs
564+
f2a(p::MTKParameters, t) = rawobs(p..., t)
565+
end
566+
else
567+
obsfn = let rawobs = rawobs
568+
f2b(p::MTKParameters) = rawobs(p...)
569+
end
570+
end
571+
end
572+
else
573+
ts_idx = nothing
574+
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
575+
end
576+
return ParameterObservedFunction(ts_idx, obsfn)
577+
end
578+
503579
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
504580
return full_parameters(sys)
505581
end
@@ -517,7 +593,7 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys
517593
end
518594

519595
function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
520-
return !is_variable(sys, sym) && !is_parameter(sys, sym) &&
596+
return !is_variable(sys, sym) && parameter_index(sys, sym) === nothing &&
521597
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
522598
end
523599

0 commit comments

Comments
 (0)