Skip to content

Commit a1f93a0

Browse files
authored
WIP: adapt to sroa jll (#521)
* WIP: adapt to sroa jll * fixup * fix * fixup * fixup * rmprint * fix patch * Update WORKSPACE * Update Project.toml * adapt to upstream properly * cuconvert * Update Project.toml * fix ci errs * alias * cuda test
1 parent 4fd0492 commit a1f93a0

File tree

10 files changed

+107
-71
lines changed

10 files changed

+107
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ PythonCall = "0.9"
6767
Random = "1.10"
6868
Random123 = "1.7"
6969
ReactantCore = "0.1.3"
70-
Reactant_jll = "0.0.39"
70+
Reactant_jll = "0.0.41"
7171
Scratch = "1.2"
7272
SpecialFunctions = "2"
7373
Statistics = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
119119
}
120120
}
121121

122+
extern "C" void ReactantHandleCuResult(uint32_t curesult) {
123+
if (curesult != 0) {
124+
std::string err = "Bad Cuda Result = " + std::to_string(curesult);
125+
if (ReactantThrowError) {
126+
ReactantThrowError(err.c_str());
127+
}
128+
}
129+
}
130+
122131
// MLIR C-API extras
123132
#pragma region MLIR Extra
124133
extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx,
@@ -599,7 +608,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
599608
prepareRegistry(registry);
600609

601610
mlir::registerenzymePasses();
602-
regsiterenzymeXLAPasses();
611+
registerenzymexlaPasses();
603612

