Skip to content

Commit e560401

Browse files
refactor: update implementation of discrete save interface
1 parent 4135f00 commit e560401

File tree

6 files changed

+79
-94
lines changed

6 files changed

+79
-94
lines changed

src/systems/abstractsystem.jl

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
447447
sym = unwrap(sym)
448448
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
449449
return sym isa ParameterIndex || is_parameter(ic, sym) ||
450-
iscall(sym) && operation(sym) === getindex &&
450+
iscall(sym) &&
451+
operation(sym) === getindex &&
451452
is_parameter(ic, first(arguments(sym)))
452453
end
453454
if unwrap(sym) isa Int
@@ -526,34 +527,19 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Sym
526527
end
527528

528529
function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem, sym)
530+
is_time_dependent(sys) || return false
529531
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
530532
is_timeseries_parameter(ic, sym)
531533
end
532534

533535
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
536+
is_time_dependent(sys) || return nothing
534537
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
535538
timeseries_parameter_index(ic, sym)
536539
end
537540

538541
function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
539542
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
540-
allvars = vars(sym; op = Symbolics.Operator)
541-
ts_idxs = Set{Int}()
542-
for var in allvars
543-
var = unwrap(var)
544-
# FIXME: Shouldn't have to shift systems
545-
if istree(var) && (op = operation(var)) isa Shift && op.steps == 1
546-
var = only(arguments(var))
547-
end
548-
ts_idx = check_index_map(ic.discrete_idx, unwrap(var))
549-
ts_idx === nothing && continue
550-
push!(ts_idxs, ts_idx[1])
551-
end
552-
if length(ts_idxs) == 1
553-
ts_idx = only(ts_idxs)
554-
else
555-
ts_idx = nothing
556-
end
557543
rawobs = build_explicit_observed_function(
558544
sys, sym; param_only = true, return_inplace = true)
559545
if rawobs isa Tuple
@@ -580,10 +566,44 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
580566
end
581567
end
582568
else
583-
ts_idx = nothing
584569
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
585570
end
586-
return ParameterObservedFunction(ts_idx, obsfn)
571+
return obsfn
572+
end
573+
574+
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
575+
if is_variable(sys, sym)
576+
push!(ts_idxs, ContinuousTimeseries())
577+
elseif is_timeseries_parameter(sys, sym)
578+
push!(ts_idxs, timeseries_parameter_index(sys, sym).timeseries_idx)
579+
end
580+
end
581+
# Need this to avoid ambiguity with the array case
582+
for traitT in [
583+
ScalarSymbolic,
584+
ArraySymbolic
585+
]
586+
@eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym)
587+
allsyms = vars(sym; op = Symbolics.Operator)
588+
foreach(allsyms) do s
589+
_all_ts_idxs!(ts_idxs, sys, s)
590+
end
591+
end
592+
end
593+
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym::AbstractArray)
594+
foreach(sym) do s
595+
_all_ts_idxs!(ts_idxs, sys, s)
596+
end
597+
end
598+
_all_ts_idxs!(ts_idxs, sys, sym) = _all_ts_idxs!(ts_idxs, NotSymbolic(), sys, sym)
599+
600+
function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSystem, sym)
601+
if !is_time_dependent(sys)
602+
return Set()
603+
end
604+
ts_idxs = Set()
605+
_all_ts_idxs!(ts_idxs, sys, sym)
606+
return ts_idxs
587607
end
588608

589609
function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)

