Skip to content

Commit 13062bd

Browse files
Merge pull request #326 from shashi/s/sybolic-simplification
SymbolicUtils integration
2 parents f0b733f + 69690e5 commit 13062bd

File tree

9 files changed

+86
-120
lines changed

9 files changed

+86
-120
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1818
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1919
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2020
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
21+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
2122
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
2223
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2324
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
@@ -34,6 +35,7 @@ NaNMath = "0.3"
3435
SafeTestsets = "0.0.1"
3536
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
3637
StaticArrays = "0.10, 0.11, 0.12"
38+
SymbolicUtils = "0.1.1"
3739
TreeViews = "0.3"
3840
UnPack = "0.1"
3941
Unitful = "1.1"

src/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function expand_derivatives(O::Operation)
5252
end |> simplify_constants
5353
end
5454

55-
return O
55+
return simplify_constants(O)
5656
end
5757
expand_derivatives(x) = x
5858

src/simplify.jl

Lines changed: 40 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,52 @@
1-
function simplify_constants(O::Operation, shorten_tree)
2-
while true
3-
O′ = _simplify_constants(O, shorten_tree)
4-
if is_operation(O′)
5-
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
6-
end
7-
isequal(O, O′) && return O
8-
O = O′
1+
import SymbolicUtils
2+
import SymbolicUtils: FnType
3+
4+
# ModelingToolkit -> SymbolicUtils
5+
SymbolicUtils.istree(x::Operation) = true
6+
function SymbolicUtils.operation(x::Operation)
7+
if x.op isa Variable
8+
T = FnType{NTuple{length(x.args), Any}, vartype(x.op)}
9+
SymbolicUtils.Variable{T}(x.op.name)
10+
else
11+
x.op
912
end
1013
end
11-
simplify_constants(x, shorten_tree) = x
1214

13-
"""
14-
simplify_constants(x::Operation)
15+
# This is required to infer the right type for
16+
# Operation(Variable{Parameter{Number}}(:foo), [])
17+
# While keeping the metadata that the variable is a parameter.
18+
SymbolicUtils.promote_symtype(f::SymbolicUtils.Sym{FnType{X,Parameter{Y}}},
19+
xs...) where {X, Y} = Y
1520

16-
Simplifies the constants within an expression, for example removing equations
17-
multiplied by a zero and summing constant values.
18-
"""
19-
simplify_constants(x) = simplify_constants(x, true)
21+
SymbolicUtils.arguments(x::Operation) = x.args
2022

21-
Base.isone(x::Operation) = x.op == one || x.op == Constant && isone(x.args)
22-
const AC_OPERATORS = (*, +)
23+
# SymbolicUtils wants raw numbers
24+
SymbolicUtils.to_symbolic(x::Constant) = x.value
25+
SymbolicUtils.to_symbolic(x::Variable{T}) where {T} = SymbolicUtils.Sym{T}(x.name)
2326

24-
function _simplify_constants(O::Operation, shorten_tree)
25-
# Tree shrinking
26-
if shorten_tree && O.op AC_OPERATORS
27-
# Flatten tree
28-
idxs = findall(x -> is_operation(x) && x.op === O.op, O.args)
29-
if !isempty(idxs)
30-
keep_idxs = eachindex(O.args) .∉ (idxs,)
31-
args = Vector{Expression}[O.args[i].args for i in idxs]
32-
push!(args, O.args[keep_idxs])
33-
return Operation(O.op, vcat(args...))
34-
end
27+
# Optional types of vars
28+
# Once converted to SymbolicUtils Variable, a Parameter needs to hide its metadata
29+
_vartype(x::Variable{<:Parameter{T}}) where {T} = T
30+
_vartype(x::Variable{T}) where {T} = T
31+
SymbolicUtils.symtype(x::Variable) = _vartype(x) # needed for a()
32+
SymbolicUtils.symtype(x::SymbolicUtils.Sym{<:Parameter{T}}) where {T} = T
3533

36-
# Collapse constants
37-
idxs = findall(is_constant, O.args)
38-
if length(idxs) > 1
39-
other_idxs = eachindex(O.args) .∉ (idxs,)
40-
new_const = Constant(mapreduce(get, O.op, O.args[idxs]))
41-
args = push!(O.args[other_idxs], new_const)
34+
# returning Any causes SymbolicUtils to infer the type using `promote_symtype`
35+
# But we are OK with Number here for now I guess
36+
SymbolicUtils.symtype(x::Expression) = Number
4237

