Skip to content

Commit b47597f

Browse files
committed
WIP: adapt to sroa jll
1 parent ca98c17 commit b47597f

File tree

4 files changed

+63
-53
lines changed

4 files changed

+63
-53
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
565565
prepareRegistry(registry);
566566

567567
mlir::registerenzymePasses();
568-
regsiterenzymeXLAPasses();
568+
registerenzymexlaPasses();
569569

570570
// Register the standard passes we want.
571571
mlir::registerCSEPass();

deps/ReactantExtra/WORKSPACE

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
12+
ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -95,39 +95,39 @@ LLVM_TARGETS = select({
9595
}) + ["AArch64", "X86", "ARM"]
9696

9797
# Uncomment these lines to use a custom LLVM commit
98-
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
99-
# LLVM_SHA256 = ""
100-
# http_archive(
101-
# name = "llvm-raw",
102-
# build_file_content = "# empty",
103-
# sha256 = LLVM_SHA256,
104-
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
105-
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106-
# )
107-
#
108-
#
109-
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110-
# maybe(
111-
# http_archive,
112-
# name = "llvm_zlib",
113-
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114-
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115-
# strip_prefix = "zlib-ng-2.0.7",
116-
# urls = [
117-
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118-
# ],
119-
# )
120-
#
121-
# maybe(
122-
# http_archive,
123-
# name = "llvm_zstd",
124-
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125-
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126-
# strip_prefix = "zstd-1.5.2",
127-
# urls = [
128-
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129-
# ],
130-
# )
98+
LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f"
99+
LLVM_SHA256 = ""
100+
http_archive(
101+
name = "llvm-raw",
102+
build_file_content = "# empty",
103+
sha256 = LLVM_SHA256,
104+
strip_prefix = "llvm-project-" + LLVM_COMMIT,
105+
urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106+
)
107+
108+
109+
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110+
maybe(
111+
http_archive,
112+
name = "llvm_zlib",
113+
build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114+
sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115+
strip_prefix = "zlib-ng-2.0.7",
116+
urls = [
117+
"https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118+
],
119+
)
120+
121+
maybe(
122+
http_archive,
123+
name = "llvm_zstd",
124+
build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125+
sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126+
strip_prefix = "zstd-1.5.2",
127+
urls = [
128+
"https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129+
],
130+
)
131131

132132
http_archive(
133133
name = "jax",
@@ -138,9 +138,9 @@ http_archive(
138138
patches = ["@enzyme_ad//:patches/jax.patch"],
139139
)
140140

141-
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142-
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143-
XLA_SHA256 = ""
141+
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142+
# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143+
# XLA_SHA256 = ""
144144

145145
http_archive(
146146
name = "xla",

ext/ReactantCUDAExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
444444
prev = Any[func.f, args...]
445445
kernelargsym = gensym("kernelarg")
446446
Reactant.make_tracer(seen, prev, (kernelargsym,), Reactant.TracedTrack)
447+
@show prev
448+
@show Core.Typeof(prev)
449+
@show seen
447450
wrapper_tys = MLIR.IR.Type[]
448451
for arg in values(seen)
449452
if !(arg isa TracedRArray || arg isa TracedRNumber)

src/Compiler.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function create_result(
116116
end
117117

118118
# Optimization passes via transform dialect
119-
function optimization_passes(; no_nan::Bool=false)
119+
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
120120
transform_passes_list = [
121121
"patterns=compare_op_canon<16>",
122122
"transpose_transpose<16>",
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295295
",",
296296
)
297297
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
298-
return join(
299-
[
300-
"inline{default-pipeline=canonicalize max-iterations=4}",
301-
"libdevice-funcs-raise",
302-
func_passes,
303-
],
298+
passes = [
299+
"inline{default-pipeline=canonicalize max-iterations=4}"
300+
]
301+
if sroa
302+
push!(passes, "sroa-wrappers")
303+
push!(passes, "libdevice-funcs-raise")
304+
push!(passes, "canonicalize")
305+
end
306+
push!(passes, func_passes)
307+
return join(passes,
304308
',',
305309
)
306310
end
@@ -310,6 +314,8 @@ end
310314
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
311315

312316
function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
317+
@show pass_pipeline
318+
flush(stdout)
313319
pm = MLIR.IR.PassManager()
314320
MLIR.IR.enable_verifier!(pm, enable_verifier)
315321
opm = MLIR.IR.OpPassManager(pm)
@@ -374,9 +380,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
374380
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
375381

376382
opt_passes = optimization_passes(; no_nan)
383+
opt_passes2 = optimization_passes(; no_nan, sroa=false)
377384

378385
if optimize === :all
379-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
386+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
380387
run_pass_pipeline!(
381388
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
382389
)
@@ -387,14 +394,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
387394
"canonicalize",
388395
"remove-unnecessary-enzyme-ops",
389396
"enzyme-simplify-math",
390-
opt_passes,
397+
opt_passes2,
391398
kern,
392399
],
393400
',',
394401
),
395402
)
396403
elseif optimize === :before_kernel
397-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
404+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
398405
run_pass_pipeline!(
399406
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
400407
)
@@ -405,13 +412,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
405412
"canonicalize",
406413
"remove-unnecessary-enzyme-ops",
407414
"enzyme-simplify-math",
408-
opt_passes,
415+
opt_passes2,
409416
],
410417
',',
411418
),
412419
)
413420
elseif optimize === :no_enzyme
414-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
421+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
415422
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
416423
run_pass_pipeline!(
417424
mod,
@@ -420,7 +427,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
420427
"canonicalize",
421428
"remove-unnecessary-enzyme-ops",
422429
"enzyme-simplify-math",
423-
opt_passes,
430+
opt_passes2,
424431
],
425432
',',
426433
),
@@ -449,14 +456,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
449456
"canonicalize",
450457
"remove-unnecessary-enzyme-ops",
451458
"enzyme-simplify-math",
452-
opt_passes,
459+
opt_passes2,
453460
kern,
454461
],
455462
',',
456463
),
457464
)
458465
elseif optimize === :before_enzyme
459-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
466+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
460467
run_pass_pipeline!(
461468
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
462469
)

0 commit comments

Comments
 (0)