Skip to content

Commit 2b7218c

Browse files
authored
Merge branch 'main' into sroa
2 parents b47597f + e718f44 commit 2b7218c

File tree

9 files changed

+221
-70
lines changed

9 files changed

+221
-70
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ using Adapt
99

1010
struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
1111
ptr::Core.LLVMPtr{T,A}
12+
13+
function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
14+
push!(Reactant.Compiler.context_gc_vector[MLIR.IR.context()], xs)
15+
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
16+
return new(ptr)
17+
end
1218
end
1319

1420
function Base.show(io::IO, a::AT) where {AT<:CuTracedArray}
@@ -211,10 +217,34 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
211217
return _derived_array(a, T, dims)
212218
end
213219

214-
function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
215-
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(
216-
Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
220+
struct ReactantKernelAdaptor end
221+
222+
function Adapt.adapt_storage(to::ReactantKernelAdaptor, p::CUDA.CuPtr)
223+
return error("Cannot convert CuPtr argument of Reactant Kernel")
224+
end
225+
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray)
226+
return Adapt.adapt_storage(ka, Array(xs))
227+
end
228+
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array)
229+
return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs))
230+
end
231+
function Adapt.adapt_structure(to::ReactantKernelAdaptor, ref::Base.RefValue)
232+
return error("Cannot convert RefValue argument of Reactant Kernel")
233+
end
234+
function Adapt.adapt_structure(
235+
to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}}
236+
) where {Style,T}
237+
return Broadcast.Broadcasted{Style}(
238+
(x...) -> T(x...), Adapt.adapt(to, bc.args), bc.axes
217239
)
240+
end
241+
242+
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
243+
return adapt(ReactantKernelAdaptor(), arg)
244+
end
245+
246+
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
247+
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
218248
return res
219249
end
220250

@@ -383,7 +413,8 @@ end
383413
function Reactant.make_tracer(
384414
seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...
385415
)
386-
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
416+
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))
417+
x = x::TracedRArray
387418
Reactant.make_tracer(seen, x, path, mode; kwargs...)
388419
return prev
389420
end
@@ -441,12 +472,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
441472

442473
# linearize kernel arguments
443474
seen = Reactant.OrderedIdDict()
444-
prev = Any[func.f, args...]
445475
kernelargsym = gensym("kernelarg")
446-
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
447-
@show prev
448-
@show Core.Typeof(prev)
449-
@show seen
476+
for (i, prev) in enumerate(Any[func.f, args...])
477+
Reactant.make_tracer(seen, prev, (kernelargsym, i), Reactant.NoStopTracedTrack)
478+
end
450479
wrapper_tys = MLIR.IR.Type[]
451480
for arg in values(seen)
452481
if !(arg isa TracedRArray || arg isa TracedRNumber)
@@ -539,16 +568,18 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
539568
if !(arg isa TracedRArray || arg isa TracedRNumber)
540569
continue
541570
end
542-
for p in Reactant.TracedUtils.get_paths(arg)
571+
572+
paths = Reactant.TracedUtils.get_paths(arg)
573+
574+
arg = arg.mlir_data
575+
arg = Reactant.TracedUtils.transpose_val(arg)
576+
push!(restys, MLIR.IR.type(arg))
577+
push!(mlir_args, arg)
578+
579+
for p in paths
543580
if p[1] !== kernelargsym
544581
continue
545582
end
546-
547-
arg = arg.mlir_data
548-
arg = Reactant.TracedUtils.transpose_val(arg)
549-
push!(restys, MLIR.IR.type(arg))
550-
push!(mlir_args, arg)
551-
552583
# Get the allocation corresponding to which arg we're doing
553584
alloc = allocs[p[2]][1]
554585

@@ -583,9 +614,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
583614
),
584615
),
585616
)
586-
587-
argidx += 1
588617
end
618+
argidx += 1
589619
end
590620

591621
MLIR.IR.block!(wrapbody) do

src/Compiler.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,26 +324,34 @@ function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
324324
return mod
325325
end
326326

