Skip to content

Commit 71ebfdd

Browse files
feat: shift all discrete systems by 1 to fix correctness issues
1 parent a103690 commit 71ebfdd

File tree

10 files changed

+147
-83
lines changed

10 files changed

+147
-83
lines changed

src/structural_transformation/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ end
433433

434434
function simplify_shifts(var)
435435
ModelingToolkit.hasshift(var) || return var
436+
var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs)
436437
if isdoubleshift(var)
437438
op1 = operation(var)
438439
vv1 = arguments(var)[1]

src/systems/alias_elimination.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ function observed2graph(eqs, unknowns)
453453
lhs_j === nothing &&
454454
throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns."))
455455
assigns[i] = lhs_j
456-
vs = vars(eq.rhs)
456+
vs = vars(eq.rhs; op = Symbolics.Operator)
457457
for v in vs
458458
j = get(v2j, v, nothing)
459459
j !== nothing && add_edge!(graph, i, j)

src/systems/clock_inference.jl

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ function split_system(ci::ClockInference{S}) where {S}
133133
tss = similar(cid_to_eq, S)
134134
for (id, ieqs) in enumerate(cid_to_eq)
135135
ts_i = system_subset(ts, ieqs)
136-
@set! ts_i.structure.only_discrete = id != continuous_id
136+
if id != continuous_id
137+
ts_i = shift_discrete_system(ts_i)
138+
@set! ts_i.structure.only_discrete = true
139+
end
137140
tss[id] = ts_i
138141
end
139142
return tss, inputs, continuous_id, id_to_clock
@@ -148,7 +151,7 @@ function generate_discrete_affect(
148151
end
149152
use_index_cache = has_index_cache(osys) && get_index_cache(osys) !== nothing
150153
out = Sym{Any}(:out)
151-
appended_parameters = parameters(syss[continuous_id])
154+
appended_parameters = full_parameters(syss[continuous_id])
152155
offset = length(appended_parameters)
153156
param_to_idx = if use_index_cache
154157
Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
@@ -157,6 +160,7 @@ function generate_discrete_affect(
157160
Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
158161
end
159162
affect_funs = []
163+
init_funs = []
160164
svs = []
161165
clocks = TimeDomain[]
162166
for (i, (sys, input)) in enumerate(zip(syss, inputs))
@@ -183,47 +187,38 @@ function generate_discrete_affect(
183187
if _v in fullvars
184188
push!(needed_disc_to_cont_obs, _v)
185189
push!(disc_to_cont_idxs, param_to_idx[v])
190+
continue
186191
end
187192

188-
# In the above case, `_v` was in `observed(sys)`
189-
# It may also be in `unknowns(sys)`, in which case it
190-
# will be shifted back by one step
191-
if istree(v) && (op = operation(v)) isa Shift
192-
_v = arguments(_v)[1]
193-
_v = Shift(op.t, op.steps - 1)(_v)
194-
else
195-
_v = Shift(get_iv(sys), -1)(_v)
196-
end
193+
# If the held quantity is calculated through observed
194+
# it will be shifted forward by 1
195+
_v = Shift(get_iv(sys), 1)(_v)
197196
if _v in fullvars
198197
push!(needed_disc_to_cont_obs, _v)
199198
push!(disc_to_cont_idxs, param_to_idx[v])
199+
continue
200200
end
201201
end
202-
append!(appended_parameters, input, unknowns(sys))
202+
append!(appended_parameters, input)
203203
cont_to_disc_obs = build_explicit_observed_function(
204204
use_index_cache ? osys : syss[continuous_id],
205205
needed_cont_to_disc_obs,
206206
throw = false,
207207
expression = true,
208208
output_type = SVector)
209-
@set! sys.ps = appended_parameters
210209
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
211210
throw = false,
212211
expression = true,
213212
output_type = SVector,
214213
op = Shift,
215-
ps = reorder_parameters(osys, full_parameters(sys)))
214+
ps = reorder_parameters(osys, appended_parameters))
216215
ni = length(input)
217216
ns = length(unknowns(sys))
218217
disc = Func(
219218
[
220219
out,
221220
DestructuredArgs(unknowns(osys)),
222-
if use_index_cache
223-
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))
224-
else
225-
(DestructuredArgs(appended_parameters),)
226-
end...,
221+
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
227222
get_iv(sys)
228223
],
229224
[],
@@ -248,6 +243,36 @@ function generate_discrete_affect(
248243
end
249244
end
250245
empty_disc = isempty(disc_range)
246+
disc_init = if use_index_cache
247+
:(function (u, p, t)
248+
c2d_obs = $cont_to_disc_obs
249+
d2c_obs = $disc_to_cont_obs
250+
result = c2d_obs(u, p..., t)
251+
for (val, i) in zip(result, $cont_to_disc_idxs)
252+
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
253+
end
254+
255+
disc_state = Tuple($(parameter_values)(p, i) for i in $disc_range)
256+
result = d2c_obs(disc_state, p..., t)
257+
for (val, i) in zip(result, $disc_to_cont_idxs)
258+
# prevent multiple updates to dependents
259+
_set_parameter_unchecked!(p, val, i; update_dependent = false)
260+
end
261+
discretes, repack, _ = $(SciMLStructures.canonicalize)(
262+
$(SciMLStructures.Discrete()), p)
263+
repack(discretes) # to force recalculation of dependents
264+
end)
265+
else
266+
:(function (u, p, t)
267+
c2d_obs = $cont_to_disc_obs
268+
d2c_obs = $disc_to_cont_obs
269+
c2d_view = view(p, $cont_to_disc_idxs)
270+
d2c_view = view(p, $disc_to_cont_idxs)
271+
disc_unknowns = view(p, $disc_range)
272+
copyto!(c2d_view, c2d_obs(u, p, t))
273+
copyto!(d2c_view, d2c_obs(disc_unknowns, p, t))
274+
end)
275+
end
251276

252277
# @show disc_to_cont_idxs
253278
# @show cont_to_disc_idxs
@@ -270,9 +295,6 @@ function generate_discrete_affect(
270295
# TODO: find a way to do this without allocating
271296
disc = $disc
272297

273-
push!(saved_values.t, t)
274-
push!(saved_values.saveval, $save_vec)
275-
276298
# Write continuous into to discrete: handles `Sample`
277299
# Write discrete into to continuous
278300
# Update discrete unknowns
@@ -322,6 +344,10 @@ function generate_discrete_affect(
322344
:(copyto!(d2c_view, d2c_obs(disc_unknowns, p, t)))
323345
end
324346
)
347+
348+
push!(saved_values.t, t)
349+
push!(saved_values.saveval, $save_vec)
350+
325351
# @show "after d2c", p
326352
$(
327353
if use_index_cache
@@ -335,15 +361,20 @@ function generate_discrete_affect(
335361
end)
336362
sv = SavedValues(Float64, Vector{Float64})
337363
push!(affect_funs, affect!)
364+
push!(init_funs, disc_init)
338365
push!(svs, sv)
339366
end
340367
if eval_expression
341368
affects = map(affect_funs) do a
342369
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
343370
end
371+
inits = map(init_funs) do a
372+
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
373+
end
344374
else
345375
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
376+
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
346377
end
347378
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
348-
return affects, clocks, svs, appended_parameters, defaults
379+
return affects, inits, clocks, svs, appended_parameters, defaults
349380
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,12 +1039,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10391039
t = tspan !== nothing ? tspan[1] : tspan,
10401040
check_length, warn_initialize_determined, kwargs...)
10411041
cbs = process_events(sys; callback, kwargs...)
1042-
affects = []
1042+
inits = []
10431043
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1044-
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1044+
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
10451045
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
10461046
if clock isa Clock
1047-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1047+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1048+
final_affect = true)
10481049
elseif clock isa SolverStepClock
10491050
affect = DiscreteSaveAffect(affect, sv)
10501051
DiscreteCallback(Returns(true), affect,
@@ -1062,12 +1063,6 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10621063
else
10631064
cbs = CallbackSet(cbs, discrete_cbs...)
10641065
end
1065-
# initialize by running affects
1066-
dummy_saveval = (; t = [], saveval = [])
1067-
for affect! in affects
1068-
affect!(
1069-
(; u = u0, p = p, t = tspan !== nothing ? tspan[1] : tspan), dummy_saveval)
1070-
end
10711066
else
10721067
svs = nothing
10731068
end
@@ -1081,7 +1076,14 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
10811076
if svs !== nothing
10821077
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
10831078
end
1084-
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1079+
1080+
prob = ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
1081+
if !isempty(inits)
1082+
for init in inits
1083+
init(prob.u0, prob.p, tspan[1])
1084+
end
1085+
end
1086+
prob
10851087
end
10861088
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
10871089

@@ -1150,12 +1152,12 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11501152
h(p::MTKParameters, t) = h_oop(p..., t)
11511153
u0 = h(p, tspan[1])
11521154
cbs = process_events(sys; callback, kwargs...)
1153-
inits = []
11541155
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1155-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1156+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
11561157
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
11571158
if clock isa Clock
1158-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1159+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1160+
final_affect = true, initial_affect = true)
11591161
else
11601162
error("$clock is not a supported clock type.")
11611163
end
@@ -1181,13 +1183,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
11811183
if svs !== nothing
11821184
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
11831185
end
1184-
prob = DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1185-
if !isempty(inits)
1186-
for init in inits
1187-
init(prob.p, tspan[1])
1188-
end
1189-
end
1190-
prob
1186+
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
11911187
end
11921188

11931189
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1212,12 +1208,12 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
12121208
h(p, t) = h_oop(p, t)
12131209
u0 = h(p, tspan[1])
12141210
cbs = process_events(sys; callback, kwargs...)
1215-
inits = []
12161211
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
1217-
affects, inits, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
1212+
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(sys, dss...)
12181213
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
12191214
if clock isa Clock
1220-
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
1215+
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt;
1216+
final_affect = true, initial_affect = true)
12211217
else
12221218
error("$clock is not a supported clock type.")
12231219
end
@@ -1254,15 +1250,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
12541250
else
12551251
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
12561252
end
1257-
prob = SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1253+
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
12581254
noise_rate_prototype =
12591255
noise_rate_prototype, kwargs1..., kwargs...)
1260-
if !isempty(inits)
1261-
for init in inits
1262-
init(prob.p, tspan[1])
1263-
end
1264-
end
1265-
prob
12661256
end
12671257

