Skip to content

Commit dc0bcf6

Browse files
Merge pull request #342 from SciML/substitute
substitute using SymbolicUtils
2 parents 2d9e975 + d294c5f commit dc0bcf6

File tree

13 files changed

+81
-79
lines changed

13 files changed

+81
-79
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ NaNMath = "0.3"
3737
SafeTestsets = "0.0.1"
3838
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
3939
StaticArrays = "0.10, 0.11, 0.12"
40-
SymbolicUtils = "0.1.1, 0.2"
40+
SymbolicUtils = "0.3"
4141
TreeViews = "0.3"
4242
UnPack = "0.1"
4343
Unitful = "1.1"

src/ModelingToolkit.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ import GeneralizedGenerated
1313
using DocStringExtensions
1414
using Base: RefValue
1515

16+
import SymbolicUtils
17+
import SymbolicUtils: to_symbolic, FnType
18+
1619
import TreeViews
1720

1821
"""
@@ -112,17 +115,18 @@ export Reaction, ReactionSystem
112115
export Differential, expand_derivatives, @derivatives
113116
export IntervalDomain, ProductDomain, , CircleDomain
114117
export Equation, ConstrainedEquation
115-
export simplify_constants
116-
117118
export Operation, Expression, Variable
119+
export independent_variable, states, parameters, equations
120+
118121
export calculate_jacobian, generate_jacobian, generate_function
119122
export calculate_tgrad, generate_tgrad
120123
export calculate_gradient, generate_gradient
121124
export calculate_factorized_W, generate_factorized_W
122125
export calculate_hessian, generate_hessian
123126
export calculate_massmatrix, generate_diffusion_function
124-
export independent_variable, states, parameters, equations
125-
export simplified_expr, rename, get_variables, substitute_expr!
127+
128+
export simplified_expr, rename, get_variables
129+
export simplify, substitute
126130
export build_function
127131
export @register
128132
export modelingtoolkitize

src/differentials.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ function expand_derivatives(O::Operation)
4949

5050
return sum(1:length(o.args)) do i
5151
derivative(o, i) * expand_derivatives(D(o.args[i]))
52-
end |> simplify_constants
52+
end |> simplify
5353
end
5454

55-
return simplify_constants(O)
55+
return simplify(O)
5656
end
5757
expand_derivatives(x) = x
5858

src/direct.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ an array of variable expressions.
88
"""
99
function gradient(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
1010
out = [expand_derivatives(Differential(v)(O)) for v in vars]
11-
simplify ? simplify_constants.(out) : out
11+
simplify ? ModelingToolkit.simplify.(out) : out
1212
end
1313

1414
"""
@@ -21,7 +21,7 @@ an array of variable expressions.
2121
"""
2222
function jacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
2323
out = [expand_derivatives(Differential(v)(O)) for O in ops, v in vars]
24-
simplify ? simplify_constants.(out) : out
24+
simplify ? ModelingToolkit.simplify.(out) : out
2525
end
2626

2727
"""
@@ -34,7 +34,7 @@ an array of variable expressions.
3434
"""
3535
function hessian(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
3636
out = [expand_derivatives(Differential(v2)(Differential(v1)(O))) for v1 in vars, v2 in vars]
37-
simplify ? simplify_constants.(out) : out
37+
simplify ? ModelingToolkit.simplify.(out) : out
3838
end
3939

4040
function simplified_expr(O::Operation)

src/simplify.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import SymbolicUtils
2-
import SymbolicUtils: FnType
3-
41
# ModelingToolkit -> SymbolicUtils
52
SymbolicUtils.istree(x::Operation) = true
63
function SymbolicUtils.operation(x::Operation)
@@ -38,9 +35,10 @@ SymbolicUtils.symtype(x::Expression) = Number
3835

3936
# SymbolicUtils -> ModelingToolkit
4037

41-
function simplify_constants(expr)
42-
SymbolicUtils.simplify(expr) |> to_mtk
43-
end
38+
simplify(expr::Expression) = SymbolicUtils.simplify(expr) |> to_mtk
39+
simplify(expr) = expr |> to_mtk
40+
41+
@deprecate simplify_constants(ex) simplify(ex)
4442

4543
to_mtk(x) = x
4644
to_mtk(x::Number) = Constant(x)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
6969
Wfact = lu(W, Val(false), check=false).factors
7070

7171
if simplify
72-
Wfact = simplify_constants.(Wfact)
72+
Wfact = ModelingToolkit.simplify.(Wfact)
7373
end
7474

7575
W_t = - LinearAlgebra.I/gam + jac
7676
Wfact_t = lu(W_t, Val(false), check=false).factors
7777
if simplify
78-
Wfact_t = simplify_constants.(Wfact_t)
78+
Wfact_t = ModelingToolkit.simplify.(Wfact_t)
7979
end
8080
sys.Wfact[] = Wfact
8181
sys.Wfact_t[] = Wfact_t
@@ -113,7 +113,7 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
113113
error("Only semi-explicit constant mass matrices are currently supported")
114114
end
115115
end
116-
M = simplify ? simplify_constants.(M) : M
116+
M = simplify ? ModelingToolkit.simplify.(M) : M
117117
# M should only contain concrete numbers
118118
M = map(x->x isa Constant ? x.value : x, M)
119119
M == I ? I : M

src/systems/jumps/jumpsystem.jl

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@ struct JumpSystem <: AbstractSystem
99
systems::Vector{JumpSystem}
1010
end
1111

12-
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
12+
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
1313
name = gensym(:JumpSystem))
1414
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
1515
end
1616

1717

1818

19-
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
20-
independent_variable(js),
19+
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
20+
independent_variable(js),
2121
expression=Val{false})
2222

