Skip to content

Commit d0b31e3

Browse files
mofeingjumerckxwsmoses
authored
linearize kernel args (#497)
* linearize kernel args * Update ext/ReactantCUDAExt.jl Co-authored-by: jumerckx <[email protected]> * tmp wip * fix * wip * fix * fixup * traced type * more fix * fixup * final diff * Update WORKSPACE * fix * bump enzymexla * bump enzymexla * Update Project.toml * Update WORKSPACE * dont fix printlin * libdevice * bump offload commit * Update Project.toml --------- Co-authored-by: jumerckx <[email protected]> Co-authored-by: William S. Moses <[email protected]>
1 parent 187bd31 commit d0b31e3

File tree

7 files changed

+183
-53
lines changed

7 files changed

+183
-53
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.2.17"
4+
version = "0.2.18"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -67,7 +67,7 @@ PythonCall = "0.9"
6767
Random = "1.10"
6868
Random123 = "1.7"
6969
ReactantCore = "0.1.3"
70-
Reactant_jll = "0.0.37"
70+
Reactant_jll = "0.0.39"
7171
Scratch = "1.2"
7272
SpecialFunctions = "2"
7373
Statistics = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,8 @@ static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
629629
static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
630630
mlir::ModuleOp source,
631631
mlir::ModuleOp target,
632-
unsigned &lastUsedID) {
632+
unsigned &lastUsedID,
633+
bool &shouldRemove) {
633634
using namespace llvm;
634635
using namespace mlir;
635636

@@ -639,6 +640,13 @@ static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
639640
return success();
640641
}
641642

643+
if (auto func = dyn_cast<FunctionOpInterface>(op.getOperation())) {
644+
if (func.isExternal()) {
645+
shouldRemove = true;
646+
return success();
647+
}
648+
}
649+
642650
StringAttr newSymName = renameSymbol(opName, lastUsedID, source, target);
643651

644652
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
@@ -658,7 +666,7 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
658666

659667
unsigned lastUsedID = 0;
660668

661-
for (auto &op : *newMod.getBody()) {
669+
for (auto &op : make_early_inc_range(*newMod.getBody())) {
662670
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
663671
if (!symbolOp)
664672
continue;
@@ -669,10 +677,14 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
669677
entryFn = &op;
670678
}
671679

672-
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID))) {
680+
bool shouldRemove = false;
681+
if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod, lastUsedID, shouldRemove))) {
673682
assert(0 && "failed to update all uses");
674683
}
675-
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
684+
if (shouldRemove)
685+
op.erase();
686+
else
687+
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
676688
}
677689
prevMod.getBody()->getOperations().splice(
678690
prevMod.getBody()->getOperations().end(),

deps/ReactantExtra/WORKSPACE

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "85612ea74731f02aa4e30800038e065912d37ae2"
12+
ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -138,7 +138,9 @@ http_archive(
138138
patches = ["@enzyme_ad//:patches/jax.patch"],
139139
)
140140

141-
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
141+
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142+
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143+
XLA_SHA256 = ""
142144

143145
http_archive(
144146
name = "xla",

ext/ReactantCUDAExt.jl

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,13 @@ function compile(job)
281281
# TODO: on 1.9, this actually creates a context. cache those.
282282
entry = GPUCompiler.JuliaContext() do ctx
283283
mod, meta = GPUCompiler.compile(
284+
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=true
284285
:llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
286+
# :llvm, job; optimize=false, cleanup=false, validate=true, libraries=false
287+
# :llvm, job; optimize=false, cleanup=false, validate=false, libraries=false
285288
)
286289

290+
GPUCompiler.link_library!(mod, GPUCompiler.load_runtime(job))
287291
entryname = LLVM.name(meta.entry)
288292

289293
GPUCompiler.optimize_module!(job, mod)
@@ -319,6 +323,8 @@ function compile(job)
319323
end
320324
end
321325

326+
# GPUCompiler.check_ir(job, mod)
327+
322328
LLVM.strip_debuginfo!(mod)
323329
modstr = string(mod)
324330

@@ -363,6 +369,38 @@ function to_bytes(x)
363369
end
364370
end
365371

372+
function Reactant.make_tracer(seen, @nospecialize(prev::CuTracedArray), @nospecialize(path), mode; kwargs...)
373+
x = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, prev.ptr))::TracedRArray
374+
Reactant.make_tracer(seen, x, path, mode; kwargs...)
375+
return prev
376+
end
377+
378+
function get_field_offset(T::Type, path)
379+
offset = 0
380+
current_type = T
381+
382+
for field in path
383+
# Get the field index
384+
field_idx = if field isa Integer
385+
field
386+
else
387+
@assert field isa Symbol
388+
findfirst(==(field), fieldnames(current_type))
389+
end
390+
if field_idx === nothing
391+
error("Field $field not found in type $current_type, fieldnames=$(fieldnames(current_type)) T=$T path=$path")
392+
end
393+
394+
# Add the offset of this field
395+
offset += fieldoffset(current_type, field_idx)
396+
397+
# Update current_type to the field's type for next iteration
398+
current_type = fieldtype(current_type, field_idx)
399+
end
400+
401+
return offset
402+
end
403+
366404
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
367405
args...;
368406
convert=Val(false),
@@ -384,20 +422,19 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
384422