43-
length(args) == 1 && return first(args)
44-
return Operation(O.op, args)
45-
end
46-
end
47-
48-
if O.op === (*)
49-
# If any variable is `Constant(0)`, zero the whole thing
50-
any(iszero, O.args) && return Constant(0)
51-
52-
# If any variable is `Constant(1)`, remove that `Constant(1)` unless
53-
# they are all `Constant(1)`, in which case simplify to a single variable
54-
if any(isone, O.args)
55-
args = filter(!isone, O.args)
56-
57-
isempty(args) && return Constant(1)
58-
length(args) == 1 && return first(args)
59-
return Operation(O.op, args)
60-
end
61-
62-
return O
63-
end
64-
65-
if O.op === (^) && length(O.args) == 2 && iszero(O.args[2])
66-
return Constant(1)
67-
end
6838

69-
if O.op === (^) && length(O.args) == 2 && isone(O.args[2])
70-
return O.args[1]
71-
end
72-
73-
if O.op === (+) && any(iszero, O.args)
74-
# If there are Constant(0)s in a big `+` expression, get rid of them
75-
args = filter(!iszero, O.args)
76-
77-
isempty(args) && return Constant(0)
78-
length(args) == 1 && return first(args)
79-
return Operation(O.op, args)
80-
end
39+
# SymbolicUtils -> ModelingToolkit
8140

82-
if (O.op === (-) || O.op === (+) || O.op === (*)) && all(is_constant, O.args) && !isempty(O.args)
83-
v = O.args[1].value
84-
for i in 2:length(O.args)
85-
v = O.op(v, O.args[i].value)
86-
end
87-
return Constant(v)
88-
end
89-
90-
(O.op, length(O.args)) === (identity, 1) && return O.args[1]
91-
92-
(O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]])
41+
function simplify_constants(expr)
42+
SymbolicUtils.simplify(expr) |> to_mtk
43+
end
9344

94-
return O
45+
to_mtk(x) = x
46+
to_mtk(x::Number) = Constant(x)
47+
to_mtk(v::SymbolicUtils.Sym{T}) where {T} = Variable{T}(nameof(v))
48+
to_mtk(v::SymbolicUtils.Sym{FnType{X,Y}}) where {X,Y} = Variable{Y}(nameof(v))
49+
function to_mtk(expr::SymbolicUtils.Term)
50+
Operation(to_mtk(SymbolicUtils.operation(expr)),
51+
map(to_mtk, SymbolicUtils.arguments(expr)))
9552
end
96-
_simplify_constants(x, shorten_tree) = x
97-
_simplify_constants(x) = _simplify_constants(x, true)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
114114
end
115115
end
116116
M = simplify ? simplify_constants.(M) : M
117+
# M should only contain concrete numbers
118+
M = map(x->x isa Constant ? x.value : x, M)
117119
M == I ? I : M
118120
end
119121

test/derivatives.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using Test
66
@variables x y z
77
@derivatives D'~t D2''~t Dx'~x
88

9+
test_equal(a, b) = @test isequal(simplify_constants(a), simplify_constants(b))
10+
911
@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))
1012

1113
@test isequal(expand_derivatives(D(t)), 1)
@@ -15,12 +17,12 @@ dsin = D(sin(t))
1517
@test isequal(expand_derivatives(dsin), cos(t))
1618

1719
dcsch = D(csch(t))
18-
@test isequal(expand_derivatives(dcsch), simplify_constants(coth(t) * csch(t) * -1))
20+
@test isequal(expand_derivatives(dcsch), simplify_constants(-coth(t) * csch(t)))
1921

2022
@test isequal(expand_derivatives(D(-7)), 0)
2123
@test isequal(expand_derivatives(D(sin(2t))), simplify_constants(cos(2t) * 2))
2224
@test isequal(expand_derivatives(D2(sin(t))), simplify_constants(-sin(t)))
23-
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(sin(2t) * -4))
25+
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(-sin(2t) * 4))
2426
@test isequal(expand_derivatives(D2(t)), 0)
2527
@test isequal(expand_derivatives(D2(5)), 0)
2628

@@ -30,23 +32,23 @@ dsinsin = D(sin(sin(t)))
3032

3133
d1 = D(sin(t)*t)
3234
d2 = D(sin(t)*cos(t))
33-
@test isequal(expand_derivatives(d1), t*cos(t)+sin(t))
34-
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+(sin(t)*-1)*sin(t)))
35+
@test isequal(expand_derivatives(d1), simplify_constants(t*cos(t)+sin(t)))
36+
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+(-sin(t))*sin(t)))
3537

3638
eqs = [0 ~ σ*(y-x),
3739
0 ~ x*-z)-y,
3840
0 ~ x*y - β*z]
3941
sys = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
4042
jac = calculate_jacobian(sys)
41-
@test isequal(jac[1,1], σ*-1)
42-
@test isequal(jac[1,2], σ)
43-
@test isequal(jac[1,3], 0)
44-
@test isequal(jac[2,1], ρ-z)
45-
@test isequal(jac[2,2], -1)
46-
@test isequal(jac[2,3], x*-1)
47-
@test isequal(jac[3,1], y)
48-
@test isequal(jac[3,2], x)
49-
@test isequal(jac[3,3], -1*β)
43+
test_equal(jac[1,1], -1σ)
44+
test_equal(jac[1,2], σ)
45+
test_equal(jac[1,3], 0)
46+
test_equal(jac[2,1], ρ - z)
47+
test_equal(jac[2,2], -1)
48+
test_equal(jac[2,3], -1x)
49+
test_equal(jac[3,1], y)
50+
test_equal(jac[3,2], x)
51+
test_equal(jac[3,3], -1β)
5052

