-
Notifications
You must be signed in to change notification settings - Fork 23
WIP: adapt to sroa jll #521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fdcb111
4b2bd6b
68afe86
843c462
4f3c68b
149f24b
cb86b3c
9c1381d
e68b0dd
060f245
46e8210
4755292
63c6c25
522221e
0aaebc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +0,0 @@ | ||
ENZYMEXLA_COMMIT = "049a05abfaf23abee646ad26834bb8725c348f51" | ||
ENZYMEXLA_SHA256 = "" | ||
|
||
NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" | ||
NSYNC_SHA256 = "" | ||
|
||
RULES_CC_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c" | ||
RULES_CC_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4" | ||
|
||
RULES_PYTHON_VERSION = "0.34.0" | ||
RULES_PYTHON_SHA256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618" | ||
|
||
UPB_COMMIT = "9effcbcb27f0a665f9f345030188c0b291e32482" | ||
UPB_SHA256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454" | ||
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -239,9 +239,12 @@ function Adapt.adapt_structure( | |||||
) | ||||||
end | ||||||
|
||||||
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) | ||||||
function recudaconvert(arg) | ||||||
return adapt(ReactantKernelAdaptor(), arg) | ||||||
end | ||||||
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg) | ||||||
return recudaconvert(arg) | ||||||
end | ||||||
|
||||||
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N} | ||||||
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs) | ||||||
|
@@ -425,6 +428,7 @@ function get_field_offset(T::Type, path) | |||||
offset = 0 | ||||||
current_type = T | ||||||
|
||||||
|
||||||
for field in path | ||||||
# Get the field index | ||||||
field_idx = if field isa Integer | ||||||
|
@@ -440,18 +444,22 @@ function get_field_offset(T::Type, path) | |||||
end | ||||||
|
||||||
# Add the offset of this field | ||||||
offset += fieldoffset(current_type, field_idx) | ||||||
toffset = fieldoffset(current_type, field_idx) | ||||||
tcurrent_type = fieldtype(current_type, field_idx) | ||||||
offset += toffset | ||||||
|
||||||
# Update current_type to the field's type for next iteration | ||||||
current_type = fieldtype(current_type, field_idx) | ||||||
current_type = tcurrent_type | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
|
||||||
return offset | ||||||
end | ||||||
|
||||||
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | ||||||
args...; | ||||||
convert=Val(false), | ||||||
convert=Val(true), | ||||||
blocks::CuDim=1, | ||||||
threads::CuDim=1, | ||||||
cooperative::Bool=false, | ||||||
|
@@ -461,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||
blockdim = CUDA.CuDim3(blocks) | ||||||
threaddim = CUDA.CuDim3(threads) | ||||||
|
||||||
if convert == Val(true) | ||||||
args = recudaconvert.(args) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
end | ||||||
|
||||||
mlir_args = MLIR.IR.Value[] | ||||||
restys = MLIR.IR.Type[] | ||||||
aliases = MLIR.IR.Attribute[] | ||||||
|
@@ -578,6 +590,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||
push!(restys, MLIR.IR.type(arg)) | ||||||
push!(mlir_args, arg) | ||||||
|
||||||
push!( | ||||||
aliases, | ||||||
MLIR.IR.Attribute( | ||||||
MLIR.API.stablehloOutputOperandAliasGet( | ||||||
MLIR.IR.context(), | ||||||
length(wrapper_tys) == 1 ? 0 : 1, | ||||||
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), | ||||||
argidx - 1, | ||||||
0, | ||||||
C_NULL, | ||||||
), | ||||||
), | ||||||
) | ||||||
|
||||||
for p in paths | ||||||
if p[1] !== kernelargsym | ||||||
continue | ||||||
|
@@ -602,20 +628,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||
) | ||||||
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr) | ||||||
end | ||||||
|
||||||
push!( | ||||||
aliases, | ||||||
MLIR.IR.Attribute( | ||||||
MLIR.API.stablehloOutputOperandAliasGet( | ||||||
MLIR.IR.context(), | ||||||
length(wrapper_tys) == 1 ? 0 : 1, | ||||||
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1), | ||||||
argidx - 1, | ||||||
0, | ||||||
C_NULL, | ||||||
), | ||||||
), | ||||||
) | ||||||
end | ||||||
argidx += 1 | ||||||
end | ||||||
|
@@ -650,6 +662,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( | |||||
end | ||||||
|
||||||
location = MLIR.IR.Location() | ||||||
@assert length(restys) == length(aliases) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
call = MLIR.Dialects.enzymexla.kernel_call( | ||||||
blk_operands..., | ||||||
mlir_args; | ||||||
|
@@ -786,6 +799,11 @@ function __init__() | |||||
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1) | ||||||
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2) | ||||||
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3) | ||||||
ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false) | ||||||
if ptr4 === nothing | ||||||
ptr4 = C_NULL | ||||||
end | ||||||
Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4) | ||||||
end | ||||||
return nothing | ||||||
end | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -116,7 +116,7 @@ function create_result( | |||||||||||||
end | ||||||||||||||
|
||||||||||||||
# Optimization passes via transform dialect | ||||||||||||||
function optimization_passes(; no_nan::Bool=false) | ||||||||||||||
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false) | ||||||||||||||
transform_passes_list = [ | ||||||||||||||
"patterns=compare_op_canon<16>", | ||||||||||||||
"transpose_transpose<16>", | ||||||||||||||
|
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false) | |||||||||||||
",", | ||||||||||||||
) | ||||||||||||||
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",") | ||||||||||||||
return join( | ||||||||||||||
[ | ||||||||||||||
"inline{default-pipeline=canonicalize max-iterations=4}", | ||||||||||||||
"libdevice-funcs-raise", | ||||||||||||||
func_passes, | ||||||||||||||
], | ||||||||||||||
passes = [ | ||||||||||||||
"inline{default-pipeline=canonicalize max-iterations=4}" | ||||||||||||||
] | ||||||||||||||
Comment on lines
+298
to
+300
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
if sroa | ||||||||||||||
push!(passes, "sroa-wrappers") | ||||||||||||||
push!(passes, "libdevice-funcs-raise") | ||||||||||||||
push!(passes, "canonicalize") | ||||||||||||||
Comment on lines
+302
to
+304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
end | ||||||||||||||
push!(passes, func_passes) | ||||||||||||||
return join(passes, | ||||||||||||||
',', | ||||||||||||||
) | ||||||||||||||
Comment on lines
+307
to
309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
end | ||||||||||||||
|
@@ -351,6 +355,8 @@ end | |||||||||||||
const cuLaunch = Ref{UInt}(0) | ||||||||||||||
const cuFunc = Ref{UInt}(0) | ||||||||||||||
const cuModule = Ref{UInt}(0) | ||||||||||||||
const cuSync = Ref{UInt}(0) | ||||||||||||||
const DEBUG_KERNEL = Ref{Bool}(false) | ||||||||||||||
|
||||||||||||||
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false) | ||||||||||||||
# Explicitly don't use block! to avoid creating a closure, which creates | ||||||||||||||
|
@@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||
if isdefined(Reactant_jll, :ptxas_path) | ||||||||||||||
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))] | ||||||||||||||
end | ||||||||||||||
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" | ||||||||||||||
if DEBUG_KERNEL[] | ||||||||||||||
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult") | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||
@assert curesulthandler !== nothing | ||||||||||||||
curesulthandler = Base.reinterpret(UInt, curesulthandler) | ||||||||||||||
kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce" | ||||||||||||||
else | ||||||||||||||
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce" | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
opt_passes = optimization_passes(; no_nan) | ||||||||||||||
opt_passes = optimization_passes(; no_nan, sroa=true) | ||||||||||||||
opt_passes2 = optimization_passes(; no_nan, sroa=false) | ||||||||||||||
|
||||||||||||||
if optimize === :all | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) | ||||||||||||||
run_pass_pipeline!( | ||||||||||||||
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false | ||||||||||||||
) | ||||||||||||||
|
@@ -395,14 +409,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||
"canonicalize", | ||||||||||||||
"remove-unnecessary-enzyme-ops", | ||||||||||||||
"enzyme-simplify-math", | ||||||||||||||
opt_passes, | ||||||||||||||
opt_passes2, | ||||||||||||||
kern, | ||||||||||||||
], | ||||||||||||||
',', | ||||||||||||||
), | ||||||||||||||
) | ||||||||||||||
elseif optimize === :before_kernel | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) | ||||||||||||||
run_pass_pipeline!( | ||||||||||||||
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false | ||||||||||||||
) | ||||||||||||||
|
@@ -413,13 +427,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||
"canonicalize", | ||||||||||||||
"remove-unnecessary-enzyme-ops", | ||||||||||||||
"enzyme-simplify-math", | ||||||||||||||
opt_passes, | ||||||||||||||
opt_passes2, | ||||||||||||||
], | ||||||||||||||
',', | ||||||||||||||
), | ||||||||||||||
) | ||||||||||||||
elseif optimize === :no_enzyme | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) | ||||||||||||||
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false) | ||||||||||||||
run_pass_pipeline!( | ||||||||||||||
mod, | ||||||||||||||
|
@@ -428,7 +442,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||
"canonicalize", | ||||||||||||||
"remove-unnecessary-enzyme-ops", | ||||||||||||||
"enzyme-simplify-math", | ||||||||||||||
opt_passes, | ||||||||||||||
opt_passes2, | ||||||||||||||
], | ||||||||||||||
',', | ||||||||||||||
), | ||||||||||||||
|
@@ -457,14 +471,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan:: | |||||||||||||
"canonicalize", | ||||||||||||||
"remove-unnecessary-enzyme-ops", | ||||||||||||||
"enzyme-simplify-math", | ||||||||||||||
opt_passes, | ||||||||||||||
opt_passes2, | ||||||||||||||
kern, | ||||||||||||||
], | ||||||||||||||
',', | ||||||||||||||
), | ||||||||||||||
) | ||||||||||||||
elseif optimize === :before_enzyme | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) | ||||||||||||||
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ",")) | ||||||||||||||
run_pass_pipeline!( | ||||||||||||||
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false | ||||||||||||||
) | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError) | |||||
) | ||||||
end | ||||||
|
||||||
function make_tracer( | ||||||
seen, | ||||||
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||
@nospecialize(path), | ||||||
mode; | ||||||
toscalar=false, | ||||||
tobatch=nothing, | ||||||
track_numbers=(), | ||||||
kwargs..., | ||||||
) | ||||||
return prev | ||||||
end | ||||||
append_path(path, i) = (path..., i) | ||||||
|
||||||
function make_tracer( | ||||||
|
@@ -590,7 +602,7 @@ function make_tracer( | |||||
if mode == ArrayToConcrete | ||||||
return ConcreteRNumber(prev) | ||||||
else | ||||||
if mode == TracedTrack | ||||||
if mode == TracedTrack || mode == NoStopTracedTrack | ||||||
res = TracedRNumber{RT}( | ||||||
(path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data | ||||||
) | ||||||
|
@@ -638,7 +650,7 @@ end | |||||
function make_tracer( | ||||||
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... | ||||||
) where {RT<:Array} | ||||||
if haskey(seen, prev) | ||||||
if mode != NoStopTracedTrack && haskey(seen, prev) | ||||||
return seen[prev] | ||||||
end | ||||||
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive | ||||||
|
@@ -699,7 +711,7 @@ function make_tracer( | |||||
end | ||||||
|
||||||
function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) | ||||||
if haskey(seen, prev) | ||||||
if mode != NoStopTracedTrack && haskey(seen, prev) | ||||||
return seen[prev] | ||||||
end | ||||||
prev2 = prev.contents | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