385423
wrapper_tys = MLIR.IR.Type[]
386424
ctx = MLIR.IR.context()
387-
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1), 1))
388-
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
389-
if sizeof(a) == 0
425+
cullvm_ty = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 1))
426+
427+
# linearize kernel arguments
428+
seen = Reactant.OrderedIdDict()
429+
prev = Any[func.f, args...]
430+
kernelargsym = gensym("kernelarg")
431+
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
432+
wrapper_tys = MLIR.IR.Type[]
433+
for arg in values(seen)
434+
if !(arg isa TracedRArray || arg isa TracedRNumber)
390435
continue
391436
end
392-
if a isa CuTracedArray
393-
a =
394-
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
395-
end
396-
if a isa TracedRArray || a isa TracedRNumber
397-
push!(wrapper_tys, cullvm_ty)
398-
continue
399-
end
400-
# Per below we assume we can inline all other types directly in
437+
push!(wrapper_tys, cullvm_ty)
401438
end
402439

403440
sym_name = String(gensym("call_$fname"))
@@ -426,20 +463,60 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
426463
gpu_function_type = MLIR.IR.Type(Reactant.TracedUtils.get_attribute_by_name(gpufunc, "function_type"))
427464

428465
trueidx = 1
429-
for (i, a) in Tuple{Int, Any}[(0, func.f), enumerate(args)...]
466+
allocs = Union{Tuple{MLIR.IR.Value, MLIR.IR.Type}, Nothing}[]
467+
468+
llvmptr = MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))
469+
i8 = MLIR.IR.Type(UInt8)
470+
allargs = [func.f, args...]
471+
for a in allargs
430472
if sizeof(a) == 0
473+
push!(allocs, nothing)
431474
continue
432475
end
433-
if a isa CuTracedArray
434-
a =
435-
Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray
476+
477+
# TODO check for only integer and explicitly non cutraced types
478+
MLIR.IR.block!(wrapbody) do
479+
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
480+
trueidx += 1
481+
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
482+
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr), 1)
483+
push!(allocs, (alloc, argty))
484+
485+
sz = sizeof(a)
486+
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
487+
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
488+
MLIR.Dialects.llvm.store(cdata, alloc)
436489
end
437-
if a isa TracedRArray || a isa TracedRNumber
438-
push!(rarrays, a)
439-
arg = a.mlir_data
490+
491+
end
492+
493+
argidx = 1
494+
for arg in values(seen)
495+
if !(arg isa TracedRArray || arg isa TracedRNumber)
496+
continue
497+
end
498+
for p in Reactant.TracedUtils.get_paths(arg)
499+
if p[1] !== kernelargsym
500+
continue
501+
end
502+
503+
arg = arg.mlir_data
440504
arg = Reactant.TracedUtils.transpose_val(arg)
441505
push!(restys, MLIR.IR.type(arg))
442506
push!(mlir_args, arg)
507+
508+
# Get the allocation corresponding to which arg we're doing
509+
alloc = allocs[p[2]][1]
510+
511+
# we need to now compute the offset in bytes of the path
512+
julia_arg = allargs[p[2]]
513+
514+
offset = get_field_offset(typeof(julia_arg), p[3:end])
515+
MLIR.IR.block!(wrapbody) do
516+
ptr = MLIR.IR.result(MLIR.Dialects.llvm.getelementptr(alloc, MLIR.IR.Value[], res=llvmptr, elem_type=i8, rawConstantIndices=MLIR.IR.Attribute([Int32(offset)])), 1)
517+
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
518+
end
519+
443520
push!(
444521
aliases,
445522
MLIR.IR.Attribute(
@@ -453,30 +530,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
453530
),
454531
),
455532
)
456-
push!(wrapargs, MLIR.IR.argument(wrapbody, argidx))
533+
457534
argidx += 1
458-
trueidx += 1
459-
continue
460-
end
461-
462-
# TODO check for only integer and explicitly non cutraced types
463-
@show "Warning: using fallback for kernel argument type conversion for argument of type $(Core.Typeof(a)), if this contains a CuTracedArray this will segfault"
464-
MLIR.IR.block!(wrapbody) do
465-
argty = MLIR.IR.Type(MLIR.API.mlirLLVMFunctionTypeGetInput(gpu_function_type, trueidx-1))
466-
trueidx += 1
467-
c1 = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=MLIR.IR.Type(Int64), value=MLIR.IR.Attribute(1)), 1)
468-
alloc = MLIR.IR.result(MLIR.Dialects.llvm.alloca(c1; elem_type=MLIR.IR.Attribute(argty), res=MLIR.IR.Type(MLIR.API.mlirLLVMPointerTypeGet(ctx, 0))), 1)
469-
470-
sz = sizeof(a)
471-
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
472-
cdata = MLIR.IR.result(MLIR.Dialects.llvm.mlir_constant(; res=array_ty, value=MLIR.IR.DenseElementsAttribute(to_bytes(a))), 1)
473-
MLIR.Dialects.llvm.store(cdata, alloc)
474-
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
475-
push!(wrapargs, argres)
476535
end
477536
end
478537

