Skip to content

WIP: adapt to sroa jll #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.39"
Reactant_jll = "0.0.41"
Scratch = "1.2"
SpecialFunctions = "2"
Statistics = "1.10"
Expand Down
11 changes: 10 additions & 1 deletion deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
}
}

extern "C" void ReactantHandleCuResult(uint32_t curesult) {
if (curesult != 0) {
std::string err = "Bad Cuda Result = " + std::to_string(curesult);
if (ReactantThrowError) {
ReactantThrowError(err.c_str());
}
}
}

// MLIR C-API extras
#pragma region MLIR Extra
extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx,
Expand Down Expand Up @@ -599,7 +608,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
prepareRegistry(registry);

mlir::registerenzymePasses();
regsiterenzymeXLAPasses();
registerenzymexlaPasses();

// Register the standard passes we want.
mlir::registerCSEPass();
Expand Down
1 change: 1 addition & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ cc_library(
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
"-Wl,-exported_symbol,_ReactantThrowError",
"-Wl,-exported_symbol,_ReactantHandleCuResult",
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down
11 changes: 3 additions & 8 deletions deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
ENZYMEXLA_COMMIT = "12dc0bf6932befe236eacfcd19ca9522f870f7b9"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down Expand Up @@ -54,9 +54,6 @@ XLA_PATCHES = XLA_PATCHES + [
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h
""",
"""
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc
""",
"""
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h
""",
"""
Expand Down Expand Up @@ -95,7 +92,7 @@ LLVM_TARGETS = select({
}) + ["AArch64", "X86", "ARM"]

# Uncomment these lines to use a custom LLVM commit
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
# LLVM_SHA256 = ""
# http_archive(
# name = "llvm-raw",
Expand Down Expand Up @@ -138,9 +135,7 @@ http_archive(
patches = ["@enzyme_ad//:patches/jax.patch"],
)

# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
XLA_SHA256 = ""
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")

http_archive(
name = "xla",
Expand Down
14 changes: 0 additions & 14 deletions deps/ReactantExtra/workspace.bzl
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
ENZYMEXLA_COMMIT = "049a05abfaf23abee646ad26834bb8725c348f51"
ENZYMEXLA_SHA256 = ""

NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
NSYNC_SHA256 = ""

RULES_CC_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c"
RULES_CC_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4"

RULES_PYTHON_VERSION = "0.34.0"
RULES_PYTHON_SHA256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618"

UPB_COMMIT = "9effcbcb27f0a665f9f345030188c0b291e32482"
UPB_SHA256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454"
54 changes: 36 additions & 18 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,12 @@ function Adapt.adapt_structure(
)
end

Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
function recudaconvert(arg)
return adapt(ReactantKernelAdaptor(), arg)
end
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
return recudaconvert(arg)
end

function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
Expand Down Expand Up @@ -425,6 +428,7 @@ function get_field_offset(T::Type, path)
offset = 0
current_type = T


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

for field in path
# Get the field index
field_idx = if field isa Integer
Expand All @@ -440,18 +444,22 @@ function get_field_offset(T::Type, path)
end

# Add the offset of this field
offset += fieldoffset(current_type, field_idx)
toffset = fieldoffset(current_type, field_idx)
tcurrent_type = fieldtype(current_type, field_idx)
offset += toffset

# Update current_type to the field's type for next iteration
current_type = fieldtype(current_type, field_idx)
current_type = tcurrent_type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change


return offset
end

Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
args...;
convert=Val(false),
convert=Val(true),
blocks::CuDim=1,
threads::CuDim=1,
cooperative::Bool=false,
Expand All @@ -461,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
blockdim = CUDA.CuDim3(blocks)
threaddim = CUDA.CuDim3(threads)

if convert == Val(true)
args = recudaconvert.(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
args = recudaconvert.(args)
args = recudaconvert.(args)

end

mlir_args = MLIR.IR.Value[]
restys = MLIR.IR.Type[]
aliases = MLIR.IR.Attribute[]
Expand Down Expand Up @@ -578,6 +590,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
push!(restys, MLIR.IR.type(arg))
push!(mlir_args, arg)

push!(
aliases,
MLIR.IR.Attribute(
MLIR.API.stablehloOutputOperandAliasGet(
MLIR.IR.context(),
length(wrapper_tys) == 1 ? 0 : 1,
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
argidx - 1,
0,
C_NULL,
),
),
)

for p in paths
if p[1] !== kernelargsym
continue
Expand All @@ -602,20 +628,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
)
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
end

push!(
aliases,
MLIR.IR.Attribute(
MLIR.API.stablehloOutputOperandAliasGet(
MLIR.IR.context(),
length(wrapper_tys) == 1 ? 0 : 1,
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
argidx - 1,
0,
C_NULL,
),
),
)
end
argidx += 1
end
Expand Down Expand Up @@ -650,6 +662,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
end

location = MLIR.IR.Location()
@assert length(restys) == length(aliases)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@assert length(restys) == length(aliases)
@assert length(restys) == length(aliases)

call = MLIR.Dialects.enzymexla.kernel_call(
blk_operands...,
mlir_args;
Expand Down Expand Up @@ -786,6 +799,11 @@ function __init__()
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false)
if ptr4 === nothing
ptr4 = C_NULL
end
Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4)
end
return nothing
end
Expand Down
48 changes: 31 additions & 17 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function create_result(
end

# Optimization passes via transform dialect
function optimization_passes(; no_nan::Bool=false)
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down Expand Up @@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
",",
)
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
return join(
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"libdevice-funcs-raise",
func_passes,
],
passes = [
"inline{default-pipeline=canonicalize max-iterations=4}"
]
Comment on lines +298 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
passes = [
"inline{default-pipeline=canonicalize max-iterations=4}"
]
passes = ["inline{default-pipeline=canonicalize max-iterations=4}"]

if sroa
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
Comment on lines +302 to +304
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")
push!(passes, "sroa-wrappers")
push!(passes, "libdevice-funcs-raise")
push!(passes, "canonicalize")

end
push!(passes, func_passes)
return join(passes,
',',
)
Comment on lines +307 to 309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return join(passes,
',',
)
return join(passes, ',')

end
Expand Down Expand Up @@ -351,6 +355,8 @@ end
const cuLaunch = Ref{UInt}(0)
const cuFunc = Ref{UInt}(0)
const cuModule = Ref{UInt}(0)
const cuSync = Ref{UInt}(0)
const DEBUG_KERNEL = Ref{Bool}(false)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
# Explicitly don't use block! to avoid creating a closure, which creates
Expand Down Expand Up @@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
if DEBUG_KERNEL[]
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult")
curesulthandler = XLA.Libdl.dlsym(
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
)

@assert curesulthandler !== nothing
curesulthandler = Base.reinterpret(UInt, curesulthandler)
kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce"
else
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
end

opt_passes = optimization_passes(; no_nan)
opt_passes = optimization_passes(; no_nan, sroa=true)
opt_passes2 = optimization_passes(; no_nan, sroa=false)

if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand All @@ -395,14 +409,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
kern,
],
',',
),
)
elseif optimize === :before_kernel
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand All @@ -413,13 +427,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
],
',',
),
)
elseif optimize === :no_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod,
Expand All @@ -428,7 +442,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
],
',',
),
Expand Down Expand Up @@ -457,14 +471,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
"canonicalize",
"remove-unnecessary-enzyme-ops",
"enzyme-simplify-math",
opt_passes,
opt_passes2,
kern,
],
',',
),
)
elseif optimize === :before_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
Expand Down
18 changes: 15 additions & 3 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
)
end

function make_tracer(
seen,
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}),
@nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}),

@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
track_numbers=(),
kwargs...,
)
return prev
end
append_path(path, i) = (path..., i)

function make_tracer(
Expand Down Expand Up @@ -590,7 +602,7 @@ function make_tracer(
if mode == ArrayToConcrete
return ConcreteRNumber(prev)
else
if mode == TracedTrack
if mode == TracedTrack || mode == NoStopTracedTrack
res = TracedRNumber{RT}(
(path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data
)
Expand Down Expand Up @@ -638,7 +650,7 @@ end
function make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
) where {RT<:Array}
if haskey(seen, prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
Expand Down Expand Up @@ -699,7 +711,7 @@ function make_tracer(
end

function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
if haskey(seen, prev)
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
prev2 = prev.contents
Expand Down
Loading
Loading