src/systems/index_cache.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem)
113113
error("Discrete subsystem $i input $inp is not a parameter")
114114
disc_clocks[inp] = i
115115
disc_clocks[default_toterm(inp)] = i
116-
if hasname(inp) && (!istree(inp) || operation(inp) !== getindex)
116+
if hasname(inp) && (!iscall(inp) || operation(inp) !== getindex)
117117
disc_clocks[getname(inp)] = i
118118
disc_clocks[default_toterm(inp)] = i
119119
end
@@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem)
126126
error("Discrete subsystem $i unknown $sym is not a parameter")
127127
disc_clocks[sym] = i
128128
disc_clocks[default_toterm(sym)] = i
129-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
129+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
130130
disc_clocks[getname(sym)] = i
131131
disc_clocks[getname(default_toterm(sym))] = i
132132
end
@@ -138,13 +138,13 @@ function IndexCache(sys::AbstractSystem)
138138
# FIXME: This shouldn't be necessary
139139
eq.rhs === -0.0 && continue
140140
sym = eq.lhs
141-
if istree(sym) && operation(sym) == Shift(t, 1)
141+
if iscall(sym) && operation(sym) == Shift(t, 1)
142142
sym = only(arguments(sym))
143143
end
144144
disc_clocks[sym] = i
145145
disc_clocks[sym] = i
146146
disc_clocks[default_toterm(sym)] = i
147-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
147+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
148148
disc_clocks[getname(sym)] = i
149149
disc_clocks[getname(default_toterm(sym))] = i
150150
end
@@ -153,7 +153,7 @@ function IndexCache(sys::AbstractSystem)
153153

154154
for par in inputs[continuous_id]
155155
is_parameter(sys, par) || error("Discrete subsystem input is not a parameter")
156-
istree(par) && operation(par) isa Hold ||
156+
iscall(par) && operation(par) isa Hold ||
157157
error("Continuous subsystem input is not a Hold")
158158
if haskey(disc_clocks, par)
159159
sym = par
@@ -176,7 +176,7 @@ function IndexCache(sys::AbstractSystem)
176176
disc_clocks[affect.lhs] = user_affect_clock
177177
disc_clocks[default_toterm(affect.lhs)] = user_affect_clock
178178
if hasname(affect.lhs) &&
179-
(!istree(affect.lhs) || operation(affect.lhs) !== getindex)
179+
(!iscall(affect.lhs) || operation(affect.lhs) !== getindex)
180180
disc_clocks[getname(affect.lhs)] = user_affect_clock
181181
disc_clocks[getname(default_toterm(affect.lhs))] = user_affect_clock
182182
end
@@ -190,7 +190,7 @@ function IndexCache(sys::AbstractSystem)
190190
disc = unwrap(disc)
191191
disc_clocks[disc] = user_affect_clock
192192
disc_clocks[default_toterm(disc)] = user_affect_clock
193-
if hasname(disc) && (!istree(disc) || operation(disc) !== getindex)
193+
if hasname(disc) && (!iscall(disc) || operation(disc) !== getindex)
194194
disc_clocks[getname(disc)] = user_affect_clock
195195
disc_clocks[getname(default_toterm(disc))] = user_affect_clock
196196
end
@@ -245,7 +245,7 @@ function IndexCache(sys::AbstractSystem)
245245
for (j, sym) in enumerate(buffer[btype])
246246
disc_idxs[sym] = (clockidx, i, j)
247247
disc_idxs[default_toterm(sym)] = (clockidx, i, j)
248-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
248+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
249249
disc_idxs[getname(sym)] = (clockidx, i, j)
250250
disc_idxs[getname(default_toterm(sym))] = (clockidx, i, j)
251251
end
@@ -256,7 +256,7 @@ function IndexCache(sys::AbstractSystem)
256256
haskey(disc_idxs, sym) && continue
257257
disc_idxs[sym] = (clockid, 0, 0)
258258
disc_idxs[default_toterm(sym)] = (clockid, 0, 0)
259-
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
259+
if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex)
260260
disc_idxs[getname(sym)] = (clockid, 0, 0)
261261
disc_idxs[getname(default_toterm(sym))] = (clockid, 0, 0)
262262
end

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ function SymbolicIndexingInterface.set_parameter!(
363363
if validate_size && size(val) !== size(p.discrete[i][j][k])
364364
throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val)))
365365
end
366-
p.discrete[i][j][k][l...] = val
366+
p.discrete[i][j][k] = val
367367
else
368368
p.discrete[i][j][k][l...] = val
369369
end
@@ -563,7 +563,8 @@ end
563563
Base.size(::NestedGetIndex) = ()
564564

565565
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
566-
ps::MTKParameters, args::Pair{A, B}...) where {A, B <: NestedGetIndex}
566+
::AbstractSystem, ps::MTKParameters, args::Pair{A, B}...) where {
567+
A, B <: NestedGetIndex}
567568
for (i, val) in args
568569
ps.discrete[i] = val.x
569570
end