327+
const context_gc_vector = Dict{MLIR.IR.Context,Vector{TracedRArray}}()
328+
327329
# helper for debug purposes: String -> Text
328330
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
329331
ctx = MLIR.IR.Context(Reactant.registry[], false)
332+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
330333
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
331-
MLIR.IR.context!(ctx) do
334+
result = MLIR.IR.context!(ctx) do
332335
mod = parse(MLIR.IR.Module, source)
333336
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
334337
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
335338
Text(repr(mod))
336339
end
340+
Base.delete!(context_gc_vector, ctx)
341+
return result
337342
end
338343

339344
function compile_mlir(f, args; kwargs...)
340345
ctx = MLIR.IR.Context(Reactant.registry[], false)
346+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
341347
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
342-
MLIR.IR.context!(ctx) do
348+
results = MLIR.IR.context!(ctx) do
343349
mod = MLIR.IR.Module(MLIR.IR.Location())
344350
evalinfo = compile_mlir!(mod, f, args; kwargs...)
345-
return mod, evalinfo...
351+
return (mod, evalinfo...)
346352
end
353+
Base.delete!(context_gc_vector, ctx)
354+
return results
347355
end
348356

349357
const cuLaunch = Ref{UInt}(0)
@@ -866,10 +874,11 @@ end
866874
function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
867875
# register MLIR dialects
868876
ctx = MLIR.IR.Context(Reactant.registry[], false)
877+
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
869878
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
870879

871880
MLIR.IR.activate!(ctx)
872-
return try
881+
results = try
873882
# compile function to MLIR module
874883
mod = MLIR.IR.Module(MLIR.IR.Location())
875884
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
@@ -895,6 +904,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
895904
finally
896905
MLIR.IR.deactivate!(ctx)
897906
end
907+
Base.delete!(context_gc_vector, ctx)
908+
return results
898909
end
899910

900911
function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=false)

src/Ops.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -936,24 +936,24 @@ end
936936
end
937937

938938
# broadcast ops
939-
# function broadcast_in_dim(
940-
# x::TracedRArray{T,N},
941-
# dims::Vector{Int};
942-
# location=mlir_stacktrace(
943-
# "broadcast_in_dim", @__FILE__, @__LINE__
944-
# ),
945-
# ) where {T,N}
946-
# rsize = restype = MLIR.IR.TensorType([...], mlir_type(T)) # mlir_type(TracedRArray{T,N}, size(x))
947-
# res = MLIR.IR.result(
948-
# stablehlo.broadcast_in_dim(
949-
# x.mlir_data;
950-
# result_0=restype,
951-
# broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
952-
# location,
953-
# ),
954-
# )
955-
# return TracedRArray{T,N}((), res, size(x))
956-
# end
939+
function broadcast_in_dim(
940+
x::TracedRArray{T,N},
941+
dims::Vector{Int},
942+
result_size::Vector{Int};
943+
location=mlir_stacktrace("broadcast_in_dim", @__FILE__, @__LINE__),
944+
) where {T,N}
945+
@assert length(dims) == N
946+
947+
res = MLIR.IR.result(
948+
stablehlo.broadcast_in_dim(
949+
x.mlir_data;
950+
result_0=MLIR.IR.TensorType(result_size, MLIR.IR.Type(T)),
951+
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims .- 1),
952+
location,
953+
),
954+
)
955+
return TracedRArray{T,Int64(length(result_size))}((), res, Tuple(result_size))
956+
end
957957

958958
@noinline function sort(
959959
x::TracedRArray{T,N};

src/TracedRArray.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,21 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
218218
return v
219219
end
220220

221-
v = TracedUtils.broadcast_to_size(v, length.(indices))
222-
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
221+
if v isa Number
222+
v = TracedUtils.broadcast_to_size(v, length.(indices))
223+
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
224+
else
225+
v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v)
226+
non_integer_indices = [!(idx isa Integer) for idx in indices]
227+
broadcast_dims = findall(non_integer_indices)
228+
if length(broadcast_dims) == N
229+
v = TracedUtils.broadcast_to_size(v, length.(indices))
230+
else
231+
v = Ops.broadcast_in_dim(
232+
materialize_traced_array(v), broadcast_dims, Int64.(length.(indices))
233+
)
234+
end
235+
end
223236

