Skip to content

Commit 13967d0

Browse files
feat: support negative shifts in structural_simplify
1 parent 30df163 commit 13967d0

File tree

8 files changed

+175
-77
lines changed

8 files changed

+175
-77
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
382382
dx = fullvars[dv]
383383
# add `x_t`
384384
order, lv = var_order(dv)
385-
x_t = lower_varname(fullvars[lv], iv, order)
386-
push!(fullvars, x_t)
385+
x_t = lower_varname_withshift(fullvars[lv], iv, order)
386+
push!(fullvars, simplify_shifts(x_t))
387387
v_t = length(fullvars)
388388
v_t_idx = add_vertex!(var_to_diff)
389389
add_vertex!(graph, DST)
@@ -437,11 +437,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
437437
# We cannot solve the differential variable like D(x)
438438
if isdervar(iv)
439439
order, lv = var_order(iv)
440-
dx = D(lower_varname(fullvars[lv], idep, order - 1))
441-
eq = dx ~ ModelingToolkit.fixpoint_sub(
440+
dx = D(simplify_shifts(lower_varname_withshift(
441+
fullvars[lv], idep, order - 1)))
442+
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
442443
Symbolics.solve_for(neweqs[ieq],
443444
fullvars[iv]),
444-
total_sub)
445+
total_sub; operator = ModelingToolkit.Shift))
445446
for e in 𝑑neighbors(graph, iv)
446447
e == ieq && continue
447448
for v in 𝑠neighbors(graph, e)
@@ -450,7 +451,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
450451
rem_edge!(graph, e, iv)
451452
end
452453
push!(diff_eqs, eq)
453-
total_sub[eq.lhs] = eq.rhs
454+
total_sub[simplify_shifts(eq.lhs)] = eq.rhs
454455
push!(diffeq_idxs, ieq)
455456
push!(diff_vars, diff_to_var[iv])
456457
continue
@@ -469,7 +470,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
469470
neweq = var ~ ModelingToolkit.fixpoint_sub(
470471
simplify ?
471472
Symbolics.simplify(rhs) : rhs,
472-
total_sub)
473+
total_sub; operator = ModelingToolkit.Shift)
473474
push!(subeqs, neweq)
474475
push!(solved_equations, ieq)
475476
push!(solved_variables, iv)

src/structural_transformation/utils.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,58 @@ function numerical_nlsolve(f, u0, p)
412412
# TODO: robust initial guess, better debugging info, and residual check
413413
sol.u
414414
end
415+
416+
###
417+
### Misc
418+
###
419+
420+
function lower_varname_withshift(var, iv, order)
421+
order == 0 && return var
422+
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
423+
op = operation(var)
424+
return Shift(op.t, order)(var)
425+
end
426+
return lower_varname(var, iv, order)
427+
end
428+
429+
function isdoubleshift(var)
430+
return ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
431+
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
432+
end
433+
434+
function simplify_shifts(var)
435+
ModelingToolkit.hasshift(var) || return var
436+
r = @rule ~x::isdoubleshift => begin
437+
op1 = operation(~x)
438+
vv1 = arguments(~x)[1]
439+
op2 = operation(vv1)
440+
vv2 = arguments(vv1)[1]
441+
s1 = op1.steps
442+
s2 = op2.steps
443+
t1 = op1.t
444+
t2 = op2.t
445+
if t1 === nothing
446+
ModelingToolkit.Shift(t2, s1 + s2)(vv2)
447+
else
448+
ModelingToolkit.Shift(t1, s1 + s2)(vv2)
449+
end
450+
end
451+
return Postwalk(PassThrough(r))(var)
452+
while ModelingToolkit.isoperator(var, ModelingToolkit.Shift) &&
453+
ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift)
454+
op1 = operation(var)
455+
vv1 = arguments(var)[1]
456+
op2 = operation(vv1)
457+
vv2 = arguments(vv1)[1]
458+
s1 = op1.steps
459+
s2 = op2.steps
460+
t1 = op1.t
461+
t2 = op2.t
462+
if t1 === nothing
463+
var = Shift(t2, s1 + s2)(vv2)
464+
else
465+
var = Shift(t1, s1 + s2)(vv2)
466+
end
467+
end
468+
return var
469+
end

src/systems/alias_elimination.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,11 @@ function observed2graph(eqs, unknowns)
463463
return graph, assigns
464464
end
465465

