Skip to content

Commit 42250e1

Browse files
Merge pull request #348 from SciML/distributed
Distributed parallelism build targets
2 parents 904d92b + b3f3bd1 commit 42250e1

File tree

9 files changed

+124
-60
lines changed

9 files changed

+124
-60
lines changed

src/build_function.jl

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ struct StanTarget <: BuildTargets end
44
struct CTarget <: BuildTargets end
55
struct MATLABTarget <: BuildTargets end
66

7+
abstract type ParallelForm end
8+
struct SerialForm <: ParallelForm end
9+
struct MultithreadedForm <: ParallelForm end
10+
struct DistributedForm <: ParallelForm end
11+
712
"""
813
`build_function`
914
@@ -57,7 +62,7 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5762
end
5863

5964
function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
60-
if iip
65+
if iip
6166
wrappedex = :(
6267
($X,$(fargs.args...)) -> begin
6368
$ex
@@ -69,28 +74,28 @@ function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
6974
($(fargs.args...),) -> begin
7075
$ex
7176
end
72-
)
77+
)
7378
end
7479
wrappedex
7580
end
7681

7782
function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
7883
integrator = gensym(:MTKIntegrator)
79-
if iip
84+
if iip
8085
wrappedex = :(
81-
$integrator -> begin
86+
$integrator -> begin
8287
($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
8388
$ex
8489
nothing
8590
end
8691
)
8792
else
8893
wrappedex = :(
89-
$integrator -> begin
94+
$integrator -> begin
9095
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
9196
$ex
9297
end
93-
)
98+
)
9499
end
95100
wrappedex
96101
end
@@ -114,7 +119,7 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
114119

115120
fargs = Expr(:tuple,argnames...)
116121
oop_ex = headerfun(bounds_block, fargs, false)
117-
122+
118123
if !linenumbers
119124
oop_ex = striplines(oop_ex)
120125
end
@@ -129,8 +134,15 @@ end
129134
function _build_function(target::JuliaTarget, rhss, args...;
130135
conv = simplified_expr, expression = Val{true},
131136
checkbounds = false, constructor=nothing,
132-
linenumbers = false, multithread=false,
133-
headerfun=addheader, outputidxs=nothing)
137+
linenumbers = false, multithread=nothing,
138+
headerfun=addheader, outputidxs=nothing,
139+
parallel=SerialForm())
140+
141+
if multithread isa Bool
142+
@warn("multithraded is deprecated for the parallel argument. See the documentation.")
143+
parallel = multithread ? MultithreadedForm() : SerialForm()
144+
end
145+
134146
argnames = [gensym(:MTKArg) for i in 1:length(args)]
135147
arg_pairs = map(vars_to_pairs,zip(argnames,args))
136148
ls = reduce(vcat,first.(arg_pairs))
@@ -143,23 +155,39 @@ function _build_function(target::JuliaTarget, rhss, args...;
143155

144156
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
145157
X = gensym(:MTIIPVar)
158+
159+
rhs_length = rhss isa SparseMatrixCSC ? length(rhss.nzval) : length(rhss)
160+
161+
if parallel isa DistributedForm
162+
numworks = Distributed.nworkers()
163+
reducevars = [Variable(gensym(:MTReduceVar))() for i in 1:numworks]
164+
lens = Int(ceil(rhs_length/numworks))
165+
finalsize = rhs_length - (numworks-1)*lens
166+
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
167+
[getindex(reducevars[end],j) for j in 1:finalsize])
168+
elseif rhss isa SparseMatrixCSC
169+
_rhss = rhss.nzval
170+
else
171+
_rhss = rhss
172+
end
173+
146174
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
147-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(rhss)],init=Expr[])
175+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
148176
elseif eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of arrays
149-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(rhss)])
177+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(_rhss)])
150178
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
151-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(rhss)])
179+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)])
152180
elseif eltype(rhss) <: AbstractArray # Array of arrays
153-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(rhss)], init = Expr[])
181+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init = Expr[])
154182
elseif rhss isa SparseMatrixCSC
155-
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
183+
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
156184
else
157-
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
185+
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
158186
end
159187

160188
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
161189

162-
if multithread
190+
if parallel isa MultithreadedForm
163191
lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads()))
164192
threaded_exprs = vcat([quote
165193
Threads.@spawn begin
@@ -172,7 +200,29 @@ function _build_function(target::JuliaTarget, rhss, args...;
172200
end
173201
end)
174202
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
175-
end
203+
elseif parallel isa DistributedForm
204+
numworks = Distributed.nworkers()
205+
lens = Int(ceil(length(ip_let_expr.args[2].args)/numworks))
206+
spawnvars = [gensym(:MTSpawnVar) for i in 1:numworks]
207+
rhss_flat = rhss isa SparseMatrixCSC ? rhss.nzval : rhss
208+
spawnvectors = vcat(
209+
[build_expr(:vect, [conv(rhs) for rhs rhss_flat[((i-1)*lens+1):i*lens]]) for i in 1:numworks-1],
210+
build_expr(:vect, [conv(rhs) for rhs rhss_flat[((numworks-1)*lens+1):end]]))
211+
212+
spawn_exprs = [quote
213+
$(spawnvars[i]) = ModelingToolkit.Distributed.remotecall($(i+1)) do
214+
$(spawnvectors[i])
215+
end
216+
end for i in 1:numworks]
217+
spawn_exprs = ModelingToolkit.build_expr(:block, spawn_exprs)
218+
resunpack_exprs = [:($(Symbol(reducevars[iter])) = fetch($(spawnvars[iter]))) for iter in 1:numworks]
219+
220+
ip_let_expr.args[2] = quote
221+
$spawn_exprs
222+
$(resunpack_exprs...)
223+
$(ip_let_expr.args[2])
224+
end
225+
end
176226

177227
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
178228

@@ -214,7 +264,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
214264

215265
oop_ex = headerfun(oop_body_block, fargs, false)
216266
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
217-
267+
218268
if !linenumbers
219269
oop_ex = striplines(oop_ex)
220270
iip_ex = striplines(iip_ex)

src/differentials.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ $(SIGNATURES)
3636
3737
TODO
3838
"""
39-
function expand_derivatives(O::Operation)
40-
@. O.args = expand_derivatives(O.args)
39+
function expand_derivatives(O::Operation,simplify=true)
40+
@. O.args = expand_derivatives(O.args,simplify)
4141

