Skip to content

Commit 7072cfc

Browse files
push simplify through derivative expansion
1 parent 964cd9a commit 7072cfc

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

src/differentials.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ $(SIGNATURES)
3636
3737
TODO
3838
"""
39-
function expand_derivatives(O::Operation)
40-
@. O.args = expand_derivatives(O.args)
39+
function expand_derivatives(O::Operation,simplify=true)
40+
@. O.args = expand_derivatives(O.args,simplify)
4141

4242
if isa(O.op, Differential)
4343
(D, o) = (O.op, O.args[1])
@@ -47,14 +47,16 @@ function expand_derivatives(O::Operation)
4747
isa(o, Operation) || return O
4848
isa(o.op, Variable) && return O
4949

50-
return sum(1:length(o.args)) do i
51-
derivative(o, i) * expand_derivatives(D(o.args[i]))
52-
end |> simplify
50+
x = sum(1:length(o.args)) do i
51+
derivative(o, i) * expand_derivatives(D(o.args[i]),simplify)
52+
end
53+
54+
return simplify ? ModelingToolkit.simplify(x) : x
5355
end
5456

55-
return simplify(O)
57+
return simplify ? ModelingToolkit.simplify(O) : O
5658
end
57-
expand_derivatives(x) = x
59+
expand_derivatives(x,args...) = x
5860

5961
# Don't specialize on the function here
6062
"""

src/direct.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ A helper function for computing the gradient of an expression with respect to
77
an array of variable expressions.
88
"""
99
function gradient(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
10-
out = [expand_derivatives(Differential(v)(O)) for v in vars]
11-
simplify ? ModelingToolkit.simplify.(out) : out
10+
[expand_derivatives(Differential(v)(O),simplify) for v in vars]
1211
end
1312

1413
"""
@@ -20,8 +19,7 @@ A helper function for computing the Jacobian of an array of expressions with res
2019
an array of variable expressions.
2120
"""
2221
function jacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
23-
out = [expand_derivatives(Differential(v)(O)) for O in ops, v in vars]
24-
simplify ? ModelingToolkit.simplify.(out) : out
22+
[expand_derivatives(Differential(v)(O),simplify) for O in ops, v in vars]
2523
end
2624

2725
"""
@@ -33,8 +31,7 @@ A helper function for computing the Hessian of an expression with respect to
3331
an array of variable expressions.
3432
"""
3533
function hessian(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
36-
out = [expand_derivatives(Differential(v2)(Differential(v1)(O))) for v1 in vars, v2 in vars]
37-
simplify ? ModelingToolkit.simplify.(out) : out
34+
[expand_derivatives(Differential(v2)(Differential(v1)(O)),simplify) for v1 in vars, v2 in vars]
3835
end
3936

4037
function simplified_expr(O::Operation)

test/bigsystem.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,31 @@ end
4545

4646
f(du,u,nothing,0.0)
4747

48-
multithreadedf = eval(ModelingToolkit.build_function(du,u,parallel=MultithreadedForm())[2])
48+
multithreadedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.MultithreadedForm())[2])
4949
_du = rand(N,N,3)
5050
_u = rand(N,N,3)
5151
multithreadedf(_du,_u)
5252

53+
using Distributed
54+
addprocs(4)
55+
distributedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DistributedForm())[2])
56+
5357
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u),simplify=false))
54-
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=MultithreadedForm())[2])
58+
serialjac = eval(ModelingToolkit.build_function(vec(jac),u)[2])
59+
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.MultithreadedForm())[2])
60+
distributedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DistributedForm())[2])
61+
62+
#=
63+
MyA = zeros(N,N)
64+
AMx = zeros(N,N)
65+
DA = zeros(N,N)
66+
using BenchmarkTools
67+
@btime f(_du,_u,nothing,0.0)
68+
@btime multithreadedf(_du,_u)
69+
@btime distributedf(_du,_u)
5570
56-
#_jac = similar(jac,Float64)
57-
#multithreadedjac(_jac,_u)
71+
_jac = similar(jac,Float64)
72+
@btime serialjac(_jac,_u)
73+
@btime multithreadedjac(_jac,_u)
74+
@btime distributedjac(_jac,_u)
75+
=#

0 commit comments

Comments
 (0)