5153
# Variable dependence checking in differentiation
5254
@variables a(t) b(a)
@@ -57,7 +59,7 @@ jac = calculate_jacobian(sys)
5759
@variables x(t) y(t) z(t)
5860

5961
@test isequal(expand_derivatives(D(x * y)), simplify_constants(y*D(x) + x*D(y)))
60-
@test_broken isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
62+
@test isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
6163

6264
@test isequal(expand_derivatives(D(2t)), 2)
6365
@test isequal(expand_derivatives(D(2x)), 2D(x))

test/direct.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using ModelingToolkit, StaticArrays, LinearAlgebra, SparseArrays
22
using DiffEqBase
33
using Test
44

5+
canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))
6+
57
# Calculus
68
@parameters t σ ρ β
79
@variables x y z
@@ -24,11 +26,11 @@ end
2426
= ModelingToolkit.jacobian(eqs,[x,y,z])
2527
for i in 1:3
2628
= ModelingToolkit.gradient(eqs[i],[x,y,z])
27-
@test isequal(∂[i,:],∇)
29+
@test canonequal(∂[i,:],∇)
2830
end
2931

30-
@test all(isequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
31-
@test all(isequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
32+
@test all(canonequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
33+
@test all(canonequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
3234

3335
Joop,Jiip = eval.(ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β],t))
3436
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)

test/mass_matrix.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ eqs = [D(y[1]) ~ -k[1]*y[1] + k[3]*y[2]*y[3],
1010

1111
sys = ODESystem(eqs,t,y,k)
1212
M = calculate_massmatrix(sys)
13-
M == [1 0 0
14-
0 1 0
15-
0 0 0]
13+
@test M == [1 0 0
14+
0 1 0
15+
0 0 0]
1616

1717
f = ODEFunction(sys)
1818
prob_mm = ODEProblem(f,[1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4))

test/nonlinearsystem.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using ModelingToolkit, StaticArrays, LinearAlgebra
22
using DiffEqBase
33
using Test
44

5+
canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))
6+
57
# Define some variables
68
@parameters t σ ρ β
79
@variables x y z
@@ -37,15 +39,15 @@ eqs = [0 ~ σ*(y-x),
3739
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
3840
jac = calculate_jacobian(ns)
3941
@testset "nlsys jacobian" begin
40-
@test isequal(jac[1,1], σ * -1)
41-
@test isequal(jac[1,2], σ)
42-
@test isequal(jac[1,3], 0)
43-
@test isequal(jac[2,1], ρ - z)
44-
@test isequal(jac[2,2], -1)
45-
@test isequal(jac[2,3], x * -1)
46-
@test isequal(jac[3,1], y)
47-
@test isequal(jac[3,2], x)
48-
@test isequal(jac[3,3], -1 * β)
42+
@test canonequal(jac[1,1], σ * -1)
43+
@test canonequal(jac[1,2], σ)
44+
@test canonequal(jac[1,3], 0)
45+
@test canonequal(jac[2,1], ρ - z)
46+
@test canonequal(jac[2,2], -1)
47+
@test canonequal(jac[2,3], x * -1)
48+
@test canonequal(jac[3,1], y)
49+
@test canonequal(jac[3,2], x)
50+
@test canonequal(jac[3,3], -1 * β)
4951
end
5052
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
5153
jac_func = generate_jacobian(ns)

test/simplify.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@ identity_op = Operation(identity,[x])
1414
@test isequal(simplify_constants(identity_op), x)
1515

1616
minus_op = -x
17-
@test isequal(simplify_constants(minus_op), -1*x)
17+
@test isequal(simplify_constants(minus_op), -x)
1818
simplify_constants(minus_op)
1919

2020
@variables x
2121

22-
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :((x-2) * 2)
23-
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :((x-2)^2 * 3)
24-
@test simplified_expr(simplify_constants(x+2+3)) == :(x + 5)
22+
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :(2 * (-2 + x))
23+
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :(3 * (-2 + x)^2)
24+
@test simplified_expr(simplify_constants(x+2+3)) == :(5 + x)
2525

26-
d1 = Differential(x)((x-2)^2)
26+
d1 = Differential(x)((-2 + x)^2)
2727
d2 = Differential(x)(d1)
2828
d3 = Differential(x)(d2)
29+
2930
@test simplified_expr(expand_derivatives(d3)) == :(0)
3031
@test simplified_expr(simplify_constants(x^0)) == :(1)

0 commit comments

Comments
 (0)