Skip to content

Commit e718f44

Browse files
Kernel: support constant input arg (#522)
* Kernel: support constant input arg * Update utils.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 7c2e390 commit e718f44

File tree

4 files changed

+83
-11
lines changed

4 files changed

+83
-11
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ using Adapt
99

1010
struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
1111
ptr::Core.LLVMPtr{T,A}
12+
13+
function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
14+
push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs)
15+
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
16+
return new(ptr)
17+
end
1218
end
1319

1420
function Base.show(io::IO, a::AT) where {AT<:CuTracedArray}
@@ -211,10 +217,34 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
211217
return _derived_array(a, T, dims)
212218
end
213219

214-
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
215-
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(
216-
Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
220+
struct ReactantKernelAdaptor end
221+
222+
function Adapt.adapt_storage(to::ReactantKernelAdaptor, p::CUDA.CuPtr)
223+
return error("Cannot convert CuPtr argument of Reactant Kernel")
224+
end
225+
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray)
226+
return Adapt.adapt_storage(ka, Array(xs))
227+
end
228+
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array)
229+
return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs))
230+
end
231+
function Adapt.adapt_structure(to::ReactantKernelAdaptor, ref::Base.RefValue)
232+
return error("Cannot convert RefValue argument of Reactant Kernel")
233+
end
234+
function Adapt.adapt_structure(
235+
to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}}
236+
) where {Style,T}
237+
return Broadcast.Broadcasted{Style}(
238+
(x...) -> T(x...), Adapt.adapt(to, bc.args), bc.axes
217239
)
240+
end
241+
242+
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
243+
return adapt(ReactantKernelAdaptor(), arg)
244+
end
245+
246+
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
247+
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
218248
return res
219249
end
220250

@@ -383,7 +413,8 @@ end
383413
function Reactant.make_tracer(
384414
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
385415
)
386-
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
416+
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))
417+
x = x::TracedRArray
387418
Reactant.make_tracer(seen, x, path, mode; kwargs...)
388419
return prev
389420
end
@@ -441,9 +472,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
441472

442473
# linearize kernel arguments
443474
seen = Reactant.OrderedIdDict()
444-
prev = Any[func.f, args...]
445475
kernelargsym = gensym("kernelarg")
446-
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.NoStopTracedTrack)
476+
for (i, prev) in enumerate(Any[func.f, args...])
477+
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
478+
end
447479
wrapper_tys = MLIR.IR.Type[]
448480
for arg in values(seen)
449481
if !(arg isa TracedRArray || arg isa TracedRNumber)

src/Compiler.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,26 +318,34 @@ function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
318318
return mod
319319
end
320320

321+
const context_gc_vector = Dict{MLIR.IR.Context,Vector{TracedRArray}}()
322+
321323
# helper for debug purposes: String -> Text
322324
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
323325
ctx = MLIR.IR.Context(Reactant.registry[], false)
326+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
324327
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
325-
MLIR.IR.context!(ctx) do
328+
result = MLIR.IR.context!(ctx) do
326329
mod = parse(MLIR.IR.Module, source)
327330
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
328331
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
329332
Text(repr(mod))
330333
end
334+
Base.delete!(context_gc_vector, ctx)
335+
return result
331336
end
332337

333338
function compile_mlir(f, args; kwargs...)
334339
ctx = MLIR.IR.Context(Reactant.registry[], false)
340+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
335341
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
336-
MLIR.IR.context!(ctx) do
342+
results = MLIR.IR.context!(ctx) do
337343
mod = MLIR.IR.Module(MLIR.IR.Location())
338344
evalinfo = compile_mlir!(mod, f, args; kwargs...)
339-
return mod, evalinfo...
345+
return (mod, evalinfo...)
340346
end
347+
Base.delete!(context_gc_vector, ctx)
348+
return results
341349
end
342350

343351
const cuLaunch = Ref{UInt}(0)
@@ -859,10 +867,11 @@ end
859867
function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
860868
# register MLIR dialects
861869
ctx = MLIR.IR.Context(Reactant.registry[], false)
870+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
862871
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
863872

864873
MLIR.IR.activate!(ctx)
865-
return try
874+
results = try
866875
# compile function to MLIR module
867876
mod = MLIR.IR.Module(MLIR.IR.Location())
868877
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
@@ -888,6 +897,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
888897
finally
889898
MLIR.IR.deactivate!(ctx)
890899
end
900+
Base.delete!(context_gc_vector, ctx)
901+
return results
891902
end
892903

893904
function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=false)

src/utils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ struct MustThrowError end
195195
@generated function applyiterate_with_reactant(
196196
iteratefn, applyfn, args::Vararg{Any,N}
197197
) where {N}
198-
@assert iteratefn == typeof(Base.iterate)
198+
if iteratefn != typeof(Base.iterate)
199+
return quote
200+
error("Unhandled apply_iterate with iteratefn=$iteratefn")
201+
end
202+
end
199203
newargs = Vector{Expr}(undef, N)
200204
for i in 1:N
201205
@inbounds newargs[i] = :(args[$i]...)

test/integration/cuda.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,28 @@ end
152152
end
153153
end
154154
end
155+
156+
using Reactant, CUDA
157+
158+
function cmul!(a, b)
159+
b[1] *= a[1]
160+
return nothing
161+
end
162+
163+
function mixed(a, b)
164+
@cuda threads = 1 cmul!(a, b)
165+
return nothing
166+
end
167+
168+
@static if !Sys.isapple()
169+
@testset "Non-traced argument" begin
170+
if CUDA.functional()
171+
a = CuArray([4])
172+
b = ConcreteRArray([3])
173+
174+
@jit mixed(a, b)
175+
@test all(Array(a) == 4)
176+
@test all(Array(b) == 12)
177+
end
178+
end
179+
end

0 commit comments

Comments
 (0)