23-
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
24-
parameters(js),
23+
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
24+
parameters(js),
2525
independent_variable(js),
2626
expression=Val{false},
27-
headerfun=add_integrator_header,
27+
headerfun=add_integrator_header,
2828
outputidxs=outputidxs)[2]
29-
function assemble_vrj(js, vrj, statetoid)
29+
function assemble_vrj(js, vrj, statetoid)
3030
rate = generate_rate_function(js, vrj.rate)
3131
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
3232
outputidxs = ((statetoid[var] for var in outputvars)...,)
@@ -42,18 +42,20 @@ function assemble_crj(js, crj, statetoid)
4242
ConstantRateJump(rate, affect)
4343
end
4444

45-
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
46-
statetoid, ptoid, p, pcontext) where {U,V,W,V2,W2}
45+
function assemble_maj(js, maj::MassActionJump{U,Vector{Pair{V,W}},Vector{Pair{V2,W2}}},
46+
statetoid, ptoid, parammap) where {U,V,W,V2,W2}
4747
sr = maj.scaled_rates
48-
if sr isa Operation || sr isa Variable
49-
pval = Base.eval(pcontext, Expr(maj.scaled_rates))
50-
else
48+
if sr isa Operation
49+
pval = substitute(sr,parammap)
50+
elseif sr isa Variable
51+
pval = Dict(parammap)[sr()]
52+
else
5153
pval = maj.scaled_rates
5254
end
53-
55+
5456
rs = Vector{Pair{Int,W}}()
5557
for (spec,stoich) in maj.reactant_stoch
56-
if iszero(spec)
58+
if iszero(spec)
5759
push!(rs, 0 => stoich)
5860
else
5961
push!(rs, statetoid[convert(Variable,spec)] => stoich)
@@ -73,13 +75,13 @@ end
7375

7476
"""
7577
```julia
76-
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan,
78+
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan,
7779
parammap=DiffEqBase.NullParameters; kwargs...)
7880
```
7981
8082
Generates a DiscreteProblem from an AbstractSystem
8183
"""
82-
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan::Tuple,
84+
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan::Tuple,
8385
parammap=DiffEqBase.NullParameters(); kwargs...)
8486
u0 = varmap_to_vars(u0map, states(sys))
8587
p = varmap_to_vars(parammap, parameters(sys))
@@ -99,29 +101,20 @@ function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
99101
majs = Vector{MassActionJump}()
100102
pvars = parameters(js)
101103
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
102-
ptoid = Dict(convert(Variable,par) => i for (i,par) in enumerate(parameters(js)))
103-
104-
# for mass action jumps might need to evaluate parameter expressions
105-
# populate dummy module with params as local variables
106-
# (for eval-ing parameter expressions)
107-
param_context = Module()
108-
for (i, pval) in enumerate(prob.p)
109-
psym = Symbol(pvars[i])
110-
Base.eval(param_context, :($psym = $pval))
111-
end
104+
parammap = map(Pair,pvars,prob.p)
112105

