Skip to content

Commit 3e96e40

Browse files
committed
Add DaggerForm
1 parent da1ce37 commit 3e96e40

File tree

4 files changed

+29
-0
lines changed

4 files changed

+29
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1717
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1818
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
19+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1920
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2021
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2122
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

src/ModelingToolkit.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import SymbolicUtils: to_symbolic, FnType
2020

2121
import TreeViews
2222

23+
using Requires
24+
2325
"""
2426
$(TYPEDEF)
2527
@@ -133,4 +135,10 @@ export build_function
133135
export @register
134136
export modelingtoolkitize
135137
export @variables, @parameters
138+
139+
const HAS_DAGGER = Ref{Bool}(false)
140+
function __init__()
141+
@require Dagger="d58978e5-989f-55fb-8d15-ea34adc7bf54" include("dagger.jl")
142+
end
143+
136144
end # module

src/build_function.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ abstract type ParallelForm end
88
struct SerialForm <: ParallelForm end
99
struct MultithreadedForm <: ParallelForm end
1010
struct DistributedForm <: ParallelForm end
11+
struct DaggerForm <: ParallelForm end
1112

1213
"""
1314
`build_function`
@@ -165,6 +166,10 @@ function _build_function(target::JuliaTarget, rhss, args...;
165166
finalsize = rhs_length - (numworks-1)*lens
166167
_rhss = vcat(reduce(vcat,[[getindex(reducevars[i],j) for j in 1:lens] for i in 1:numworks-1],init=Expr[]),
167168
[getindex(reducevars[end],j) for j in 1:finalsize])
169+
elseif parallel isa DaggerForm
170+
computevars = [Variable(gensym(:MTComputeVar))() for i in axes(rhss,1)]
171+
reducevar = Variable(gensym(:MTReduceVar))()
172+
_rhss = [getindex(reducevar,i) for i in axes(rhss,1)]
168173
elseif rhss isa SparseMatrixCSC
169174
_rhss = rhss.nzval
170175
else
@@ -222,6 +227,18 @@ function _build_function(target::JuliaTarget, rhss, args...;
222227
$(resunpack_exprs...)
223228
$(ip_let_expr.args[2])
224229
end
230+
elseif parallel isa DaggerForm
231+
@assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`"
232+
delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(conv(rhss[i])))) for i in axes(computevars,1)])
233+
# TODO: treereduce?
234+
reduce_expr = quote
235+
$(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...)))
236+
end
237+
ip_let_expr.args[2] = quote
238+
$delayed_exprs
239+
$reduce_expr
240+
$(ip_let_expr.args[2])
241+
end
225242
end
226243

227244
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])

src/dagger.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
using .Dagger
2+
3+
HAS_DAGGER[] = true

0 commit comments

Comments
 (0)