4242
if isa(O.op, Differential)
4343
(D, o) = (O.op, O.args[1])
@@ -47,14 +47,16 @@ function expand_derivatives(O::Operation)
4747
isa(o, Operation) || return O
4848
isa(o.op, Variable) && return O
4949

50-
return sum(1:length(o.args)) do i
51-
derivative(o, i) * expand_derivatives(D(o.args[i]))
52-
end |> simplify
50+
x = sum(1:length(o.args)) do i
51+
derivative(o, i) * expand_derivatives(D(o.args[i]),simplify)
52+
end
53+
54+
return simplify ? ModelingToolkit.simplify(x) : x
5355
end
5456

55-
return simplify(O)
57+
return simplify ? ModelingToolkit.simplify(O) : O
5658
end
57-
expand_derivatives(x) = x
59+
expand_derivatives(x,args...) = x
5860

5961
# Don't specialize on the function here
6062
"""

src/direct.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ A helper function for computing the gradient of an expression with respect to
77
an array of variable expressions.
88
"""
99
function gradient(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
10-
out = [expand_derivatives(Differential(v)(O)) for v in vars]
11-
simplify ? ModelingToolkit.simplify.(out) : out
10+
[expand_derivatives(Differential(v)(O),simplify) for v in vars]
1211
end
1312

1413
"""
@@ -20,8 +19,7 @@ A helper function for computing the Jacobian of an array of expressions with res
2019
an array of variable expressions.
2120
"""
2221
function jacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
23-
out = [expand_derivatives(Differential(v)(O)) for O in ops, v in vars]
24-
simplify ? ModelingToolkit.simplify.(out) : out
22+
[expand_derivatives(Differential(v)(O),simplify) for O in ops, v in vars]
2523
end
2624

2725
"""
@@ -33,8 +31,7 @@ A helper function for computing the Hessian of an expression with respect to
3331
an array of variable expressions.
3432
"""
3533
function hessian(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)
36-
out = [expand_derivatives(Differential(v2)(Differential(v1)(O))) for v1 in vars, v2 in vars]
37-
simplify ? ModelingToolkit.simplify.(out) : out
34+
[expand_derivatives(Differential(v2)(Differential(v1)(O)),simplify) for v1 in vars, v2 in vars]
3835
end
3936

4037
function simplified_expr(O::Operation)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
197197
version = nothing, tgrad=false,
198198
jac = false, Wfact = false,
199199
checkbounds = false, sparse = false,
200-
linenumbers = true, multithread=false,
200+
linenumbers = true, parallel=SerialForm(),
201201
kwargs...) where iip
202202
```
203203
@@ -209,10 +209,10 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
209209
version = nothing, tgrad=false,
210210
jac = false, Wfact = false,
211211
checkbounds = false, sparse = false,
212-
linenumbers = true, multithread=false,
212+
linenumbers = true, parallel=SerialForm(),
213213
kwargs...) where iip
214214
f = ODEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
215-
linenumbers=linenumbers,multithread=multithread,
215+
linenumbers=linenumbers,parallel=parallel,
216216
sparse=sparse)
217217
u0 = varmap_to_vars(u0map,states(sys))
218218
p = varmap_to_vars(parammap,parameters(sys))

src/systems/diffeqs/sdesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,p=parammap;
161161
version = nothing, tgrad=false,
162162
jac = false, Wfact = false,
163163
checkbounds = false, sparse = false,
164-
linenumbers = true, multithread=false,
164+
linenumbers = true, parallel=SerialForm(),
165165
kwargs...)
166166
```
167167
@@ -172,11 +172,11 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,parammap=DiffEqBa
172172
version = nothing, tgrad=false,
173173
jac = false, Wfact = false,
174174
checkbounds = false, sparse = false,
175-
linenumbers = true, multithread=false,
175+
linenumbers = true, parallel=SerialForm(),
176176
kwargs...) where iip
177177