113106
for j in equations(js)
114107
if j isa ConstantRateJump
115108
push!(crjs, assemble_crj(js, j, statetoid))
116109
elseif j isa VariableRateJump
117110
push!(vrjs, assemble_vrj(js, j, statetoid))
118111
elseif j isa MassActionJump
119-
push!(majs, assemble_maj(js, j, statetoid, ptoid, prob.p, param_context))
112+
push!(majs, assemble_maj(js, j, statetoid, parammap))
120113
else
121114
error("JumpSystems should only contain Constant, Variable or Mass Action Jumps.")
122115
end
123116
end
124-
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
117+
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
125118
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
126119
JumpProblem(prob, aggregator, jset)
127-
end
120+
end

src/utils.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,23 @@ end
100100

101101
# variable substitution
102102
"""
103-
substitute_expr!(expr::Operation, s::Pair{Operation, Operation})
103+
substitute(expr::Operation, s::Pair)
104+
substitute(expr::Operation, s::Dict)
105+
substitute(expr::Operation, s::Vector)
104106
105-
Performs the substitution `Operation => Operation` on the `expr` Operation.
107+
Performs the substitution `Operation => val` on the `expr` Operation.
106108
"""
107-
substitute_expr!(expr::Constant, s::Pair{Operation, Operation}) = nothing
108-
function substitute_expr!(expr::Operation, s::Pair{Operation, Operation})
109-
if !is_singleton(expr)
110-
expr.args .= replace(expr.args, s)
111-
for arg in expr.args
112-
substitute_expr!(arg, s)
113-
end
114-
end
115-
return nothing
109+
substitute(expr::Constant, s) = expr
110+
substitute(expr::Operation, s::Pair) = _substitute(expr, [s[1]], [s[2]])
111+
substitute(expr::Operation, dict::Dict) = _substitute(expr, keys(dict), values(dict))
112+
substitute(expr::Operation, s::Vector) = _substitute(expr, first.(s), last.(s))
113+
114+
function _substitute(expr, ks, vs)
115+
_substitute(expr, Dict(map(Pair, map(to_symbolic, ks), map(to_symbolic, vs))))
116116
end
117+
118+
function _substitute(expr, dict::Dict)
119+
to_mtk(simplify(SymbolicUtils.substitute(expr, dict)))
120+
end
121+
122+
@deprecate substitute_expr!(expr,s) substitute(expr,s)

test/derivatives.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Test
66
@variables x y z
77
@derivatives D'~t D2''~t Dx'~x
88

9-
test_equal(a, b) = @test isequal(simplify_constants(a), simplify_constants(b))
9+
test_equal(a, b) = @test isequal(simplify(a), simplify(b))
1010

1111
@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))
1212

@@ -17,12 +17,12 @@ dsin = D(sin(t))
1717
@test isequal(expand_derivatives(dsin), cos(t))
1818

1919
dcsch = D(csch(t))
20-
@test isequal(expand_derivatives(dcsch), simplify_constants(-coth(t) * csch(t)))
20+
@test isequal(expand_derivatives(dcsch), simplify(-coth(t) * csch(t)))
2121

2222
@test isequal(expand_derivatives(D(-7)), 0)
23-
@test isequal(expand_derivatives(D(sin(2t))), simplify_constants(cos(2t) * 2))
24-
@test isequal(expand_derivatives(D2(sin(t))), simplify_constants(-sin(t)))
25-
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(-sin(2t) * 4))
23+
@test isequal(expand_derivatives(D(sin(2t))), simplify(cos(2t) * 2))
24+
@test isequal(expand_derivatives(D2(sin(t))), simplify(-sin(t)))
25+
@test isequal(expand_derivatives(D2(sin(2t))), simplify(-sin(2t) * 4))
2626
@test isequal(expand_derivatives(D2(t)), 0)
2727
@test isequal(expand_derivatives(D2(5)), 0)
2828