12681258
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,13 +453,16 @@ function build_explicit_observed_function(sys, ts;
453453
if inputs !== nothing
454454
ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list
455455
end
456-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
457-
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
458-
elseif ps isa Tuple
456+
if ps isa Tuple
459457
ps = DestructuredArgs.(ps, inbounds = !checkbounds)
458+
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
459+
ps = DestructuredArgs.(reorder_parameters(get_index_cache(sys), ps))
460460
else
461461
ps = (DestructuredArgs(ps, inbounds = !checkbounds),)
462462
end
463+
if isempty(ps)
464+
ps = (DestructuredArgs([]),)
465+
end
463466
dvs = DestructuredArgs(unknowns(sys), inbounds = !checkbounds)
464467
if inputs === nothing
465468
args = [dvs, ps..., ivs...]

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
139139
iv′ = value(iv)
140140
dvs′ = value.(dvs)
141141
ps′ = value.(ps)
142-
if !all(hasshift, eqs)
143-
error("All equations in a `DiscreteSystem` must be difference equations")
142+
if any(hasderiv, eqs) || any(hashold, eqs) || any(hassample, eqs) || any(hasdiff, eqs)
143+
error("Equations in a `DiscreteSystem` can only have `Shift` operators.")
144144
end
145145
if !(isempty(default_u0) && isempty(default_p))
146146
Base.depwarn(

src/systems/systemstructure.jl

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,6 @@ function TearingState(sys; quick_cancel = false, check = true)
369369
steps = 0
370370
tt = iv
371371
v = var
372-
if lshift < 0
373-
defs = ModelingToolkit.get_defaults(sys)
374-
if (_val = get(defs, var, nothing)) !== nothing
375-
defs[Shift(tt, -1)(v)] = _val
376-
end
377-
end
378372
else
379373
continue
380374
end
@@ -434,10 +428,14 @@ function TearingState(sys; quick_cancel = false, check = true)
434428

435429
eq_to_diff = DiffGraph(nsrcs(graph))
436430

437-
return TearingState(sys, fullvars,
431+
ts = TearingState(sys, fullvars,
438432
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
439433
complete(graph), nothing, var_types, sys isa DiscreteSystem),
440434
Any[])
435+
if sys isa DiscreteSystem
436+
ts = shift_discrete_system(ts)
437+
end
438+
return ts
441439
end
442440

443441
function lower_order_var(dervar, t)
@@ -458,6 +456,30 @@ function lower_order_var(dervar, t)
458456
diffvar
459457
end
460458

459+
function shift_discrete_system(ts::TearingState)
460+
@unpack fullvars, sys = ts
461+
discvars = OrderedSet()
462+
eqs = equations(sys)
463+
for eq in eqs
464+
vars!(discvars, eq; op = Union{Sample, Hold})
465+
end
466+
iv = get_iv(sys)
467+
discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k))
468+
for k in discvars
469+
if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold}))
470+
for i in eachindex(fullvars)
471+
fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute(
472+
fullvars[i], discmap; operator = Union{Sample, Hold}))
473+
end
474+
for i in eachindex(eqs)
475+
eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute(
476+
eqs[i], discmap; operator = Union{Sample, Hold}))
477+
end
478+
@set! ts.sys.eqs = eqs
479+
@set! ts.fullvars = fullvars
480+
return ts
481+
end
482+
461483
using .BipartiteGraphs: Label, BipartiteAdjacencyList
462484
struct SystemStructurePrintMatrix <:
463485
AbstractMatrix{Union{Label, BipartiteAdjacencyList}}

src/variables.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
195195
values = Dict()
196196
for var in varlist
197197
var = unwrap(var)
198-
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap), defaults))
198+
val = unwrap(fixpoint_sub(fixpoint_sub(var, varmap; operator = Symbolics.Operator),
199+
defaults; operator = Symbolics.Operator))
199200
if symbolic_type(val) === NotSymbolic()
200201
values[var] = val
201202
end

0 commit comments

Comments
 (0)