466-
function fixpoint_sub(x, dict)
467-
y = fast_substitute(x, dict)
466+
function fixpoint_sub(x, dict; operator = Nothing)
467+
y = fast_substitute(x, dict; operator)
468468
while !isequal(x, y)
469469
y = x
470-
x = fast_substitute(y, dict)
470+
x = fast_substitute(y, dict; operator)
471471
end
472472

473473
return x

src/systems/discrete_system/discrete_system.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,6 @@ function DiscreteSystem(eqs, iv; kwargs...)
179179
eq.lhs in diffvars &&
180180
throw(ArgumentError("The shift variable $(eq.lhs) is not unique in the system of equations."))
181181
push!(diffvars, eq.lhs)
182-
else
183-
throw(ArgumentError("All equations in a `DiscreteSystem` must be difference equations with positive shifts"))
184182
end
185183
end
186184
new_ps = OrderedSet()

src/systems/systemstructure.jl

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -327,24 +327,6 @@ function TearingState(sys; quick_cancel = false, check = true)
327327

328328
dvar = var
329329
idx = varidx
330-
if ModelingToolkit.isoperator(dvar, ModelingToolkit.Shift)
331-
if !(idx in dervaridxs)
332-
push!(dervaridxs, idx)
333-
end
334-
op = operation(dvar)
335-
tt = op.t
336-
steps = op.steps
337-
v = arguments(dvar)[1]
338-
for s in (steps - 1):-1:1
339-
sf = Shift(tt, s)
340-
dvar = sf(v)
341-
idx = addvar!(dvar)
342-
if !(idx in dervaridxs)
343-
push!(dervaridxs, idx)
344-
end
345-
end
346-
idx = addvar!(v)
347-
end
348330

349331
if istree(var) && operation(var) isa Symbolics.Operator &&
350332
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
@@ -364,6 +346,47 @@ function TearingState(sys; quick_cancel = false, check = true)
364346
eqs[i] = eqs[i].lhs ~ rhs
365347
end
366348
end
349+
lowest_shift = Dict()
350+
for var in fullvars
351+
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
352+
steps = operation(var).steps
353+
v = arguments(var)[1]
354+
lowest_shift[v] = min(get(lowest_shift, v, 0), steps)
355+
end
356+
end
357+
for var in fullvars
358+
if ModelingToolkit.isoperator(var, ModelingToolkit.Shift)
359+
op = operation(var)
360+
steps = op.steps
361+
v = arguments(var)[1]
362+
lshift = lowest_shift[v]
363+
tt = op.t
364+
elseif haskey(lowest_shift, var)
365+
lshift = lowest_shift[var]
366+
steps = 0
367+
tt = iv
368+
v = var
369+
if lshift < 0
370+
defs = ModelingToolkit.get_defaults(sys)
371+
if (_val = get(defs, var, nothing)) !== nothing
372+
defs[Shift(tt, lshift)(v)] = _val
373+
end
374+
end
375+
else
376+
continue
377+
end
378+
if lshift < steps
379+
push!(dervaridxs, var2idx[var])
380+
end
381+
for s in (steps - 1):-1:(lshift + 1)
382+
sf = Shift(tt, s)
383+
dvar = sf(v)
384+
idx = addvar!(dvar)
385+
if !(idx in dervaridxs)
386+
push!(dervaridxs, idx)
387+
end
388+
end
389+
end
367390

368391
# sort `fullvars` such that the mass matrix is as diagonal as possible.
369392
dervaridxs = collect(dervaridxs)
@@ -418,15 +441,18 @@ end
418441
function lower_order_var(dervar)
419442
if isdifferential(dervar)
420443
diffvar = arguments(dervar)[1]
421-
else # shift
444+
elseif ModelingToolkit.isoperator(dervar, ModelingToolkit.Shift)
422445
s = operation(dervar)
423446
step = s.steps - 1
424447
vv = arguments(dervar)[1]
425-
if step >= 1
448+
if step != 0
426449
diffvar = Shift(s.t, step)(vv)
427450
else
428451
diffvar = vv
429452
end
453+
else
454+
iv = only(arguments(dervar))
455+
return Shift(iv, -1)(dervar)
430456
end
431457
diffvar
432458
end