@@ -32,8 +32,8 @@ dsinsin = D(sin(sin(t)))
3232

3333
d1 = D(sin(t)*t)
3434
d2 = D(sin(t)*cos(t))
35-
@test isequal(expand_derivatives(d1), simplify_constants(t*cos(t)+sin(t)))
36-
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+(-sin(t))*sin(t)))
35+
@test isequal(expand_derivatives(d1), simplify(t*cos(t)+sin(t)))
36+
@test isequal(expand_derivatives(d2), simplify(cos(t)*cos(t)+(-sin(t))*sin(t)))
3737

3838
eqs = [0 ~ σ*(y-x),
3939
0 ~ x*-z)-y,
@@ -58,12 +58,12 @@ test_equal(jac[3,3], -1β)
5858

5959
@variables x(t) y(t) z(t)
6060

61-
@test isequal(expand_derivatives(D(x * y)), simplify_constants(y*D(x) + x*D(y)))
62-
@test isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
61+
@test isequal(expand_derivatives(D(x * y)), simplify(y*D(x) + x*D(y)))
62+
@test isequal(expand_derivatives(D(x * y)), simplify(D(x)*y + x*D(y)))
6363

6464
@test isequal(expand_derivatives(D(2t)), 2)
6565
@test isequal(expand_derivatives(D(2x)), 2D(x))
66-
@test isequal(expand_derivatives(D(x^2)), simplify_constants(2 * x * D(x)))
66+
@test isequal(expand_derivatives(D(x^2)), simplify(2 * x * D(x)))
6767

6868
# n-ary * and +
6969
isequal(ModelingToolkit.derivative(Operation(*, [x, y, z*ρ]), 1), y*(z*ρ))

test/direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ModelingToolkit, StaticArrays, LinearAlgebra, SparseArrays
22
using DiffEqBase
33
using Test
44

5-
canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))
5+
canonequal(a, b) = isequal(simplify(a), simplify(b))
66

77
# Calculus
88
@parameters t σ ρ β

test/nonlinearsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using ModelingToolkit, StaticArrays, LinearAlgebra
22
using DiffEqBase
33
using Test
44

5-
canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))
5+
canonequal(a, b) = isequal(simplify(a), simplify(b))
66

77
# Define some variables
88
@parameters t σ ρ β

test/simplify.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,27 @@ using Test
55
@variables x(t) y(t) z(t)
66

77
null_op = 0*t
8-
@test isequal(simplify_constants(null_op), 0)
8+
@test isequal(simplify(null_op), 0)
99

1010
one_op = 1*t
11-
@test isequal(simplify_constants(one_op), t)
11+
@test isequal(simplify(one_op), t)
1212

1313
identity_op = Operation(identity,[x])
14-
@test isequal(simplify_constants(identity_op), x)
14+
@test isequal(simplify(identity_op), x)
1515

1616
minus_op = -x
17-
@test isequal(simplify_constants(minus_op), -x)
18-
simplify_constants(minus_op)
17+
@test isequal(simplify(minus_op), -1x)
18+
simplify(minus_op)
1919

2020
@variables x
2121

2222
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :(2 * (-2 + x))
2323
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :(3 * (-2 + x)^2)
24-
@test simplified_expr(simplify_constants(x+2+3)) == :(5 + x)
24+
@test simplified_expr(simplify(x+2+3)) == :(5 + x)
2525

2626
d1 = Differential(x)((-2 + x)^2)
2727
d2 = Differential(x)(d1)
2828
d3 = Differential(x)(d2)
2929

3030
@test simplified_expr(expand_derivatives(d3)) == :(0)
31-
@test simplified_expr(simplify_constants(x^0)) == :(1)
31+
@test simplified_expr(simplify(x^0)) == :(1)

test/variable_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ sol = ModelingToolkit.get_variables(expr)
77

88
@parameters γ
99
s = α => γ
10+
expr = (((1 / β - 1) + δ) / α) ^ (1 /- 1))
11+
ModelingToolkit.substitute(expr, s)
1012
new = (((1 / β - 1) + δ) / γ) ^ (1 /- 1))
11-
ModelingToolkit.substitute_expr!(expr, s)
1213
@test isequal(expr, new)

0 commit comments

Comments
 (0)