Skip to content

Commit bfc73a3

Browse files
refactor: store parameters from different clock partitions separately
1 parent bc6a186 commit bfc73a3

File tree

3 files changed

+341
-86
lines changed

3 files changed

+341
-86
lines changed

src/systems/abstractsystem.jl

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ end
446446
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
447447
sym = unwrap(sym)
448448
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
449-
return is_parameter(ic, sym) ||
449+
return sym isa ParameterIndex || is_parameter(ic, sym) ||
450450
iscall(sym) && operation(sym) === getindex &&
451451
is_parameter(ic, first(arguments(sym)))
452452
end
@@ -472,11 +472,21 @@ end
472472
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
473473
sym = unwrap(sym)
474474
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
475-
return if (idx = parameter_index(ic, sym)) !== nothing
476-
idx
475+
return if sym isa ParameterIndex
476+
sym
477+
elseif (idx = parameter_index(ic, sym)) !== nothing
478+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
479+
return nothing
480+
else
481+
idx
482+
end
477483
elseif iscall(sym) && operation(sym) === getindex &&
478484
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
479-
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
485+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing
486+
return nothing
487+
else
488+
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
489+
end
480490
else
481491
nothing
482492
end
@@ -494,7 +504,12 @@ end
494504

495505
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
496506
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
497-
return parameter_index(ic, sym)
507+
idx = parameter_index(ic, sym)
508+
if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
509+
return nothing
510+
else
511+
return idx
512+
end
498513
end
499514
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
500515
if idx !== nothing
@@ -507,6 +522,67 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
507522
return nothing
508523
end
509524

525+
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
526+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
527+
is_timeseries_parameter(ic, sym)
528+
end
529+
530+
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
531+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
532+
timeseries_parameter_index(ic, sym)
533+
end
534+
535+
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
536+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
537+
allvars = vars(sym; op = Symbolics.Operator)
538+
ts_idxs = Set{Int}()
539+
for var in allvars
540+
var = unwrap(var)
541+
# FIXME: Shouldn't have to shift systems
542+
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
543+
var = only(arguments(var))
544+
end
545+
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
546+
ts_idx === nothing && continue
547+
push!(ts_idxs, ts_idx[1])
548+
end
549+
if length(ts_idxs) == 1
550+
ts_idx = only(ts_idxs)
551+
else
552+
ts_idx = nothing
553+
end
554+
rawobs = build_explicit_observed_function(
555+
sys, sym; param_only = true, return_inplace = true)
556+
if rawobs isa Tuple
557+
if is_time_dependent(sys)
558+
obsfn = let oop = rawobs[1], iip = rawobs[2]
559+
f1a(p::MTKParameters, t) = oop(p..., t)
560+
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
561+
end
562+
else
563+
obsfn = let oop = rawobs[1], iip = rawobs[2]
564+
f1b(p::MTKParameters) = oop(p...)
565+
f1b(out, p::MTKParameters) = iip(out, p...)
566+
end
567+
end
568+
else
569+
if is_time_dependent(sys)
570+
obsfn = let rawobs = rawobs
571+
f2a(p::MTKParameters, t) = rawobs(p..., t)
572+
end
573+
else
574+
obsfn = let rawobs = rawobs
575+
f2b(p::MTKParameters) = rawobs(p...)
576+
end
577+
end
578+
end
579+
else
580+
ts_idx = nothing
581+
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
582+
end
583+
return ParameterObservedFunction(ts_idx, obsfn)
584+
end
585+
510586
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
511587
return full_parameters(sys)
512588
end
@@ -524,7 +600,7 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys
524600
end
525601

526602
function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
527-
return !is_variable(sys, sym) && !is_parameter(sys, sym) &&
603+
return !is_variable(sys, sym) && parameter_index(sys, sym) === nothing &&
528604
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
529605
end
530606

0 commit comments

Comments
 (0)