479538
MLIR.IR.block!(wrapbody) do
539+
for arg in allocs
540+
if arg === nothing
541+
continue
542+
end
543+
alloc, argty = arg
544+
argres = MLIR.IR.result(MLIR.Dialects.llvm.load(alloc; res=argty), 1)
545+
push!(wrapargs, argres)
546+
end
480547
MLIR.Dialects.llvm.call(wrapargs, MLIR.IR.Value[]; callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)), op_bundle_sizes=MLIR.IR.Attribute(Int32[]))
481548
MLIR.Dialects.llvm.return_(nothing)
482549
end
@@ -500,8 +567,14 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
500567
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
501568
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases)
502569
)
503-
for (i, res) in enumerate(rarrays)
504-
res.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, i))
570+
571+
argidx = 1
572+
for arg in values(seen)
573+
if !(arg isa TracedRArray || arg isa TracedRNumber)
574+
continue
575+
end
576+
arg.mlir_data = Reactant.TracedUtils.transpose_val(MLIR.IR.result(call, argidx))
577+
argidx+=1
505578
end
506579
end
507580

@@ -546,6 +619,12 @@ Reactant.@reactant_overlay @noinline function CUDA.cufunction(
546619
return Core.Typeof(res)(f, res.entry)
547620
end
548621

622+
function Reactant.traced_type(
623+
::Type{A}, seen::ST, ::Val{mode}, track_numbers
624+
) where {A<:CuTracedArray,ST,mode}
625+
return A
626+
end
627+
549628
function Reactant.traced_type(
550629
::Type{A}, seen::ST, ::Val{mode}, track_numbers
551630
) where {T,N,A<:CUDA.CuArray{T,N},ST,mode}

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ function optimization_passes(; no_nan::Bool=false)
293293
)
294294
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
295295
return join(
296-
["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ','
296+
["inline{default-pipeline=canonicalize max-iterations=4}", "libdevice-funcs-raise", func_passes], ','
297297
)
298298
end
299299

src/utils.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ function should_rewrite_ft(@nospecialize(ft))
116116
if ft.name.name == Symbol("#launch_configuration")
117117
return false
118118
end
119+
if ft.name.name == Symbol("cudaconvert")
120+
return false
121+
end
119122
end
120123
end
121124
end
@@ -161,7 +164,11 @@ function should_rewrite_ft(@nospecialize(ft))
161164
ft <: typeof(Base.getproperty) ||
162165
ft <: typeof(Base.vect) ||
163166
ft <: typeof(Base.eltype) ||
164-
ft <: typeof(Base.argtail)
167+
ft <: typeof(Base.argtail) ||
168+
ft <: typeof(Base.identity) ||
169+
ft <: typeof(Base.print) ||
170+
ft <: typeof(Base.println) ||
171+
ft <: typeof(Adapt.adapt_structure)
165172
return false
166173
end
167174

test/integration/cuda.jl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ function smul!(x)
7272
end
7373

7474
@static if !Sys.isapple()
75-
76-
# Broken pending jll update
77-
@static if false
7875
@testset "Constant Op Kernel" begin
7976
oA = collect(1:1:64)
8077
A = Reactant.to_rarray(oA)
@@ -87,4 +84,37 @@ end
8784
end
8885
end
8986

87+
88+
function tuplef!(tup)
89+
tup[1][] += 2
90+
return nothing
91+
end
92+
93+
function tuplef2!(tup)
94+
tup[2][] *= tup[1]
95+
return nothing
96+
end
97+
98+
tuplef(a) = @cuda threads=1 tuplef!((a,))
99+
tuplef2(a) = @cuda threads=1 tuplef2!((5, a))
100+
101+
@static if !Sys.isapple()
102+
@testset "Structured Kernel Arguments" begin
103+
A = ConcreteRArray(fill(1))
104+
if CUDA.functional()
105+
@jit tuplef(A)
106+
@test all(Array(A) .≈ 3)
107+
else
108+
@code_hlo optimize = :before_kernel tuplef(A)
109+
end
110+
111+
A = ConcreteRArray(fill(1))
112+
if CUDA.functional()
113+
@jit tuplef2(A)
114+
@test all(Array(A) .≈ 5)
115+
else
116+
@code_hlo optimize = :before_kernel tuplef2(A)
117+
end
118+
119+
end
90120
end

0 commit comments

Comments
 (0)