Skip to content

Commit e5621bc

Browse files
YingboMashashi
andcommitted
update first_order_transform
Co-authored-by: "Shashi Gowda" <[email protected]> Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent e2f3f72 commit e5621bc

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ include("systems/diffeqs/odesystem.jl")
8383
include("systems/diffeqs/sdesystem.jl")
8484
include("systems/diffeqs/abstractodesystem.jl")
8585
include("systems/diffeqs/first_order_transform.jl")
86+
include("systems/diffeqs/index_reduction.jl")
8687
include("systems/diffeqs/modelingtoolkitize.jl")
8788
include("systems/diffeqs/validation.jl")
8889

src/systems/diffeqs/first_order_transform.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ Takes a Nth order ODESystem and returns a new ODESystem written in first order
1919
form by defining new variables which represent the N-1 derivatives.
2020
"""
2121
function ode_order_lowering(sys::ODESystem)
22-
eqs_lowered, _ = ode_order_lowering(sys.eqs, sys.iv)
23-
ODESystem(eqs_lowered, sys.iv, [var_from_nested_derivative(eq.lhs)[1] for eq in eqs_lowered], sys.ps)
22+
eqs_lowered, new_vars = ode_order_lowering(sys.eqs, sys.iv)
23+
ODESystem(eqs_lowered, sys.iv, vcat(new_vars, states(sys)), sys.ps)
2424
end
2525

2626
function ode_order_lowering(eqs, iv)
@@ -30,14 +30,18 @@ function ode_order_lowering(eqs, iv)
3030
new_vars = Variable[]
3131

3232
for (i, eq) enumerate(eqs)
33-
var, maxorder = var_from_nested_derivative(eq.lhs)
34-
if maxorder > get(var_order, var, 0)
35-
var_order[var] = maxorder
36-
any(isequal(var), vars) || push!(vars, var)
33+
if isequal(eq.lhs, Constant(0))
34+
push!(new_eqs, eq)
35+
else
36+
var, maxorder = var_from_nested_derivative(eq.lhs)
37+
if maxorder > get(var_order, var, 0)
38+
var_order[var] = maxorder
39+
any(isequal(var), vars) || push!(vars, var)
40+
end
41+
var′ = lower_varname(var, iv, maxorder - 1)
42+
rhs′ = rename_lower_order(eq.rhs)
43+
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
3744
end
38-
var′ = lower_varname(var, iv, maxorder - 1)
39-
rhs′ = rename_lower_order(eq.rhs)
40-
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
4145
end
4246

4347
for var vars

src/systems/diffeqs/odesystem.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
7575
end
7676

7777
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
78+
var_from_nested_derivative(x::Constant) = (missing, missing)
7879
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
7980
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
81+
iv_from_nested_derivative(x::Constant) = missing
8082

8183
function ODESystem(eqs; kwargs...)
82-
ivs = unique(iv_from_nested_derivative(eq.lhs) for eq eqs)
84+
ivs = unique(skipmissing(iv_from_nested_derivative(eq.lhs) for eq eqs))
8385
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
8486
iv = first(ivs)
8587

86-
dvs = unique(var_from_nested_derivative(eq.lhs)[1] for eq eqs)
88+
dvs = unique(skipmissing(var_from_nested_derivative(eq.lhs)[1] for eq eqs))
8789
ps = filter(vars(eq.rhs for eq eqs)) do x
8890
isparameter(x) & !isequal(x, iv)
8991
end |> collect

test/index_reduction.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@ using Test
55

66
# Define some variables
77
@parameters t L g
8-
@variables x(t) y(t) w(t) z(t) T(t)
8+
@variables x(t) y(t) w(t) z(t) T(t) x_t(t) y_t(t)
99
@derivatives D'~t
1010

1111
eqs2 = [D(D(x)) ~ T*x,
1212
D(D(y)) ~ T*y - g,
1313
0 ~ x^2 + y^2 - L^2]
14-
pendulum2 = ODESystem(eqs, t, [x, y, T], [L, g], name=:pendulum)
15-
ModelingToolkit.ode_order_lowering(pendulum2)
14+
pendulum2 = ODESystem(eqs2, t, [x, y, T], [L, g], name=:pendulum)
15+
lowered_sys = ModelingToolkit.ode_order_lowering(pendulum2)
16+
17+
lowered_eqs = [D(x_t) ~ T*x,
18+
D(y_t) ~ T*y - g,
19+
0 ~ x^2 + y^2 - L^2,
20+
D(x) ~ x_t,
21+
D(y) ~ x_t]
22+
@test_skip ODESystem(lowered_eqs) == lowered_sys # not gonna work
23+
@test_broken isequal(lowered_sys.eqs, lowered_eqs)
1624

1725
# Simple pendulum in cartesian coordinates
1826
eqs = [D(x) ~ w,

0 commit comments

Comments
 (0)