Skip to content

Commit 0bb1b44

Browse files
Distributed parallelism build targets
```julia using ModelingToolkit, OrdinaryDiffEq @parameters t σ ρ β @variables x(t) y(t) z(t) @derivatives D'~t eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ x*y - β*z] lorenz1 = ODESystem(eqs,name=:lorenz1) lorenz2 = ODESystem(eqs,name=:lorenz2) @variables a @parameters γ connections = [0 ~ lorenz1.x + lorenz2.y + a*γ] connected = ODESystem(connections,t,[a],[γ],systems=[lorenz1,lorenz2]) using Distributed addprocs(2) generate_function(connected,parallel=ModelingToolkit.DistributedForm())[2] generate_jacobian(connected,parallel=ModelingToolkit.DistributedForm(),sparse=true)[2] ``` gives: ```julia :((var"##MTIIPVar#365", var"##MTKArg#361", var"##MTKArg#362", var"##MTKArg#363")->begin @inbounds begin let (a, lorenz1₊x, lorenz1₊y, lorenz1₊z, lorenz2₊x, lorenz2₊y, lorenz2₊z, γ, lorenz1₊ρ, lorenz1₊σ, lorenz1₊β, lorenz2₊ρ, lorenz2₊σ, lorenz2₊β, t) = (var"##MTKArg#361"[1], var"##MTKArg#361"[2], var"##MTKArg#361"[3], var"##MTKArg#361"[4], var"##MTKArg#361"[5], var"##MTKArg#361"[6], var"##MTKArg#361"[7], var"##MTKArg#362"[1], var"##MTKArg#362"[2], var"##MTKArg#362"[3], var"##MTKArg#362"[4], var"##MTKArg#362"[5], var"##MTKArg#362"[6], var"##MTKArg#362"[7], var"##MTKArg#363") begin begin var"##MTSpawnVar#368" = Distributed.@spawnat(2, begin [(lorenz1₊x + lorenz2₊y) + a * γ, lorenz1₊σ * (lorenz1₊y - lorenz1₊x), lorenz1₊x * (lorenz1₊ρ - lorenz1₊z) - lorenz1₊y, lorenz1₊x * lorenz1₊y - lorenz1₊β * lorenz1₊z] end) end begin var"##MTSpawnVar#369" = Distributed.@spawnat(3, begin [lorenz2₊σ * (lorenz2₊y - lorenz2₊x), lorenz2₊x * (lorenz2₊ρ - lorenz2₊z) - lorenz2₊y, lorenz2₊x * lorenz2₊y - lorenz2₊β * lorenz2₊z] end) end end var"##MTReduceVar#366" = fetch(var"##MTSpawnVar#368") var"##MTReduceVar#367" = fetch(var"##MTSpawnVar#369") begin var"##MTIIPVar#365"[1] = getindex(var"##MTReduceVar#366", 1) var"##MTIIPVar#365"[2] = getindex(var"##MTReduceVar#366", 2) var"##MTIIPVar#365"[3] = getindex(var"##MTReduceVar#366", 3) var"##MTIIPVar#365"[4] = getindex(var"##MTReduceVar#366", 4) var"##MTIIPVar#365"[5] = getindex(var"##MTReduceVar#367", 1) var"##MTIIPVar#365"[6] = getindex(var"##MTReduceVar#367", 2) var"##MTIIPVar#365"[7] = getindex(var"##MTReduceVar#367", 3) end end end nothing end) ``` and ```julia :((var"##MTIIPVar#371", var"##MTKArg#367", var"##MTKArg#368", var"##MTKArg#369")->begin @inbounds begin let (a, lorenz1₊x, lorenz1₊y, lorenz1₊z, lorenz2₊x, lorenz2₊y, lorenz2₊z, γ, lorenz1₊σ, lorenz1₊β, lorenz1₊ρ, lorenz2₊σ, lorenz2₊β, lorenz2₊ρ, t) = (var"##MTKArg#367"[1], var"##MTKArg#367"[2], var"##MTKArg#367"[3], var"##MTKArg#367"[4], var"##MTKArg#367"[5], var"##MTKArg#367"[6], var"##MTKArg#367"[7], var"##MTKArg#368"[1], var"##MTKArg#368"[2], var"##MTKArg#368"[3], var"##MTKArg#368"[4], var"##MTKArg#368"[5], var"##MTKArg#368"[6], var"##MTKArg#368"[7], var"##MTKArg#369") begin begin var"##MTSpawnVar#374" = Distributed.@spawnat(2, begin [1, -1lorenz1₊σ, -1lorenz1₊z + lorenz1₊ρ, lorenz1₊y, lorenz1₊σ, -1, lorenz1₊x, -1lorenz1₊x, -1lorenz1₊β] end) end begin var"##MTSpawnVar#375" = Distributed.@spawnat(3, begin [-1lorenz2₊σ, -1lorenz2₊z + lorenz2₊ρ, lorenz2₊y, 1, lorenz2₊σ, -1, lorenz2₊x, -1lorenz2₊x, -1lorenz2₊β] end) end end var"##MTReduceVar#372" = fetch(var"##MTSpawnVar#374") var"##MTReduceVar#373" = fetch(var"##MTSpawnVar#375") begin (var"##MTIIPVar#371").nzval[1] = getindex(var"##MTReduceVar#372", 1) (var"##MTIIPVar#371").nzval[2] = getindex(var"##MTReduceVar#372", 2) (var"##MTIIPVar#371").nzval[3] = getindex(var"##MTReduceVar#372", 3) (var"##MTIIPVar#371").nzval[4] = getindex(var"##MTReduceVar#372", 4) (var"##MTIIPVar#371").nzval[5] = getindex(var"##MTReduceVar#372", 5) (var"##MTIIPVar#371").nzval[6] = getindex(var"##MTReduceVar#372", 6) (var"##MTIIPVar#371").nzval[7] = getindex(var"##MTReduceVar#372", 7) (var"##MTIIPVar#371").nzval[8] = getindex(var"##MTReduceVar#372", 8) (var"##MTIIPVar#371").nzval[9] = getindex(var"##MTReduceVar#372", 9) (var"##MTIIPVar#371").nzval[10] = getindex(var"##MTReduceVar#373", 1) (var"##MTIIPVar#371").nzval[11] = getindex(var"##MTReduceVar#373", 2) (var"##MTIIPVar#371").nzval[12] = getindex(var"##MTReduceVar#373", 3) (var"##MTIIPVar#371").nzval[13] = getindex(var"##MTReduceVar#373", 4) (var"##MTIIPVar#371").nzval[14] = getindex(var"##MTReduceVar#373", 5) (var"##MTIIPVar#371").nzval[15] = getindex(var"##MTReduceVar#373", 6) (var"##MTIIPVar#371").nzval[16] = getindex(var"##MTReduceVar#373", 7) (var"##MTIIPVar#371").nzval[17] = getindex(var"##MTReduceVar#373", 8) (var"##MTIIPVar#371").nzval[18] = getindex(var"##MTReduceVar#373", 9) end end end nothing end) ```
1 parent 904d92b commit 0bb1b44

