Skip to content

Commit cb63715

Browse files
committed
Recursively apply delayed() to AST
1 parent ae5a0b6 commit cb63715

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/build_function.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,12 @@ function _build_function(target::JuliaTarget, rhss, args...;
229229
end
230230
elseif parallel isa DaggerForm
231231
@assert HAS_DAGGER[] "Dagger.jl is not loaded; please do `using Dagger`"
232-
delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(conv(rhss[i])))) for i in axes(computevars,1)])
232+
dagwrap(x) = x
233+
dagwrap(ex::Expr) = dagwrap(ex, Val(ex.head))
234+
dagwrap(ex::Expr, ::Val) = ex
235+
dagwrap(ex::Expr, ::Val{:call}) = :(Dagger.delayed($(ex.args[1]))($(dagwrap.(ex.args[2:end])...)))
236+
new_rhss = dagwrap.(conv.(rhss))
237+
delayed_exprs = build_expr(:block, [:($(Symbol(computevars[i])) = Dagger.delayed(identity)($(new_rhss[i]))) for i in axes(computevars,1)])
233238
# TODO: treereduce?
234239
reduce_expr = quote
235240
$(Symbol(reducevar)) = collect(Dagger.delayed(vcat)($(computevars...)))

0 commit comments

Comments
 (0)