Skip to content

Commit df03a12

Browse files
Merge pull request #431 from jpsamaroo/jps/daggerform
Add DaggerForm
2 parents 90bace8 + f63f23f commit df03a12

File tree

5 files changed

+43
-1
lines changed

5 files changed

+43
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1919
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
2020
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
21+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2122
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2223
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2324
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -50,10 +51,11 @@ Unitful = "1.1"
5051
julia = "1.2"
5152

5253
[extras]
54+
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
5355
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
5456
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
5557
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
5658
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5759

5860
[targets]
59-
test = ["OrdinaryDiffEq", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]
61+
test = ["Dagger", "OrdinaryDiffEq", "SteadyStateDiffEq", "Test", "StochasticDiffEq"]

src/ModelingToolkit.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import LightGraphs: SimpleDiGraph, add_edge!
2323

2424
import TreeViews
2525

26+
using Requires
27+
2628
"""
2729
$(TYPEDEF)
2830
@@ -148,4 +150,10 @@ export build_function
148150
export @register
149151
export modelingtoolkitize
150152
export @variables, @parameters
153+
154+
const HAS_DAGGER = Ref{Bool}(false)
155+
function __init__()
156+
@require Dagger="d58978e5-989f-55fb-8d15-ea34adc7bf54" include("dagger.jl")
157+
end
158+
151159
end # module

src/build_function.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ abstract type ParallelForm end
88
struct SerialForm <: ParallelForm end
99
struct MultithreadedForm <: ParallelForm end
1010
struct DistributedForm <: ParallelForm end
11+
struct DaggerForm <: ParallelForm end
1112

1213
"""
1314
`build_function`
@@ -196,6 +197,10 @@ function _build_function(target::JuliaTarget, rhss, args...;
196197
finalsize = rhs_length - (numworks-1)*lens
197198
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
198199
[getindex(reducevars[end],j) for j in 1:finalsize])
200+
elseif parallel isa DaggerForm
201+
computevars = [Variable(gensym(:MTComputeVar))() for i in axes(rhss,1)]
202+
reducevar = Variable(gensym(:MTReduceVar))()
203+
_rhss = [getindex(reducevar,i) for i in axes(rhss,1)]
199204
elseif rhss isa SparseMatrixCSC
200205
_rhss = rhss.nzval
201206
else
@@ -290,6 +295,23 @@ function _build_function(target::JuliaTarget, rhss, args...;
290295
$(resunpack_exprs...)
291296
$(ip_let_expr.args[2])
292297
end
298+
elseif parallel isa DaggerForm
299+
@assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`"
300+
dagwrap(x) = x
301+
dagwrap(ex::Expr) = dagwrap(ex, Val(ex.head))
302+
dagwrap(ex::Expr, ::Val) = ex
303+
dagwrap(ex::Expr, ::Val{:call}) = :(Dagger.delayed($(ex.args[1]))($(dagwrap.(ex.args[2:end])...)))
304+
new_rhss = dagwrap.(conv.(rhss))
305+
delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(new_rhss[i]))) for i in axes(computevars,1)])
306+
# TODO: treereduce?
307+
reduce_expr = quote
308+
$(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...)))
309+
end
310+
ip_let_expr.args[2] = quote
311+
$delayed_exprs
312+
$reduce_expr
313+
$(ip_let_expr.args[2])
314+
end
293315
end
294316

295317
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])

src/dagger.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
using .Dagger
2+
3+
HAS_DAGGER[] = true

test/bigsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ using Distributed
5454
addprocs(4)
5555
distributedf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DistributedForm())[2])
5656

57+
using Dagger
58+
daggerf = eval(ModelingToolkit.build_function(du,u,parallel=ModelingToolkit.DaggerForm())[2])
59+
5760
jac = sparse(ModelingToolkit.jacobian(vec(du),vec(u),simplify=false))
5861
serialjac = eval(ModelingToolkit.build_function(vec(jac),u)[2])
5962
multithreadedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.MultithreadedForm())[2])
6063
distributedjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DistributedForm())[2])
64+
daggerjac = eval(ModelingToolkit.build_function(vec(jac),u,parallel=ModelingToolkit.DaggerForm())[2])
6165

6266
MyA = zeros(N,N)
6367
AMx = zeros(N,N)
@@ -66,15 +70,18 @@ DA = zeros(N,N)
6670
f(_du,_u,nothing,0.0)
6771
multithreadedf(_du,_u)
6872
#distributedf(_du,_u)
73+
#daggerf(_du,_u)
6974

7075
#=
7176
using BenchmarkTools
7277
@btime f(_du,_u,nothing,0.0)
7378
@btime multithreadedf(_du,_u)
7479
@btime distributedf(_du,_u)
80+
@btime daggerf(_du,_u)
7581
7682
_jac = similar(jac,Float64)
7783
@btime serialjac(_jac,_u)
7884
@btime multithreadedjac(_jac,_u)
7985
@btime distributedjac(_jac,_u)
86+
@btime daggerjac(_jac,_u)
8087
=#

0 commit comments

Comments
 (0)