178178
f = SDEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
179-
linenumbers=linenumbers,multithread=multithread,
179+
linenumbers=linenumbers,parallel=parallel,
180180
sparse=sparse)
181181
u0 = varmap_to_vars(u0map,states(sys))
182182
p = varmap_to_vars(parammap,parameters(sys))
@@ -185,4 +185,4 @@ end
185185

186186
function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
187187
SDEProblem{true}(sys, args...; kwargs...)
188-
end
188+
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
7171
parammap=DiffEqBase.NullParameters();
7272
jac = false, sparse=false,
7373
checkbounds = false,
74-
linenumbers = true, multithread=false,
74+
linenumbers = true, parallel=SerialForm(),
7575
kwargs...) where iip
7676
```
7777
@@ -82,13 +82,13 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
8282
parammap=DiffEqBase.NullParameters();
8383
jac = false, sparse=false,
8484
checkbounds = false,
85-
linenumbers = true, multithread=false,
85+
linenumbers = true, parallel=SerialForm(),
8686
kwargs...) where iip
8787
dvs = states(sys)
8888
ps = parameters(sys)
8989

9090
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
91-
multithread=multithread,sparse=sparse,expression=Val{false})
91+
parallel=parallel,sparse=sparse,expression=Val{false})
9292
u0 = varmap_to_vars(u0map,dvs)
9393
p = varmap_to_vars(parammap,ps)
9494
NonlinearProblem(f,u0,tspan,p;kwargs...)

src/systems/optimization/optimizationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
8080
u0=nothing, lb=nothing, ub=nothing,
8181
hes = false, sparse = false,
8282
checkbounds = false,
83-
linenumbers = true, multithread=false,
83+
linenumbers = true, parallel=SerialForm(),
8484
kwargs...) where iip
8585
```
8686
@@ -92,13 +92,13 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
9292
u0=nothing, lb=nothing, ub=nothing,
9393
hes = false, sparse = false,
9494
checkbounds = false,
95-
linenumbers = true, multithread=false,
95+
linenumbers = true, parallel=SerialForm(),
9696
kwargs...) where iip
9797
dvs = states(sys)
9898
ps = parameters(sys)
9999

100100
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
101-
multithread=multithread,expression=Val{false})
101+
parallel=parallel,expression=Val{false})
102102
u0 = varmap_to_vars(u0,dvs)
103103
p = varmap_to_vars(parammap,ps)
104104
lb = varmap_to_vars(lb,dvs)

test/bigsystem.jl

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,36 @@ end
4545

4646
f(du,u,nothing,0.0)
4747

48-
multithreadedf = eval(ModelingToolkit.build_function(du,u,multithread=true)[2])
48+
multithreadedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.MultithreadedForm())[2])
4949
_du = rand(N,N,3)
5050
_u = rand(N,N,3)
5151
multithreadedf(_du,_u)
5252

53+
using Distributed
54+
addprocs(4)
55+
distributedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DistributedForm())[2])
56+
5357
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u),simplify=false))
54-
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,multithread=true)[2])
58+
serialjac = eval(ModelingToolkit.build_function(vec(jac),u)[2])
59+
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.MultithreadedForm())[2])
60+
distributedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DistributedForm())[2])
61+
62+
MyA = zeros(N,N)
63+
AMx = zeros(N,N)
64+
DA = zeros(N,N)
65+
66+
f(_du,_u,nothing,0.0)
67+
multithreadedf(_du,_u)
68+
distributedf(_du,_u)
69+
70+
#=
71+
using BenchmarkTools
72+
@btime f(_du,_u,nothing,0.0)
73+
@btime multithreadedf(_du,_u)
74+
@btime distributedf(_du,_u)
5575
56-
#_jac = similar(jac,Float64)
57-
#multithreadedjac(_jac,_u)
76+
_jac = similar(jac,Float64)
77+
@btime serialjac(_jac,_u)
78+
@btime multithreadedjac(_jac,_u)
79+
@btime distributedjac(_jac,_u)
80+
=#

0 commit comments

Comments
 (0)