@@ -116,7 +116,7 @@ function create_result(
116
116
end
117
117
118
118
# Optimization passes via transform dialect
119
- function optimization_passes (; no_nan:: Bool = false )
119
+ function optimization_passes (; no_nan:: Bool = false , sroa :: Bool = false )
120
120
transform_passes_list = [
121
121
" patterns=compare_op_canon<16>" ,
122
122
" transpose_transpose<16>" ,
@@ -295,12 +295,16 @@ function optimization_passes(; no_nan::Bool=false)
295
295
" ," ,
296
296
)
297
297
func_passes = join ([" canonicalize" , " cse" , " canonicalize" , transform_passes], " ," )
298
- return join (
299
- [
300
- " inline{default-pipeline=canonicalize max-iterations=4}" ,
301
- " libdevice-funcs-raise" ,
302
- func_passes,
303
- ],
298
+ passes = [
299
+ " inline{default-pipeline=canonicalize max-iterations=4}"
300
+ ]
301
+ if sroa
302
+ push! (passes, " sroa-wrappers" )
303
+ push! (passes, " libdevice-funcs-raise" )
304
+ push! (passes, " canonicalize" )
305
+ end
306
+ push! (passes, func_passes)
307
+ return join (passes,
304
308
' ,' ,
305
309
)
306
310
end
310
314
const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
311
315
312
316
function run_pass_pipeline! (mod, pass_pipeline; enable_verifier= true )
317
+ @show pass_pipeline
318
+ flush (stdout )
313
319
pm = MLIR. IR. PassManager ()
314
320
MLIR. IR. enable_verifier! (pm, enable_verifier)
315
321
opm = MLIR. IR. OpPassManager (pm)
@@ -374,9 +380,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
374
380
kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) },symbol-dce"
375
381
376
382
opt_passes = optimization_passes (; no_nan)
383
+ opt_passes2 = optimization_passes (; no_nan, sroa= false )
377
384
378
385
if optimize === :all
379
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
386
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
380
387
run_pass_pipeline! (
381
388
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
382
389
)
@@ -387,14 +394,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
387
394
" canonicalize" ,
388
395
" remove-unnecessary-enzyme-ops" ,
389
396
" enzyme-simplify-math" ,
390
- opt_passes ,
397
+ opt_passes2 ,
391
398
kern,
392
399
],
393
400
' ,' ,
394
401
),
395
402
)
396
403
elseif optimize === :before_kernel
397
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
404
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
398
405
run_pass_pipeline! (
399
406
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
400
407
)
@@ -405,13 +412,13 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
405
412
" canonicalize" ,
406
413
" remove-unnecessary-enzyme-ops" ,
407
414
" enzyme-simplify-math" ,
408
- opt_passes ,
415
+ opt_passes2 ,
409
416
],
410
417
' ,' ,
411
418
),
412
419
)
413
420
elseif optimize === :no_enzyme
414
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
421
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
415
422
run_pass_pipeline! (mod, " arith-raise{stablehlo=true}" ; enable_verifier= false )
416
423
run_pass_pipeline! (
417
424
mod,
@@ -420,7 +427,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
420
427
" canonicalize" ,
421
428
" remove-unnecessary-enzyme-ops" ,
422
429
" enzyme-simplify-math" ,
423
- opt_passes ,
430
+ opt_passes2 ,
424
431
],
425
432
' ,' ,
426
433
),
@@ -449,14 +456,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
449
456
" canonicalize" ,
450
457
" remove-unnecessary-enzyme-ops" ,
451
458
" enzyme-simplify-math" ,
452
- opt_passes ,
459
+ opt_passes2 ,
453
460
kern,
454
461
],
455
462
' ,' ,
456
463
),
457
464
)
458
465
elseif optimize === :before_enzyme
459
- run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes ], " ," ))
466
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes2 ], " ," ))
460
467
run_pass_pipeline! (
461
468
mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
462
469
)
0 commit comments