604613
// Register the standard passes we want.
605614
mlir::registerCSEPass();

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ cc_library(
436436
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
437437
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
438438
"-Wl,-exported_symbol,_ReactantThrowError",
439+
"-Wl,-exported_symbol,_ReactantHandleCuResult",
439440
]}),
440441
deps = [
441442
"@enzyme//:EnzymeMLIR",

deps/ReactantExtra/WORKSPACE

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
12+
ENZYMEXLA_COMMIT = "12dc0bf6932befe236eacfcd19ca9522f870f7b9"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -54,9 +54,6 @@ XLA_PATCHES = XLA_PATCHES + [
5454
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/backends/cpu/runtime/thunk_executor.h
5555
""",
5656
"""
57-
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/stream_executor/host/host_kernel.cc
58-
""",
59-
"""
6057
sed -i.bak0 "s/__cpp_lib_hardware_interference_size/HW_INTERFERENCE_SIZE/g" xla/tsl/concurrency/async_value_ref.h
6158
""",
6259
"""
@@ -95,7 +92,7 @@ LLVM_TARGETS = select({
9592
}) + ["AArch64", "X86", "ARM"]
9693

9794
# Uncomment these lines to use a custom LLVM commit
98-
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
95+
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
9996
# LLVM_SHA256 = ""
10097
# http_archive(
10198
# name = "llvm-raw",
@@ -138,9 +135,7 @@ http_archive(
138135
patches = ["@enzyme_ad//:patches/jax.patch"],
139136
)
140137

141-
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142-
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143-
XLA_SHA256 = ""
138+
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
144139

145140
http_archive(
146141
name = "xla",

deps/ReactantExtra/workspace.bzl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +0,0 @@
1-
ENZYMEXLA_COMMIT = "049a05abfaf23abee646ad26834bb8725c348f51"
2-
ENZYMEXLA_SHA256 = ""
3-
4-
NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
5-
NSYNC_SHA256 = ""
6-
7-
RULES_CC_COMMIT = "c8c38f8c710cbbf834283e4777916b68261b359c"
8-
RULES_CC_SHA256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4"
9-
10-
RULES_PYTHON_VERSION = "0.34.0"
11-
RULES_PYTHON_SHA256 = "778aaeab3e6cfd56d681c89f5c10d7ad6bf8d2f1a72de9de55b23081b2d31618"
12-
13-
UPB_COMMIT = "9effcbcb27f0a665f9f345030188c0b291e32482"
14-
UPB_SHA256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454"

ext/ReactantCUDAExt.jl

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,12 @@ function Adapt.adapt_structure(
239239
)
240240
end
241241

242-
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
242+
function recudaconvert(arg)
243243
return adapt(ReactantKernelAdaptor(), arg)
244244
end
245+
Reactant.@reactant_overlay @noinline function CUDA.cudaconvert(arg)
246+
return recudaconvert(arg)
247+
end
245248

246249
function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRArray{T,N}) where {T,N}
247250
res = CuTracedArray{T,N,CUDA.AS.Global,size(xs)}(xs)
@@ -425,6 +428,7 @@ function get_field_offset(T::Type, path)
425428
offset = 0
426429
current_type = T
427430

431+
428432
for field in path
429433
# Get the field index
430434
field_idx = if field isa Integer
@@ -440,18 +444,22 @@ function get_field_offset(T::Type, path)
440444
end
441445

442446
# Add the offset of this field
443-
offset += fieldoffset(current_type, field_idx)
447+
toffset = fieldoffset(current_type, field_idx)
448+
tcurrent_type = fieldtype(current_type, field_idx)
449+
offset += toffset
444450

445451
# Update current_type to the field's type for next iteration
446-
current_type = fieldtype(current_type, field_idx)
452+
current_type = tcurrent_type
453+
447454
end
455+
448456

449457
return offset
450458
end
451459

452460
Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
453461
args...;
454-
convert=Val(false),
462+
convert=Val(true),
455463
blocks::CuDim=1,
456464
threads::CuDim=1,
457465
cooperative::Bool=false,
@@ -461,6 +469,10 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
461469
blockdim = CUDA.CuDim3(blocks)
462470
threaddim = CUDA.CuDim3(threads)
463471

472+
if convert == Val(true)
473+
args = recudaconvert.(args)
474+
end
475+
464476
mlir_args = MLIR.IR.Value[]
465477
restys = MLIR.IR.Type[]
466478
aliases = MLIR.IR.Attribute[]
@@ -578,6 +590,20 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
578590
push!(restys, MLIR.IR.type(arg))
579591
push!(mlir_args, arg)
580592

593+
push!(
594+
aliases,
595+
MLIR.IR.Attribute(
596+
MLIR.API.stablehloOutputOperandAliasGet(
597+
MLIR.IR.context(),
598+
length(wrapper_tys) == 1 ? 0 : 1,
599+
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
600+
argidx - 1,
601+
0,
602+
C_NULL,
603+
),
604+
),
605+
)
606+
581607
for p in paths
582608
if p[1] !== kernelargsym
583609
continue
@@ -602,20 +628,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
602628
)
603629
MLIR.Dialects.llvm.store(MLIR.IR.argument(wrapbody, argidx), ptr)
604630
end
605-
606-
push!(
607-
aliases,
608-
MLIR.IR.Attribute(
609-
MLIR.API.stablehloOutputOperandAliasGet(
610-
MLIR.IR.context(),
611-
length(wrapper_tys) == 1 ? 0 : 1,
612-
length(wrapper_tys) == 1 ? C_NULL : Ref{Int64}(argidx - 1),
613-
argidx - 1,
614-
0,
615-
C_NULL,
616-
),
617-
),
618-
)
619631
end
620632
argidx += 1
621633
end
@@ -650,6 +662,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
650662
end
651663

652664
location = MLIR.IR.Location()
665+
@assert length(restys) == length(aliases)
653666
call = MLIR.Dialects.enzymexla.kernel_call(
654667
blk_operands...,
655668
mlir_args;
@@ -786,6 +799,11 @@ function __init__()
786799
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
787800
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
788801
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
802+
ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false)
803+
if ptr4 === nothing
804+
ptr4 = C_NULL
805+
end
806+
Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4)
789807
end
790808
return nothing
791809
end

