@@ -116,7 +116,7 @@ function create_result(
116
116
end
117
117
118
118
# Optimization passes via transform dialect
119
- function optimization_passes (; no_nan:: Bool = false )
119
+ function optimization_passes (; no_nan:: Bool = false , sroa :: Bool = false )
120
120
transform_passes_list = [
121
121
" patterns=compare_op_canon<16>" ,
122
122
" transpose_transpose<16>" ,
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295
295
" ," ,
296
296
)
297
297
func_passes = join ([" canonicalize" , " cse" , " canonicalize" , transform_passes], " ," )
298
- return join (
299
- [
300
- " inline{default-pipeline=canonicalize max-iterations=4}" ,
301
- " libdevice-funcs-raise" ,
302
- func_passes,
303
- ],
298
+ passes = [
299
+ " inline{default-pipeline=canonicalize max-iterations=4}"
300
+ ]
301
+ if sroa
302
+ push! (passes, " sroa-wrappers" )
303
+ push! (passes, " libdevice-funcs-raise" )
304
+ push! (passes, " canonicalize" )
305
+ end
306
+ push! (passes, func_passes)
307
+ return join (passes,
304
308
' ,' ,
305
309
)
306
310
end
310
314
const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
311
315
312
316
function run_pass_pipeline! (mod, pass_pipeline; enable_verifier= true )
317
+ @show pass_pipeline
318
+ flush (stdout )
313
319
pm = MLIR. IR. PassManager ()
314
320
MLIR. IR. enable_verifier! (pm, enable_verifier)
315
321
opm = MLIR. IR. OpPassManager (pm)
@@ -382,9 +388,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
382
388
kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) },symbol-dce"
383
389
384
390
opt_passes = optimization_passes (; no_nan)
391
+ opt_passes2 = optimization_passes (; no_nan, sroa= false )
385
392
386
393
if optimize === :all
387
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
394
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
388
395
run_pass_pipeline! (
389
396
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
390
397
)
@@ -395,14 +402,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
395
402
" canonicalize" ,
396
403
" remove-unnecessary-enzyme-ops" ,
397
404
" enzyme-simplify-math" ,
398
- opt_passes ,
405
+ opt_passes2 ,
399
406
kern,
400
407
],
401
408
' ,' ,
402
409
),
403
410
)
404
411
elseif optimize === :before_kernel
405
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
412
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
406
413
run_pass_pipeline! (
407
414
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
408
415
)
@@ -413,13 +420,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
413
420
" canonicalize" ,
414
421
" remove-unnecessary-enzyme-ops" ,
415
422
" enzyme-simplify-math" ,
416
- opt_passes ,
423
+ opt_passes2 ,
417
424
],
418
425
' ,' ,
419
426
),
420
427
)
421
428
elseif optimize === :no_enzyme
422
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
429
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
423
430
run_pass_pipeline! (mod, " arith-raise{stablehlo=true}" ; enable_verifier= false )
424
431
run_pass_pipeline! (
425
432
mod,
@@ -428,7 +435,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
428
435
" canonicalize" ,
429
436
" remove-unnecessary-enzyme-ops" ,
430
437
" enzyme-simplify-math" ,
431
- opt_passes ,
438
+ opt_passes2 ,
432
439
],
433
440
' ,' ,
434
441
),
@@ -457,14 +464,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
457
464
" canonicalize" ,
458
465
" remove-unnecessary-enzyme-ops" ,
459
466
" enzyme-simplify-math" ,
460
- opt_passes ,
467
+ opt_passes2 ,
461
468
kern,
462
469
],
463
470
' ,' ,
464
471
),
465
472
)
466
473
elseif optimize === :before_enzyme
467
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
474
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
468
475
run_pass_pipeline! (
469
476
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
470
477
)
0 commit comments