Skip to content

Commit f7be948

Browse files
committed
WIP: adapt to sroa jll
1 parent c664414 commit f7be948

File tree

3 files changed

+60
-53
lines changed

3 files changed

+60
-53
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
565565
prepareRegistry(registry);
566566

567567
mlir::registerenzymePasses();
568-
regsiterenzymeXLAPasses();
568+
registerenzymexlaPasses();
569569

570570
// Register the standard passes we want.
571571
mlir::registerCSEPass();

deps/ReactantExtra/WORKSPACE

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "4d7c91e5d71fc98b901f7aa40b6deacb449fa873"
12+
ENZYMEXLA_COMMIT = "3b217bbfd5680ecd88c20285fe7b5693c541fa8b"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -95,39 +95,39 @@ LLVM_TARGETS = select({
9595
}) + ["AArch64", "X86", "ARM"]
9696

9797
# Uncomment these lines to use a custom LLVM commit
98-
# LLVM_COMMIT = "023dbbaa3eeddd537e2376aa7355e3bcef618908"
99-
# LLVM_SHA256 = ""
100-
# http_archive(
101-
# name = "llvm-raw",
102-
# build_file_content = "# empty",
103-
# sha256 = LLVM_SHA256,
104-
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
105-
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106-
# )
107-
#
108-
#
109-
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110-
# maybe(
111-
# http_archive,
112-
# name = "llvm_zlib",
113-
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114-
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115-
# strip_prefix = "zlib-ng-2.0.7",
116-
# urls = [
117-
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118-
# ],
119-
# )
120-
#
121-
# maybe(
122-
# http_archive,
123-
# name = "llvm_zstd",
124-
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125-
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126-
# strip_prefix = "zstd-1.5.2",
127-
# urls = [
128-
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129-
# ],
130-
# )
98+
LLVM_COMMIT = "9b4bf06be33f0fe6a4c487bb9244d8c0f6acab3f"
99+
LLVM_SHA256 = ""
100+
http_archive(
101+
name = "llvm-raw",
102+
build_file_content = "# empty",
103+
sha256 = LLVM_SHA256,
104+
strip_prefix = "llvm-project-" + LLVM_COMMIT,
105+
urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106+
)
107+
108+
109+
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110+
maybe(
111+
http_archive,
112+
name = "llvm_zlib",
113+
build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114+
sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115+
strip_prefix = "zlib-ng-2.0.7",
116+
urls = [
117+
"https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118+
],
119+
)
120+
121+
maybe(
122+
http_archive,
123+
name = "llvm_zstd",
124+
build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125+
sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126+
strip_prefix = "zstd-1.5.2",
127+
urls = [
128+
"https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129+
],
130+
)
131131

132132
http_archive(
133133
name = "jax",
@@ -138,9 +138,9 @@ http_archive(
138138
patches = ["@enzyme_ad//:patches/jax.patch"],
139139
)
140140

141-
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142-
XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143-
XLA_SHA256 = ""
141+
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142+
# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143+
# XLA_SHA256 = ""
144144

145145
http_archive(
146146
name = "xla",

src/Compiler.jl

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ function create_result(
116116
end
117117

118118
# Optimization passes via transform dialect
119-
function optimization_passes(; no_nan::Bool=false)
119+
function optimization_passes(; no_nan::Bool=false, sroa::Bool=false)
120120
transform_passes_list = [
121121
"patterns=compare_op_canon<16>",
122122
"transpose_transpose<16>",
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295295
",",
296296
)
297297
func_passes = join(["canonicalize", "cse", "canonicalize", transform_passes], ",")
298-
return join(
299-
[
300-
"inline{default-pipeline=canonicalize max-iterations=4}",
301-
"libdevice-funcs-raise",
302-
func_passes,
303-
],
298+
passes = [
299+
"inline{default-pipeline=canonicalize max-iterations=4}"
300+
]
301+
if sroa
302+
push!(passes, "sroa-wrappers")
303+
push!(passes, "libdevice-funcs-raise")
304+
push!(passes, "canonicalize")
305+
end
306+
push!(passes, func_passes)
307+
return join(passes,
304308
',',
305309
)
306310
end
@@ -310,6 +314,8 @@ end
310314
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
311315

312316
function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
317+
@show pass_pipeline
318+
flush(stdout)
313319
pm = MLIR.IR.PassManager()
314320
MLIR.IR.enable_verifier!(pm, enable_verifier)
315321
opm = MLIR.IR.OpPassManager(pm)
@@ -382,9 +388,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
382388
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
383389

384390
opt_passes = optimization_passes(; no_nan)
391+
opt_passes2 = optimization_passes(; no_nan, sroa=false)
385392

386393
if optimize === :all
387-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
394+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
388395
run_pass_pipeline!(
389396
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
390397
)
@@ -395,14 +402,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
395402
"canonicalize",
396403
"remove-unnecessary-enzyme-ops",
397404
"enzyme-simplify-math",
398-
opt_passes,
405+
opt_passes2,
399406
kern,
400407
],
401408
',',
402409
),
403410
)
404411
elseif optimize === :before_kernel
405-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
412+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
406413
run_pass_pipeline!(
407414
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
408415
)
@@ -413,13 +420,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
413420
"canonicalize",
414421
"remove-unnecessary-enzyme-ops",
415422
"enzyme-simplify-math",
416-
opt_passes,
423+
opt_passes2,
417424
],
418425
',',
419426
),
420427
)
421428
elseif optimize === :no_enzyme
422-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
429+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
423430
run_pass_pipeline!(mod, "arith-raise{stablehlo=true}"; enable_verifier=false)
424431
run_pass_pipeline!(
425432
mod,
@@ -428,7 +435,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
428435
"canonicalize",
429436
"remove-unnecessary-enzyme-ops",
430437
"enzyme-simplify-math",
431-
opt_passes,
438+
opt_passes2,
432439
],
433440
',',
434441
),
@@ -457,14 +464,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
457464
"canonicalize",
458465
"remove-unnecessary-enzyme-ops",
459466
"enzyme-simplify-math",
460-
opt_passes,
467+
opt_passes2,
461468
kern,
462469
],
463470
',',
464471
),
465472
)
466473
elseif optimize === :before_enzyme
467-
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
474+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes2], ","))
468475
run_pass_pipeline!(
469476
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
470477
)

0 commit comments

Comments
 (0)