test/mtkparameters.jl

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using StaticArrays: SizedVector
56
using OrdinaryDiffEq
67
using ForwardDiff
78
using JET
@@ -292,29 +293,10 @@ end
292293
end
293294

294295
# Parameter timeseries
295-
# dt = 0.1
296-
# dt2 = 0.2
297-
# @variables x(t)=0 y(t)=0 u(t)=0 yd1(t)=0 ud1(t)=0 yd2(t)=0 ud2(t)=0
298-
# @parameters kp=1 r=1
299-
300-
# eqs = [
301-
# # controller (time discrete part `dt=0.1`)
302-
# yd1 ~ Sample(t, dt)(y)
303-
# ud1 ~ kp * (r - yd1)
304-
# # controller (time discrete part `dt=0.2`)
305-
# yd2 ~ Sample(t, dt2)(y)
306-
# ud2 ~ kp * (r - yd2)
307-
308-
# # plant (time continuous part)
309-
# u ~ Hold(ud1) + Hold(ud2)
310-
# D(x) ~ -x + u
311-
# y ~ x]
312-
313-
# @mtkbuild cl = ODESystem(eqs, t)
314-
ps = MTKParameters(([1.0, 1.0],), SizedArray{2}([([0.0, 0.0],), ([0.0, 0.0],)]), (), (), (), nothing, nothing)
315-
# ps = MTKParameters(cl, [kp => 1.0])
296+
ps = MTKParameters(([1.0, 1.0],), SizedVector{2}([([0.0, 0.0],), ([0.0, 0.0],)]),
297+
(), (), (), nothing, nothing)
316298
with_updated_parameter_timeseries_values(
317-
ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
299+
sys, ps, 1 => ModelingToolkit.NestedGetIndex(([5.0, 10.0],)))
318300
@test ps.discrete[1][1] == [5.0, 10.0]
319301
with_updated_parameter_timeseries_values(
320302
ps, 1 => ModelingToolkit.NestedGetIndex(([3.0, 30.0],)),
@@ -324,27 +306,9 @@ with_updated_parameter_timeseries_values(
324306
@test SciMLBase.get_saveable_values(ps, 1).x == ps.discrete[1]
325307

326308
# With multiple types and clocks
327-
# @variables x(t) xd1(t) xd2(t) flag(t)::Bool yd1(t) yd2(t) yc1(t) yc2(t)
328-
# dt = 0.1
329-
# k1 = ShiftIndex(t, dt)
330-
# ssc = ModelingToolkit.SolverStepClock(t)
331-
# k2 = ShiftIndex(ssc)
332-
333-
# eqs = [
334-
# flag ~ ~flag(k1 - 1),
335-
# xd1 ~ Sample(t, dt)(x),
336-
# yd1 ~ ifelse(flag, xd1, yd1(k1 - 1)), xd2 ~ Sample(ssc)(x),
337-
# yd2 ~ yd2(k2 - 1) + xd2, yc1 ~ Hold(yd1),
338-
# yc2 ~ Hold(yd2),
339-
# D(x) ~ yc1 + yc2
340-
# ]
341-
# @mtkbuild sys = ODESystem(eqs, t)
342-
# ps = MTKParameters(sys,
343-
# [flag => true, yd1 => ifelse(flag, Sample(t, dt)(x), 1.0),
344-
# yd2 => 2.0 + Sample(ssc)(x), Sample(t, dt)(x) => x,
345-
# Sample(ssc)(x) => x, Hold(yd1) => yd1, Hold(yd2) => yd2],
346-
# [x => 3.0])
347-
ps = MTKParameters((), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]), (), (), (), nothing, nothing)
309+
ps = MTKParameters(
310+
(), SizedVector{2}([([1.0, 2.0, 3.0], falses(1)), ([4.0, 5.0, 6.0], falses(0))]),
311+
(), (), (), nothing, nothing)
348312
@test SciMLBase.get_saveable_values(ps, 1).x isa Tuple{Vector{Float64}, BitVector}
349313
# tsidx1 = timeseries_parameter_index(sys, flag).timeseries_idx
350314
# tsidx2 = 3 - tsidx1
@@ -355,6 +319,6 @@ tsidx2 = 2
355319
@test length(ps.discrete[tsidx2][1]) == 3
356320
@test length(ps.discrete[tsidx2][2]) == 0
357321
with_updated_parameter_timeseries_values(
358-
ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
322+
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
359323
@test ps.discrete[tsidx1][1] == [10.0, 11.0, 12.0]
360324
@test ps.discrete[tsidx1][2][] == false

test/parameter_dependencies.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,21 @@ end
173173
@test_skip begin
174174
Tf = 1.0
175175
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
176-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
176+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
177+
yd(k - 2) => 2.0])
177178
@test_nowarn solve(prob, Tsit5())
178179

