Skip to content

Commit 192441f

Browse files
Merge pull request #285 from SciML/myb/Pantelides
Pantelides algorithm
2 parents 9d842fd + 7ec2a4f commit 192441f

File tree

6 files changed

+244
-11
lines changed

6 files changed

+244
-11
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ include("systems/diffeqs/odesystem.jl")
8383
include("systems/diffeqs/sdesystem.jl")
8484
include("systems/diffeqs/abstractodesystem.jl")
8585
include("systems/diffeqs/first_order_transform.jl")
86+
include("systems/diffeqs/index_reduction.jl")
8687
include("systems/diffeqs/modelingtoolkitize.jl")
8788
include("systems/diffeqs/validation.jl")
8889

src/systems/diffeqs/first_order_transform.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Takes a Nth order ODESystem and returns a new ODESystem written in first order
1919
form by defining new variables which represent the N-1 derivatives.
2020
"""
2121
function ode_order_lowering(sys::ODESystem)
22-
eqs_lowered, _ = ode_order_lowering(sys.eqs, sys.iv)
23-
ODESystem(eqs_lowered, sys.iv, [var_from_nested_derivative(eq.lhs)[1] for eq in eqs_lowered], sys.ps)
22+
eqs_lowered, new_vars = ode_order_lowering(sys.eqs, sys.iv)
23+
ODESystem(eqs_lowered, sys.iv, vcat(new_vars, states(sys)), sys.ps)
2424
end
2525

2626
function ode_order_lowering(eqs, iv)
@@ -30,14 +30,18 @@ function ode_order_lowering(eqs, iv)
3030
new_vars = Variable[]
3131

3232
for (i, eq) enumerate(eqs)
33-
var, maxorder = var_from_nested_derivative(eq.lhs)
34-
if maxorder > get(var_order, var, 0)
35-
var_order[var] = maxorder
36-
any(isequal(var), vars) || push!(vars, var)
33+
if isequal(eq.lhs, Constant(0))
34+
push!(new_eqs, eq)
35+
else
36+
var, maxorder = var_from_nested_derivative(eq.lhs)
37+
if maxorder > get(var_order, var, 0)
38+
var_order[var] = maxorder
39+
any(isequal(var), vars) || push!(vars, var)
40+
end
41+
var′ = lower_varname(var, iv, maxorder - 1)
42+
rhs′ = rename_lower_order(eq.rhs)
43+
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
3744
end
38-
var′ = lower_varname(var, iv, maxorder - 1)
39-
rhs′ = rename_lower_order(eq.rhs)
40-
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
4145
end
4246

4347
for var vars
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# V-nodes `[x_1, x_2, x_3, ..., dx_1, dx_2, ..., y_1, y_2, ...]` where `x`s are
2+
# differential variables and `y`s are algebraic variables.
3+
function get_vnodes(sys)
4+
dxvars = Operation[]
5+
edges = map(_->Int[], 1:length(sys.eqs))
6+
for (i, eq) in enumerate(sys.eqs)
7+
if !(eq.lhs isa Constant)
8+
# Make sure that the LHS is a first order derivative of a var.
9+
@assert eq.lhs.op isa Differential
10+
@assert !(eq.lhs.args[1] isa Differential) # first order
11+
12+
push!(dxvars, eq.lhs)
13+
# For efficiency we note down the diff edges here
14+
push!(edges[i], length(dxvars))
15+
end
16+
end
17+
18+
xvars = (first var_from_nested_derivative).(dxvars)
19+
algvars = setdiff(states(sys), xvars)
20+
return xvars, dxvars, edges, algvars
21+
end
22+
23+
function sys2bigraph(sys)
24+
xvars, dxvars, edges, algvars = get_vnodes(sys)
25+
xvar_offset = length(xvars)
26+
algvar_offset = 2xvar_offset
27+
for edge in edges
28+
isempty(edge) || (edge .+= xvar_offset)
29+
end
30+
31+
for (i, eq) in enumerate(sys.eqs)
32+
# T or D(x):
33+
# We assume no derivatives appear on the RHS at this point
34+
vs = vars(eq.rhs)
35+
for v in vs
36+
for (j, target_v) in enumerate(xvars)
37+
if v == target_v
38+
push!(edges[i], j)
39+
end
40+
end
41+
for (j, target_v) in enumerate(algvars)
42+
if v == target_v
43+
push!(edges[i], j+algvar_offset)
44+
end
45+
end
46+
end
47+
end
48+
49+
fullvars = [xvars; dxvars; algvars] # full list of variables
50+
vars_asso = [(1:xvar_offset) .+ xvar_offset; zeros(Int, length(fullvars) - xvar_offset)] # variable association list
51+
return edges, fullvars, vars_asso
52+
end
53+
54+
print_bigraph(sys, vars, edges) = print_bigraph(stdout, sys, vars, edges)
55+
function print_bigraph(io::IO, sys, vars, edges)
56+
println(io, "Equations:")
57+
foreach(x->println(io, x), [i => sys.eqs[i] for i in 1:length(sys.eqs)])
58+
for (i, edge) in enumerate(edges)
59+
println(io, "\nEq $i has:")
60+
print(io, '[')
61+
for e in edge
62+
print(io, "$(vars[e]), ")
63+
end
64+
print(io, ']')
65+
end
66+
return nothing
67+
end
68+
69+
function match_equation!(edges, i, assign, active, vcolor=falses(length(active)), ecolor=falses(length(edges)))
70+
# `edge[active]` are active edges
71+
# i: equations
72+
# j: variables
73+
# assign: assign[j] == i means (i-j) is assigned
74+
#
75+
# color the equation
76+
ecolor[i] = true
77+
# if a V-node j exists s.t. edge (i-j) exists and assign[j] == 0
78+
for j in edges[i]
79+
if active[j] && assign[j] == 0
80+
assign[j] = i
81+
return true
82+
end
83+
end
84+
# for every j such that edge (i-j) exists and j is uncolored
85+
for j in edges[i]
86+
(active[j] && !vcolor[j]) || continue
87+
# color the variable
88+
vcolor[j] = true
89+
if match_equation!(edges, assign[j], assign, active, vcolor, ecolor)
90+
assign[j] = i
91+
return true
92+
end
93+
end
94+
return false
95+
end
96+
97+
function matching(edges, nvars, active=trues(nvars))
98+
assign = zeros(Int, nvars)
99+
for i in 1:length(edges)
100+
match_equation!(edges, i, assign, active)
101+
end
102+
return assign
103+
end
104+
105+
function pantelides(sys::ODESystem; kwargs...)
106+
edges, fullvars, vars_asso = sys2bigraph(sys)
107+
return pantelides!(edges, fullvars, vars_asso; kwargs...)
108+
end
109+
110+
function pantelides!(edges, vars, vars_asso; maxiter = 8000)
111+
neqs = length(edges)
112+
nvars = length(vars)
113+
assign = zeros(Int, nvars)
114+
eqs_asso = fill(0, neqs)
115+
neqs′ = neqs
116+
for k in 1:neqs′
117+
i = k
118+
pathfound = false
119+
# In practice, `maxiter=8000` should never be reached, otherwise, the
120+
# index would be on the order of thousands.
121+
for _ in 1:maxiter
122+
# run matching on (dx, y) variables
123+
active = vars_asso .== 0
124+
vcolor = falses(nvars)
125+
ecolor = falses(neqs)
126+
pathfound = match_equation!(edges, i, assign, active, vcolor, ecolor)
127+
pathfound && break # terminating condition
128+
# for every colored V-node j
129+
for j in eachindex(vcolor); vcolor[j] || continue
130+
# introduce a new variable
131+
nvars += 1
132+
push!(vars_asso, 0)
133+
vars_asso[j] = nvars
134+
push!(assign, 0)
135+
end
136+
137+
# for every colored E-node l
138+
for l in eachindex(ecolor); ecolor[l] || continue
139+
neqs += 1
140+
# create new E-node
141+
push!(edges, copy(edges[l]))
142+
# create edges from E-node `neqs` to all V-nodes `j` and
143+
# `vars_asso[j]` s.t. edge `(l-j)` exists
144+
for j in edges[l]
145+
if !(vars_asso[j] in edges[neqs])
146+
push!(edges[neqs], vars_asso[j])
147+
end
148+
end
149+
push!(eqs_asso, 0)
150+
eqs_asso[l] = neqs
151+
end
152+
153+
# for every colored V-node j
154+
for j in eachindex(vcolor); vcolor[j] || continue
155+
assign[vars_asso[j]] = eqs_asso[assign[j]]
156+
end
157+
i = eqs_asso[i]
158+
end # for _ in 1:maxiter
159+
pathfound || error("maxiter=$maxiter reached! File a bug report if your system has a reasonable index (<100), and you are using the default `maxiter`. Try to increase the maxiter by `pantelides(sys::ODESystem; maxiter=1_000_000)` if your system has an incredibly high index and it is truly extremely large.")
160+
end # for k in 1:neqs′
161+
return edges, assign, vars_asso, eqs_asso
162+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
7575
end
7676

7777
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
78+
var_from_nested_derivative(x::Constant) = (missing, missing)
7879
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
7980
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
81+
iv_from_nested_derivative(x::Constant) = missing
8082

8183
function ODESystem(eqs; kwargs...)
82-
ivs = unique(iv_from_nested_derivative(eq.lhs) for eq eqs)
84+
ivs = unique(skipmissing(iv_from_nested_derivative(eq.lhs) for eq eqs))
8385
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
8486
iv = first(ivs)
8587

86-
dvs = unique(var_from_nested_derivative(eq.lhs)[1] for eq eqs)
88+
dvs = unique(skipmissing(var_from_nested_derivative(eq.lhs)[1] for eq eqs))
8789
ps = filter(vars(eq.rhs for eq eqs)) do x
8890
isparameter(x) & !isequal(x, iv)
8991
end |> collect

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Base.occursin(t::Expression, x::Expression) = isequal(x, t)
6767
vars(exprs) = foldl(vars!, exprs; init = Set{Variable}())
6868
function vars!(vars, O)
6969
isa(O, Operation) || return vars
70+
O.op isa Variable && push!(vars, O.op)
7071
for arg O.args
7172
if isa(arg, Operation)
7273
isa(arg.op, Variable) && push!(vars, arg.op)

test/index_reduction.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using ModelingToolkit
2+
using ModelingToolkit: sys2bigraph
3+
using DiffEqBase
4+
using Test
5+
6+
# Define some variables
7+
@parameters t L g
8+
@variables x(t) y(t) w(t) z(t) T(t) xˍt(t) yˍt(t)
9+
@derivatives D'~t
10+
11+
eqs2 = [D(D(x)) ~ T*x,
12+
D(D(y)) ~ T*y - g,
13+
0 ~ x^2 + y^2 - L^2]
14+
pendulum2 = ODESystem(eqs2, t, [x, y, T], [L, g], name=:pendulum)
15+
lowered_sys = ModelingToolkit.ode_order_lowering(pendulum2)
16+
17+
lowered_eqs = [D(xˍt) ~ T*x,
18+
D(yˍt) ~ T*y - g,
19+
0 ~ x^2 + y^2 - L^2,
20+
D(x) ~ xˍt,
21+
D(y) ~ yˍt]
22+
@test ODESystem(lowered_eqs, t, [xˍt, yˍt, x, y, T], [L, g]) == lowered_sys
23+
@test isequal(lowered_sys.eqs, lowered_eqs)
24+
25+
# Simple pendulum in cartesian coordinates
26+
eqs = [D(x) ~ w,
27+
D(y) ~ z,
28+
D(w) ~ T*x,
29+
D(z) ~ T*y - g,
30+
0 ~ x^2 + y^2 - L^2]
31+
pendulum = ODESystem(eqs, t, [x, y, w, z, T], [L, g], name=:pendulum)
32+
33+
edges, vars, vars_asso = sys2bigraph(pendulum)
34+
@test ModelingToolkit.matching(edges, length(vars), vars_asso .== 0) == [0, 0, 0, 0, 1, 2, 3, 4, 0]
35+
36+
edges, assign, vars_asso, eqs_asso = ModelingToolkit.pantelides(pendulum)
37+
38+
@test sort.(edges) == sort.([
39+
[5, 3], # 1
40+
[6, 4], # 2
41+
[7, 9, 1], # 3
42+
[8, 9, 2], # 4
43+
[2, 1], # 5
44+
[2, 1, 6, 5], # 6
45+
[5, 3, 10, 7], # 7
46+
[6, 4, 11, 8], # 8
47+
[2, 1, 6, 5, 11, 10], # 9
48+
])
49+
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
50+
# [x, y, w, z, x', y', w', z', T, x'', y'']
51+
@test vars_asso == [5, 6, 7, 8, 10, 11, 0, 0, 0, 0, 0]
52+
#1: D(x) ~ w
53+
#2: D(y) ~ z
54+
#3: D(w) ~ T*x
55+
#4: D(z) ~ T*y - g
56+
#5: 0 ~ x^2 + y^2 - L^2
57+
# ----
58+
#6: D(5) -> 0 ~ 2xx'+ 2yy'
59+
#7: D(1) -> D(D(x)) ~ D(w)
60+
#8: D(2) -> D(D(y)) ~ D(z)
61+
#9: D(6) -> 0 ~ 2xx'' + 2x'x' + 2yy'' + 2y'y'
62+
# [1, 2, 3, 4, 5, 6, 7, 8, 9]
63+
@test eqs_asso == [7, 8, 0, 0, 6, 9, 0, 0, 0]

0 commit comments

Comments
 (0)