File tree

1 file changed

+68
-18
lines changed

1 file changed

+68
-18
lines changed

src/build_function.jl

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ struct StanTarget <: BuildTargets end
44
struct CTarget <: BuildTargets end
55
struct MATLABTarget <: BuildTargets end
66

7+
abstract type ParallelForm end
8+
struct SerialForm <: ParallelForm end
9+
struct MultithreadedForm <: ParallelForm end
10+
struct DistributedForm <: ParallelForm end
11+
712
"""
813
`build_function`
914
@@ -57,7 +62,7 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
5762
end
5863

5964
function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
60-
if iip
65+
if iip
6166
wrappedex = :(
6267
($X,$(fargs.args...)) -> begin
6368
$ex
@@ -69,28 +74,28 @@ function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
6974
($(fargs.args...),) -> begin
7075
$ex
7176
end
72-
)
77+
)
7378
end
7479
wrappedex
7580
end
7681

7782
function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
7883
integrator = gensym(:MTKIntegrator)
79-
if iip
84+
if iip
8085
wrappedex = :(
81-
$integrator -> begin
86+
$integrator -> begin
8287
($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
8388
$ex
8489
nothing
8590
end
8691
)
8792
else
8893
wrappedex = :(
89-
$integrator -> begin
94+
$integrator -> begin
9095
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
9196
$ex
9297
end
93-
)
98+
)
9499
end
95100
wrappedex
96101
end
@@ -114,7 +119,7 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
114119

115120
fargs = Expr(:tuple,argnames...)
116121
oop_ex = headerfun(bounds_block, fargs, false)
117-
122+
118123
if !linenumbers
119124
oop_ex = striplines(oop_ex)
120125
end
@@ -129,8 +134,15 @@ end
129134
function _build_function(target::JuliaTarget, rhss, args...;
130135
conv = simplified_expr, expression = Val{true},
131136
checkbounds = false, constructor=nothing,
132-
linenumbers = false, multithread=false,
133-
headerfun=addheader, outputidxs=nothing)
137+
linenumbers = false, multithread=nothing,
138+
headerfun=addheader, outputidxs=nothing,
139+
parallel=SerialForm())
140+
141+
if multithread isa Bool
142+
@warn("multithraded is deprecated for the parallel argument. See the documentation.")
143+
parallel = multithread ? MultithreadedForm() : SerialForm()
144+
end
145+
134146
argnames = [gensym(:MTKArg) for i in 1:length(args)]
135147
arg_pairs = map(vars_to_pairs,zip(argnames,args))
136148
ls = reduce(vcat,first.(arg_pairs))
@@ -143,23 +155,39 @@ function _build_function(target::JuliaTarget, rhss, args...;
143155

144156
oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
145157
X = gensym(:MTIIPVar)
158+
159+
rhs_length = rhss isa SparseMatrixCSC ? length(rhss.nzval) : length(rhss)
160+
161+
if parallel isa DistributedForm
162+
numworks = Distributed.nworkers()
163+
reducevars = [Variable(gensym(:MTReduceVar))() for i in 1:numworks]
164+
lens = Int(ceil(rhs_length/numworks))
165+
finalsize = rhs_length - (numworks-1)*lens
166+
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
167+
[getindex(reducevars[end],j) for j in 1:finalsize])
168+
elseif rhss isa SparseMatrixCSC
169+
_rhss = rhss.nzval
170+
else
171+
_rhss = rhss
172+
end
173+
146174
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
147-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(rhss)],init=Expr[])
175+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)],init=Expr[])) for (i,rhsel) enumerate(_rhss)],init=Expr[])
148176
elseif eltype(eltype(rhss)) <: SparseMatrixCSC # Array of arrays of arrays
149-
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(rhss)])
177+
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j].nzval[$k] = $(conv(rhs))) for (k, rhs) enumerate(rhsel2)]) for (j, rhsel2) enumerate(rhsel)])) for (i,rhsel) enumerate(_rhss)])
150178
elseif eltype(rhss) <: SparseMatrixCSC # Array of sparse matrices
151-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(rhss)])
179+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i].nzval[$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)])
152180
elseif eltype(rhss) <: AbstractArray # Array of arrays
153-
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(rhss)], init = Expr[])
181+
ip_sys_exprs = reduce(vcat,[vec([:($X[$i][$j] = $(conv(rhs))) for (j, rhs) enumerate(rhsel)]) for (i,rhsel) enumerate(_rhss)], init = Expr[])
154182
elseif rhss isa SparseMatrixCSC
155-
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss.nzval)]
183+
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
156184
else
157-
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
185+
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) enumerate(_rhss)]
158186
end
159187

160188
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
161189

162-
if multithread
190+
if parallel isa MultithreadedForm
163191
lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads()))
164192
threaded_exprs = vcat([quote
165193
Threads.@spawn begin
@@ -172,7 +200,29 @@ function _build_function(target::JuliaTarget, rhss, args...;
172200
end
173201
end)
174202
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
175-
end
203+
elseif parallel isa DistributedForm
204+
numworks = Distributed.nworkers()
205+
lens = Int(ceil(length(ip_let_expr.args[2].args)/numworks))
206+
spawnvars = [gensym(:MTSpawnVar) for i in 1:numworks]
207+
rhss_flat = rhss isa SparseMatrixCSC ? rhss.nzval : rhss
208+
spawnvectors = vcat(
209+
[build_expr(:vect, [conv(rhs) for rhs rhss_flat[((i-1)*lens+1):i*lens]]) for i in 1:numworks-1],
210+
build_expr(:vect, [conv(rhs) for rhs rhss_flat[((numworks-1)*lens+1):end]]))
211+
212+
spawn_exprs = [quote
213+
$(spawnvars[i]) = Distributed.@spawnat $(i+1) begin
214+
$(spawnvectors[i])
215+
end
216+
end for i in 1:numworks]
217+
spawn_exprs = ModelingToolkit.build_expr(:block, spawn_exprs)
218+
resunpack_exprs = [:($(Symbol(reducevars[iter])) = fetch($(spawnvars[iter]))) for iter in 1:numworks]
219+
220+
ip_let_expr.args[2] = quote
221+
$spawn_exprs
222+
$(resunpack_exprs...)
223+
$(ip_let_expr.args[2])
224+
end
225+
end
176226

177227
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
178228

@@ -214,7 +264,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
214264

215265
oop_ex = headerfun(oop_body_block, fargs, false)
216266
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)
217-
267+
218268
if !linenumbers
219269
oop_ex = striplines(oop_ex)
220270
iip_ex = striplines(iip_ex)

0 commit comments

Comments
 (0)