Skip to content

Commit 522ad07

Browse files
refactor: store parameters from different clock partitions separately
1 parent a54b391 commit 522ad07

File tree

3 files changed

+341
-86
lines changed

3 files changed

+341
-86
lines changed

src/systems/abstractsystem.jl

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ end
415415
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
416416
sym = unwrap(sym)
417417
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
418-
return is_parameter(ic, sym) ||
418+
return sym isa ParameterIndex || is_parameter(ic, sym) ||
419419
istree(sym) && operation(sym) === getindex &&
420420
is_parameter(ic, first(arguments(sym)))
421421
end
@@ -441,11 +441,21 @@ end
441441
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
442442
sym = unwrap(sym)
443443
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
444-
return if (idx = parameter_index(ic, sym)) !== nothing
445-
idx
444+
return if sym isa ParameterIndex
445+
sym
446+
elseif (idx = parameter_index(ic, sym)) !== nothing
447+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
448+
return nothing
449+
else
450+
idx
451+
end
446452
elseif istree(sym) && operation(sym) === getindex &&
447453
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
448-
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
454+
if idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == nothing
455+
return nothing
456+
else
457+
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
458+
end
449459
else
450460
nothing
451461
end
@@ -463,7 +473,12 @@ end
463473

464474
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
465475
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
466-
return parameter_index(ic, sym)
476+
idx = parameter_index(ic, sym)
477+
if idx === nothing || idx.portion isa SciMLStructures.Discrete && idx.idx[2] == idx.idx[3] == 0
478+
return nothing
479+
else
480+
return idx
481+
end
467482
end
468483
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
469484
if idx !== nothing
@@ -475,6 +490,67 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
475490
return nothing
476491
end
477492

493+
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
494+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
495+
is_timeseries_parameter(ic, sym)
496+
end
497+
498+
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
499+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
500+
timeseries_parameter_index(ic, sym)
501+
end
502+
503+
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
504+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
505+
allvars = vars(sym; op = Symbolics.Operator)
506+
ts_idxs = Set{Int}()
507+
for var in allvars
508+
var = unwrap(var)
509+
# FIXME: Shouldn't have to shift systems
510+
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
511+
var = only(arguments(var))
512+
end
513+
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
514+
ts_idx === nothing && continue
515+
push!(ts_idxs, ts_idx[1])
516+
end
517+
if length(ts_idxs) == 1
518+
ts_idx = only(ts_idxs)
519+
else
520+
ts_idx = nothing
521+
end
522+
rawobs = build_explicit_observed_function(
523+
sys, sym; param_only = true, return_inplace = true)
524+
if rawobs isa Tuple
525+
if is_time_dependent(sys)
526+
obsfn = let oop = rawobs[1], iip = rawobs[2]
527+
f1a(p::MTKParameters, t) = oop(p..., t)
528+
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
529+
end
530+
else
531+
obsfn = let oop = rawobs[1], iip = rawobs[2]
532+
f1b(p::MTKParameters) = oop(p...)
533+
f1b(out, p::MTKParameters) = iip(out, p...)
534+
end
535+
end
536+
else
537+
if is_time_dependent(sys)
538+
obsfn = let rawobs = rawobs
539+
f2a(p::MTKParameters, t) = rawobs(p..., t)
540+
end
541+
else
542+
obsfn = let rawobs = rawobs
543+
f2b(p::MTKParameters) = rawobs(p...)
544+
end
545+
end
546+
end
547+
else
548+
ts_idx = nothing
549+
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
550+
end
551+
return ParameterObservedFunction(ts_idx, obsfn)
552+
end
553+
478554
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
479555
return full_parameters(sys)
480556
end
@@ -492,7 +568,7 @@ function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSys
492568
end
493569

494570
function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
495-
return !is_variable(sys, sym) && !is_parameter(sys, sym) &&
571+
return !is_variable(sys, sym) && parameter_index(sys, sym) === nothing &&
496572
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
497573
end
498574

0 commit comments

Comments
 (0)