224237
indices = [
225238
(

src/TracedUtils.jl

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -496,30 +496,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize)
496496
end
497497

498498
@noinline function broadcast_to_size_internal(x::TracedRArray{T}, rsize) where {T}
499-
dims = collect(Int64, 0:(length(size(x)) - 1))
500-
501-
if length(size(MLIR.IR.type(get_mlir_data(x)))) != length(dims)
502-
@show x
503-
@show arg
504-
@show rsize
505-
@show rsize2
506-
@show dims
507-
end
508-
@assert length(size(MLIR.IR.type(get_mlir_data(x)))) == length(dims)
509-
mlirty = MLIR.IR.type(get_mlir_data(x))
510-
511-
return TracedRArray{T,Int(length(rsize))}(
512-
(),
513-
MLIR.IR.result(
514-
MLIR.Dialects.stablehlo.broadcast_in_dim(
515-
get_mlir_data(x);
516-
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
517-
broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims),
518-
),
519-
1,
520-
),
521-
collect(rsize),
522-
)
499+
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
523500
end
524501

525502
end

src/Tracing.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
TracedToConcrete = 3
55
ArrayToConcrete = 4
66
TracedSetPath = 5
7+
NoStopTracedTrack = 6
78
end
89

910
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RNumber)
@@ -249,7 +250,7 @@ function traced_type(
249250
@inline base_typec(TV::TT) where {TT<:DataType} =
250251
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
251252
return base_typec(T)
252-
elseif mode == TracedTrack || mode == TracedSetPath
253+
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
253254
return T
254255
else
255256
throw("Abstract RArray $T cannot be made concrete in mode $mode")
@@ -261,7 +262,7 @@ function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:Trac
261262
throw("TracedRNG cannot be traced")
262263
elseif mode == TracedToConcrete
263264
return ConcreteRNG
264-
elseif mode == TracedTrack || mode == TracedSetPath
265+
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
265266
return T
266267
else
267268
throw("Unsupported mode: $mode")
@@ -329,7 +330,7 @@ function make_tracer(
329330
track_numbers=(),
330331
kwargs...,
331332
) where {RT}
332-
if haskey(seen, prev)
333+
if mode != NoStopTracedTrack && haskey(seen, prev)
333334
return seen[prev]
334335
end
335336
TT = traced_type(RT, (), Val(mode), track_numbers)
@@ -460,6 +461,13 @@ function make_tracer(
460461
end
461462
return prev
462463
end
464+
if mode == NoStopTracedTrack
465+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
466+
if !haskey(seen, prev)
467+
seen[prev] = prev # don't return!
468+
end
469+
return prev
470+
end
463471
if mode == TracedSetPath
464472
if haskey(seen, prev)
465473
return seen[prev]
@@ -506,6 +514,13 @@ function make_tracer(
506514
end
507515
return prev
508516
end
517+
if mode == NoStopTracedTrack
518+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
519+
if !haskey(seen, prev)
520+
seen[prev] = prev # don't return!
521+
end
522+
return prev
523+
end
509524
if mode == TracedSetPath
510525
if haskey(seen, prev)
511526
return seen[prev]
@@ -546,6 +561,13 @@ function make_tracer(
546561
end
547562
return prev
548563
end
564+
if mode == NoStopTracedTrack
565+
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
566+
if !haskey(seen, prev)
567+
seen[prev] = prev # don't return!
568+
end
569+
return prev
570+
end
549571
if mode == TracedSetPath
550572
haskey(seen, prev) && return seen[prev]
551573
res = MissingTracedValue((path,))

src/utils.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ struct MustThrowError end
195195
@generated function applyiterate_with_reactant(
196196
iteratefn, applyfn, args::Vararg{Any,N}
197197
) where {N}
198-
@assert iteratefn == typeof(Base.iterate)
198+
if iteratefn != typeof(Base.iterate)
199+
return quote
200+
error("Unhandled apply_iterate with iteratefn=$iteratefn")
201+
end
202+
end
199203
newargs = Vector{Expr}(undef, N)
200204
for i in 1:N
201205
@inbounds newargs[i] = :(args[$i]...)

0 commit comments

Comments
 (0)