src/Compiler.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function create_result(
116116
end
117117

118118
# Optimization passes via transform dialect
119-
function optimization_passes(; no_nan::Bool=false)
119+
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
120120
transform_passes_list = [
121121
"patterns=compare_op_canon<16>",
122122
"transpose_transpose<16>",
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295295
",",
296296
)
297297
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
298-
return join(
299-
[
300-
"inline{default-pipeline=canonicalize max-iterations=4}",
301-
"libdevice-funcs-raise",
302-
func_passes,
303-
],
298+
passes = [
299+
"inline{default-pipeline=canonicalize max-iterations=4}"
300+
]
301+
if sroa
302+
push!(passes, "sroa-wrappers")
303+
push!(passes, "libdevice-funcs-raise")
304+
push!(passes, "canonicalize")
305+
end
306+
push!(passes, func_passes)
307+
return join(passes,
304308
',',
305309
)
306310
end
@@ -351,6 +355,8 @@ end
351355
const cuLaunch = Ref{UInt}(0)
352356
const cuFunc = Ref{UInt}(0)
353357
const cuModule = Ref{UInt}(0)
358+
const cuSync = Ref{UInt}(0)
359+
const DEBUG_KERNEL = Ref{Bool}(false)
354360

355361
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
356362
# Explicitly don't use block! to avoid creating a closure, which creates
@@ -379,12 +385,20 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
379385
if isdefined(Reactant_jll, :ptxas_path)
380386
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
381387
end
382-
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
388+
if DEBUG_KERNEL[]
389+
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult")
390+
@assert curesulthandler !== nothing
391+
curesulthandler = Base.reinterpret(UInt, curesulthandler)
392+
kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce"
393+
else
394+
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
395+
end
383396

384-
opt_passes = optimization_passes(; no_nan)
397+
opt_passes = optimization_passes(; no_nan, sroa=true)
398+
opt_passes2 = optimization_passes(; no_nan, sroa=false)
385399

386400
if optimize === :all
387-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
401+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
388402
run_pass_pipeline!(
389403
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
390404
)
@@ -395,14 +409,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
395409
"canonicalize",
396410
"remove-unnecessary-enzyme-ops",
397411
"enzyme-simplify-math",
398-
opt_passes,
412+
opt_passes2,
399413
kern,
400414
],
401415
',',
402416
),
403417
)
404418
elseif optimize === :before_kernel
405-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
419+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
406420
run_pass_pipeline!(
407421
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
408422
)
@@ -413,13 +427,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
413427
"canonicalize",
414428
"remove-unnecessary-enzyme-ops",
415429
"enzyme-simplify-math",
416-
opt_passes,
430+
opt_passes2,
417431
],
418432
',',
419433
),
420434
)
421435
elseif optimize === :no_enzyme
422-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
436+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
423437
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
424438
run_pass_pipeline!(
425439
mod,
@@ -428,7 +442,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
428442
"canonicalize",
429443
"remove-unnecessary-enzyme-ops",
430444
"enzyme-simplify-math",
431-
opt_passes,
445+
opt_passes2,
432446
],
433447
',',
434448
),
@@ -457,14 +471,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
457471
"canonicalize",
458472
"remove-unnecessary-enzyme-ops",
459473
"enzyme-simplify-math",
460-
opt_passes,
474+
opt_passes2,
461475
kern,
462476
],
463477
',',
464478
),
465479
)
466480
elseif optimize === :before_enzyme
467-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
481+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
468482
run_pass_pipeline!(
469483
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
470484
)

src/Tracing.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
318318
)
319319
end
320320

321+
function make_tracer(
322+
seen,
323+
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}),
324+
@nospecialize(path),
325+
mode;
326+
toscalar=false,
327+
tobatch=nothing,
328+
track_numbers=(),
329+
kwargs...,
330+
)
331+
return prev
332+
end
321333
append_path(path, i) = (path..., i)
322334

323335
function make_tracer(
@@ -590,7 +602,7 @@ function make_tracer(
590602
if mode == ArrayToConcrete
591603
return ConcreteRNumber(prev)
592604
else
593-
if mode == TracedTrack
605+
if mode == TracedTrack || mode == NoStopTracedTrack
594606
res = TracedRNumber{RT}(
595607
(path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data
596608
)
@@ -638,7 +650,7 @@ end
638650
function make_tracer(
639651
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
640652
) where {RT<:Array}
641-
if haskey(seen, prev)
653+
if mode != NoStopTracedTrack && haskey(seen, prev)
642654
return seen[prev]
643655
end
644656
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
@@ -699,7 +711,7 @@ function make_tracer(
699711
end
700712

701713
function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
702-
if haskey(seen, prev)
714+
if mode != NoStopTracedTrack && haskey(seen, prev)
703715
return seen[prev]
704716
end
705717
prev2 = prev.contents

0 commit comments

Comments
 (0)