Skip to content

Commit c09a6b2

Browse files
committed
cuconvert
1 parent 789d9cc commit c09a6b2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,12 @@ function Adapt.adapt_structure(
239239
)
240240
end
241241

242-
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
242+
function recudaconvert(arg)
243243
return adapt(ReactantKernelAdaptor(), arg)
244244
end
245+
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
246+
return recudaconvert(arg)
247+
end
245248

246249
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
247250
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
@@ -454,7 +457,7 @@ end
454457

455458
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
456459
args...;
457-
convert=Val(false),
460+
convert=Val(true),
458461
blocks::CuDim=1,
459462
threads::CuDim=1,
460463
cooperative::Bool=false,
@@ -464,6 +467,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
464467
blockdim = CUDA.CuDim3(blocks)
465468
threaddim = CUDA.CuDim3(threads)
466469

470+
if convert == Val(true)
471+
args = recudaconvert.(args)
472+
end
473+
467474
mlir_args = MLIR.IR.Value[]
468475
restys = MLIR.IR.Type[]
469476
aliases = MLIR.IR.Attribute[]

0 commit comments

Comments
 (0)