src/utils.jl

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -799,35 +799,51 @@ end
799799
# Symbolics needs to call unwrap on the substitution rules, but most of the time
800800
# we don't want to do that in MTK.
801801
const Eq = Union{Equation, Inequality}
802-
function fast_substitute(eq::Eq, subs)
802+
function fast_substitute(eq::Eq, subs; operator = Nothing)
803803
if eq isa Inequality
804-
Inequality(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs),
804+
Inequality(fast_substitute(eq.lhs, subs; operator),
805+
fast_substitute(eq.rhs, subs; operator),
805806
eq.relational_op)
806807
else
807-
Equation(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
808+
Equation(fast_substitute(eq.lhs, subs; operator),
809+
fast_substitute(eq.rhs, subs; operator))
808810
end
809811
end
810-
function fast_substitute(eq::T, subs::Pair) where {T <: Eq}
811-
T(fast_substitute(eq.lhs, subs), fast_substitute(eq.rhs, subs))
812+
function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq}
813+
T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator))
812814
end
813-
fast_substitute(eqs::AbstractArray, subs) = fast_substitute.(eqs, (subs,))
814-
fast_substitute(a, b) = substitute(a, b)
815-
function fast_substitute(expr, pair::Pair)
815+
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
816+
fast_substitute.(eqs, (subs,); operator)
817+
end
818+
function fast_substitute(a, b; operator = Nothing)
819+
b = Dict(value(k) => value(v) for (k, v) in b)
820+
a = value(a)
821+
haskey(b, a) && return b[a]
822+
for _b in b
823+
a = fast_substitute(a, _b; operator)
824+
end
825+
a
826+
end
827+
function fast_substitute(expr, pair::Pair; operator = Nothing)
816828
a, b = pair
829+
a = value(a)
830+
b = value(b)
817831
isequal(expr, a) && return b
818832

819833
istree(expr) || return expr
820-
op = fast_substitute(operation(expr), pair)
821-
canfold = Ref(!(op isa Symbolic))
822-
args = let canfold = canfold
823-
map(SymbolicUtils.unsorted_arguments(expr)) do x
824-
x′ = fast_substitute(x, pair)
825-
canfold[] = canfold[] && !(x′ isa Symbolic)
826-
x′
834+
op = fast_substitute(operation(expr), pair; operator)
835+
args = SymbolicUtils.unsorted_arguments(expr)
836+
if !(op isa operator)
837+
canfold = Ref(!(op isa Symbolic))
838+
args = let canfold = canfold
839+
map(args) do x
840+
x′ = fast_substitute(x, pair; operator)
841+
canfold[] = canfold[] && !(x′ isa Symbolic)
842+
x′
843+
end
827844
end
845+
canfold[] && return op(args...)
828846
end
829-
canfold[] && return op(args...)
830-
831847
similarterm(expr,
832848
op,
833849
args,

test/clock.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ ci, varmap = infer_clocks(expand_connections(_model))
423423
@test varmap[_model.feedback.output.u] == d
424424
@test varmap[_model.feedback.input2.u] == d
425425

426-
@test_skip ssys = structural_simplify(model)
426+
ssys = structural_simplify(model)
427427

428428
Tf = 0.2
429429
timevec = 0:(d.dt):Tf
@@ -445,20 +445,20 @@ y = res.y[:]
445445
# ref = Constant(k = 0.5)
446446

447447
# ; model.controller.x(k-1) => 0.0
448+
prob = ODEProblem(ssys,
449+
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
450+
(0.0, Tf))
451+
452+
@test prob.ps[Hold(ss.holder.input.u)] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
453+
@test prob.ps[ss.controller.x(k - 1)] == 0 # c2d
454+
@test prob.ps[Sample(d)(ss.sampler.input.u)] == 0 # disc state
455+
sol = solve(prob,
456+
Tsit5(),
457+
kwargshandle = KeywordArgSilent,
458+
abstol = 1e-8,
459+
reltol = 1e-8)
448460
@test_skip begin
449-
prob = ODEProblem(ssys,
450-
[model.plant.x => 0.0; model.controller.kp => 2.0; model.controller.ki => 2.0],
451-
(0.0, Tf))
452-
453-
@test prob.p[9] == 1 # constant output * kp issue https://github.com/SciML/ModelingToolkit.jl/issues/2356
454-
@test prob.p[10] == 0 # c2d
455-
@test prob.p[11] == 0 # disc state
456-
sol = solve(prob,
457-
Tsit5(),
458-
kwargshandle = KeywordArgSilent,
459-
abstol = 1e-8,
460-
reltol = 1e-8)
461-
plot([y sol(timevec, idxs = model.plant.output.u)], m = :o, lab = ["CS" "MTK"])
461+
# plot([y sol(timevec, idxs = model.plant.output.u)], m = :o, lab = ["CS" "MTK"])
462462

463463
##
464464

0 commit comments

Comments
 (0)