179180
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
180181
discrete_events = [[0.5] => [kp ~ 2.0]])
181182
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
182-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
183+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
184+
yd(k - 2) => 2.0])
183185
@test prob.ps[kp] == 1.0
184186
@test prob.ps[kq] == 2.0
185187
@test_nowarn solve(prob, Tsit5())
186188
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
187-
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0; yd(k - 2) => 2.0])
189+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0; z(k - 2) => 4.0;
190+
yd(k - 2) => 2.0])
188191
integ = init(prob, Tsit5())
189192
@test integ.ps[kp] == 1.0
190193
@test integ.ps[kq] == 2.0

test/symbolic_indexing_interface.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ using SciMLStructures: Tunable
3838
odesys = complete(odesys)
3939
@test default_values(odesys)[xy] == 3.0
4040
pobs = parameter_observed(odesys, a + b)
41-
@test pobs.timeseries_idx === nothing
42-
@test pobs.observed_fn(
41+
@test isempty(get_all_timeseries_indexes(odesys, a + b))
42+
@test pobs(
4343
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) 3.0
4444
pobs = parameter_observed(odesys, [a + b, a - b])
45-
@test pobs.timeseries_idx === nothing
46-
@test pobs.observed_fn(
45+
@test isempty(get_all_timeseries_indexes(odesys, [a + b, a - b]))
46+
@test pobs(
4747
ModelingToolkit.MTKParameters(odesys, [a => 1.0, b => 2.0]), 0.0) [3.0, -1.0]
4848
end
4949

@@ -102,11 +102,11 @@ end
102102
@test !is_time_dependent(ns)
103103
ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0])
104104
pobs = parameter_observed(ns, σ + ρ)
105-
@test pobs.timeseries_idx === nothing
106-
@test pobs.observed_fn(ps) == 3.0
105+
@test isempty(get_all_timeseries_indexes(ns, σ + ρ))
106+
@test pobs(ps) == 3.0
107107
pobs = parameter_observed(ns, [σ + ρ, ρ + β])
108-
@test pobs.timeseries_idx === nothing
109-
@test pobs.observed_fn(ps) == [3.0, 5.0]
108+
@test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β]))
109+
@test pobs(ps) == [3.0, 5.0]
110110
end
111111

112112
@testset "PDESystem" begin
@@ -147,6 +147,11 @@ end
147147
domains = [t (0.0, 1.0),
148148
x (0.0, 1.0)]
149149

150+
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
151+
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)
152+
153+
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)
154+
150155
@test isequal(pdesys.ps, [h])
151156
@test isequal(parameter_symbols(pdesys), [h])
152157
@test isequal(parameters(pdesys), [h])
@@ -179,12 +184,4 @@ get_dep = @test_nowarn getu(prob, 2p1)
179184
@test getu(prob, z)(prob) == getu(prob, :z)(prob)
180185
@test getu(prob, p1)(prob) == getu(prob, :p1)(prob)
181186
@test getu(prob, p2)(prob) == getu(prob, :p2)(prob)
182-
analytic = [u(t, x) ~ -h * x * (x - 1) * sin(x) * exp(-2 * h * t)]
183-
analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] * t)
184-
185-
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u], [h], analytic = analytic)
186-
187-
@test isequal(pdesys.ps, [h])
188-
@test isequal(parameter_symbols(pdesys), [h])
189-
@test isequal(parameters(pdesys), [h])
190187
end

0 commit comments

Comments
 (0)