Skip to content

Commit 3ac5867

Browse files
refactor: enforce non-positive shifts, update documentation
1 parent 5d7693e commit 3ac5867

File tree

9 files changed

+76
-72
lines changed

9 files changed

+76
-72
lines changed

docs/src/tutorials/discrete_system.md

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,30 @@ end
1515
@constants h = 1
1616
@variables S(t) I(t) R(t)
1717
k = ShiftIndex(t)
18-
infection = rate_to_proportion(β * c * I / (S * h + I + R), δt * h) * S
19-
recovery = rate_to_proportion(γ * h, δt) * I
18+
infection = rate_to_proportion(β * c * I(k-1) / (S(k-1) * h + I(k-1) + R(k-1)), δt * h) * S(k-1)
19+
recovery = rate_to_proportion(γ * h, δt) * I(k-1)
2020
2121
# Equations
22-
eqs = [S(k + 1) ~ S - infection * h,
23-
I(k + 1) ~ I + infection - recovery,
24-
R(k + 1) ~ R + recovery]
22+
eqs = [S(k) ~ S(k-1) - infection * h,
23+
I(k) ~ I(k-1) + infection - recovery,
24+
R(k) ~ R(k-1) + recovery]
2525
@mtkbuild sys = DiscreteSystem(eqs, t)
2626
27-
u0 = [S => 990.0, I => 10.0, R => 0.0]
27+
u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
2828
p = [β => 0.05, c => 10.0, γ => 0.25, δt => 0.1]
2929
tspan = (0.0, 100.0)
3030
prob = DiscreteProblem(sys, u0, tspan, p)
3131
sol = solve(prob, FunctionMap())
3232
```
33+
34+
All shifts must be negative. If default values are provided, they are treated as the value
35+
for the variable at the previous timestep. For example, consider the following system to
36+
generate the Fibonacci series:
37+
38+
```@example discrete
39+
@variables x(t) = 1.0
40+
@mtkbuild sys = DiscreteSystem([x ~ x(k-1) + x(k-2)], t)
41+
```
42+
43+
Note that the default value is treated as the initial value of `x(k-1)`. The value for
44+
`x(k-2)` must be provided during problem construction.

src/discretedomain.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ function (D::Shift)(x::Num, allow_zero = false)
3838
vt = value(x)
3939
if istree(vt)
4040
op = operation(vt)
41-
if op isa Shift
41+
if op isa Sample
42+
error("Cannot shift a `Sample`. Create a variable to represent the sampled value and shift that instead")
43+
elseif op isa Shift
4244
if D.t === nothing || isequal(D.t, op.t)
4345
arg = arguments(vt)[1]
4446
newsteps = D.steps + op.steps

src/structural_transformation/utils.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -449,21 +449,4 @@ function simplify_shifts(var)
449449
end
450450
end
451451
return Postwalk(PassThrough(r))(var)
452-
while ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
453-
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
454-
op1 = operation(var)
455-
vv1 = arguments(var)[1]
456-
op2 = operation(vv1)
457-
vv2 = arguments(vv1)[1]
458-
s1 = op1.steps
459-
s2 = op2.steps
460-
t1 = op1.t
461-
t2 = op2.t
462-
if t1 === nothing
463-
var = Shift(t2, s1 + s2)(vv2)
464-
else
465-
var = Shift(t1, s1 + s2)(vv2)
466-
end
467-
end
468-
return var
469452
end

src/systems/clock_inference.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,23 @@ function generate_discrete_affect(
180180
disc_to_cont_idxs = Int[]
181181
end
182182
for v in inputs[continuous_id]
183-
vv = arguments(v)[1]
184-
if vv in fullvars
185-
push!(needed_disc_to_cont_obs, vv)
183+
_v = arguments(v)[1]
184+
if _v in fullvars
185+
push!(needed_disc_to_cont_obs, _v)
186+
push!(disc_to_cont_idxs, param_to_idx[v])
187+
end
188+
189+
# In the above case, `_v` was in `observed(sys)`
190+
# It may also be in `unknowns(sys)`, in which case it
191+
# will be shifted back by one step
192+
if istree(v) && (op = operation(v)) isa Shift
193+
_v = arguments(_v)[1]
194+
_v = Shift(op.t, op.steps - 1)(_v)
195+
else
196+
_v = Shift(get_iv(sys), -1)(_v)
197+
end
198+
if _v in fullvars
199+
push!(needed_disc_to_cont_obs, _v)
186200
push!(disc_to_cont_idxs, param_to_idx[v])
187201
end
188202
end
@@ -198,6 +212,7 @@ function generate_discrete_affect(
198212
throw = false,
199213
expression = true,
200214
output_type = SVector,
215+
op = Shift,
201216
ps = reorder_parameters(osys, full_parameters(sys)))
202217
ni = length(input)
203218
ns = length(unknowns(sys))
@@ -213,7 +228,7 @@ function generate_discrete_affect(
213228
get_iv(sys)
214229
],
215230
[],
216-
let_block)
231+
let_block) |> toexpr
217232
if use_index_cache
218233
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
219234
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,14 +378,15 @@ function build_explicit_observed_function(sys, ts;
378378
checkbounds = true,
379379
drop_expr = drop_expr,
380380
ps = full_parameters(sys),
381+
op = Differential,
381382
throw = true)
382383
if (isscalar = !(ts isa AbstractVector))
383384
ts = [ts]
384385
end
385386
ts = unwrap.(Symbolics.scalarize(ts))
386387

387388
vars = Set()
388-
foreach(Base.Fix1(vars!, vars), ts)
389+
foreach(v -> vars!(vars, v; op), ts)
389390
ivs = independent_variables(sys)
390391
dep_vars = scalarize(setdiff(vars, ivs))
391392

src/systems/systemstructure.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ function TearingState(sys; quick_cancel = false, check = true)
350350
for var in fullvars
351351
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
352352
steps = operation(var).steps
353+
if steps > 0
354+
error("Only non-positive shifts allowed. Found $var with a shift of $steps")
355+
end
353356
v = arguments(var)[1]
354357
lowest_shift[v] = min(get(lowest_shift, v, 0), steps)
355358
end
@@ -369,7 +372,7 @@ function TearingState(sys; quick_cancel = false, check = true)
369372
if lshift < 0
370373
defs = ModelingToolkit.get_defaults(sys)
371374
if (_val = get(defs, var, nothing)) !== nothing
372-
defs[Shift(tt, lshift)(v)] = _val
375+
defs[Shift(tt, -1)(v)] = _val
373376
end
374377
end
375378
else
@@ -387,14 +390,13 @@ function TearingState(sys; quick_cancel = false, check = true)
387390
end
388391
end
389392
end
390-
391393
# sort `fullvars` such that the mass matrix is as diagonal as possible.
392394
dervaridxs = collect(dervaridxs)
393395
sorted_fullvars = OrderedSet(fullvars[dervaridxs])
394396
var_to_old_var = Dict(zip(fullvars, fullvars))
395397
for dervaridx in dervaridxs
396398
dervar = fullvars[dervaridx]
397-
diffvar = var_to_old_var[lower_order_var(dervar)]
399+
diffvar = var_to_old_var[lower_order_var(dervar, iv)]
398400
if !(diffvar in sorted_fullvars)
399401
push!(sorted_fullvars, diffvar)
400402
end
@@ -416,7 +418,7 @@ function TearingState(sys; quick_cancel = false, check = true)
416418
var_to_diff = DiffGraph(nvars, true)
417419
for dervaridx in dervaridxs
418420
dervar = fullvars[dervaridx]
419-
diffvar = lower_order_var(dervar)
421+
diffvar = lower_order_var(dervar, iv)
420422
diffvaridx = var2idx[diffvar]
421423
push!(diffvars, diffvar)
422424
var_to_diff[diffvaridx] = dervaridx
@@ -438,7 +440,7 @@ function TearingState(sys; quick_cancel = false, check = true)
438440
Any[])
439441
end
440442

441-
function lower_order_var(dervar)
443+
function lower_order_var(dervar, t)
442444
if isdifferential(dervar)
443445
diffvar = arguments(dervar)[1]
444446
elseif ModelingToolkit.isoperator(dervar, ModelingToolkit.Shift)
@@ -451,8 +453,7 @@ function lower_order_var(dervar)
451453
diffvar = vv
452454
end
453455
else
454-
iv = only(arguments(dervar))
455-
return Shift(iv, -1)(dervar)
456+
return Shift(t, -1)(dervar)
456457
end
457458
diffvar
458459
end

test/clock.jl

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -89,40 +89,31 @@ d = Clock(t, dt)
8989

9090
@info "Testing shift normalization"
9191
dt = 0.1
92-
@variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t)
92+
@variables x(t) y(t) u(t) yd(t) ud(t)
9393
@parameters kp
9494
d = Clock(t, dt)
9595
k = ShiftIndex(d)
9696

9797
eqs = [yd ~ Sample(t, dt)(y)
98-
ud ~ kp * (r - yd) + z(k)
99-
r ~ 1.0
98+
ud ~ kp * yd + ud(k - 2)
10099

101100
# plant (time continuous part)
102101
u ~ Hold(ud)
103102
D(x) ~ -x + u
104-
y ~ x
105-
z(k + 2) ~ z(k) + yd
106-
#=
107-
z(k + 2) ~ z(k) + yd
108-
=>
109-
z′(k + 1) ~ z(k) + yd
110-
z(k + 1) ~ z′(k)
111-
=#
112-
]
103+
y ~ x]
113104
@named sys = ODESystem(eqs, t)
114105
ss = structural_simplify(sys);
115106

116107
Tf = 1.0
117108
prob = ODEProblem(ss, [x => 0.0, y => 0.0], (0.0, Tf),
118-
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
119-
@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
109+
[kp => 1.0; ud(k - 1) => 2.0; ud(k - 2) => 2.0])
110+
@test sort(vcat(prob.p...)) == [0, 1.0, 2.0, 2.0, 2.0] # yd, Hold(ud), kp, ud(k - 1)
120111
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
121112

122113
ss_nosplit = structural_simplify(sys; split = false)
123114
prob_nosplit = ODEProblem(ss_nosplit, [x => 0.0, y => 0.0], (0.0, Tf),
124-
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
125-
@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 3.0, 4.0] # yd, kp, z(k+1), z(k), ud
115+
[kp => 1.0; ud(k - 1) => 2.0; ud(k - 2) => 2.0])
116+
@test sort(prob_nosplit.p) == [0, 1.0, 2.0, 2.0, 2.0] # yd, Hold(ud), kp, ud(k - 1)
126117
sol_nosplit = solve(prob_nosplit, Tsit5(), kwargshandle = KeywordArgSilent)
127118
# For all inputs in parameters, just initialize them to 0.0, and then set them
128119
# in the callback.
@@ -134,30 +125,23 @@ function foo!(du, u, p, t)
134125
du[1] = -x + ud
135126
end
136127
function affect!(integrator, saved_values)
137-
z_t, z = integrator.p[3], integrator.p[4]
138128
yd = integrator.u[1]
139129
kp = integrator.p[1]
140-
r = 1.0
130+
ud = integrator.p[2]
131+
udd = integrator.p[3]
141132

142133
push!(saved_values.t, integrator.t)
143-
push!(saved_values.saveval, [z_t, z])
144-
145-
# Update the discrete state
146-
z_t, z = z + yd, z_t
147-
# @show z_t, z
148-
integrator.p[3] = z_t
149-
integrator.p[4] = z
134+
push!(saved_values.saveval, [ud, udd])
150135

151-
ud = kp * (r - yd) + z
152-
integrator.p[2] = ud
136+
integrator.p[2] = kp * yd + udd
137+
integrator.p[3] = ud
153138

154139
nothing
155140
end
156141
saved_values = SavedValues(Float64, Vector{Float64})
157142
cb = PeriodicCallback(Base.Fix2(affect!, saved_values), 0.1)
158-
# kp ud z_t z
159-
prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 4.0, 2.0, 3.0], callback = cb)
160-
# ud initializes to kp * (r - yd) + z = 1 * (1 - 0) + 3 = 4
143+
# kp ud
144+
prob = ODEProblem(foo!, [0.0], (0.0, Tf), [1.0, 2.0, 2.0], callback = cb)
161145
sol2 = solve(prob, Tsit5())
162146
@test sol.u == sol2.u
163147
@test sol_nosplit.u == sol2.u
@@ -217,7 +201,7 @@ end
217201
function filt(; name)
218202
@variables x(t)=0 u(t)=0 y(t)=0
219203
a = 1 / exp(dt)
220-
eqs = [x(k + 1) ~ a * x + (1 - a) * u(k)
204+
eqs = [x ~ a * x(k - 1) + (1 - a) * u(k - 1)
221205
y ~ x]
222206
ODESystem(eqs, t, name = name)
223207
end
@@ -487,9 +471,11 @@ k = ShiftIndex(c)
487471
@variables begin
488472
count(t) = 0
489473
u(t) = 0
474+
ud(t) = 0
490475
end
491476
@equations begin
492-
count(k + 1) ~ Sample(c)(u)
477+
ud ~ Sample(c)(u)
478+
count ~ ud(k - 1)
493479
end
494480
end
495481

test/discrete_system.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ using ModelingToolkit, Test
77
using ModelingToolkit: t_nounits as t
88
using ModelingToolkit: get_metadata, MTKParameters
99

10+
# Make sure positive shifts error
11+
@variables x(t)
12+
k = ShiftIndex(t)
13+
@test_throws ErrorException @mtkbuild sys = DiscreteSystem([x(k + 1) ~ x + x(k - 1)], t)
14+
1015
@inline function rate_to_proportion(r, t)
1116
1 - exp(-r * t)
1217
end;
@@ -15,7 +20,6 @@ end;
1520
@parameters c nsteps δt β γ
1621
@constants h = 1
1722
@variables S(t) I(t) R(t)
18-
k = ShiftIndex(t)
1923
infection = rate_to_proportion(
2024
β * c * I(k - 1) / (S(k - 1) * h + I(k - 1) + R(k - 1)), δt * h) * S(k - 1)
2125
recovery = rate_to_proportion* h, δt) * I(k - 1)

test/parameter_dependencies.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ end
6161
u ~ Hold(ud)
6262
D(x) ~ -x + u
6363
y ~ x
64-
z(k + 2) ~ z(k) + yd]
64+
z(k) ~ z(k - 2) + yd(k - 2)]
6565
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp])
6666

6767
Tf = 1.0
6868
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
69-
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
69+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0])
7070
@test_nowarn solve(prob, Tsit5(); kwargshandle = KeywordArgSilent)
7171

7272
@mtkbuild sys = ODESystem(eqs, t; parameter_dependencies = [kq => 2kp],
@@ -77,7 +77,7 @@ end
7777
@test prob.ps[kq] == 2.0
7878
@test_nowarn solve(prob, Tsit5(), kwargshandle = KeywordArgSilent)
7979
prob = ODEProblem(sys, [x => 0.0, y => 0.0], (0.0, Tf),
80-
[kp => 1.0; z => 3.0; z(k + 1) => 2.0])
80+
[kp => 1.0; z(k - 1) => 3.0; yd(k - 1) => 0.0])
8181
integ = init(prob, Tsit5(), kwargshandle = KeywordArgSilent)
8282
@test integ.ps[kp] == 1.0
8383
@test integ.ps[kq] == 2.0

0 commit comments

Comments
 (0)