Skip to content

Commit 65fd055

Browse files
authored
[MLIR] Added check for IsTerminator trait (#79317)
This PR adds a check for IsTerminator trait to prevent deletion of ops like gpu.terminator as a "simple op" by RemoveDeadValues pass.
1 parent 6ccb06a commit 65fd055

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -594,13 +594,9 @@ void RemoveDeadValues::runOnOperation() {
594594
cleanFuncOp(funcOp, module, la);
595595
} else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
596596
cleanRegionBranchOp(regionBranchOp, la);
597-
} else if (op->hasTrait<OpTrait::ReturnLike>()) {
598-
// Nothing to do because this terminator is associated with either a
599-
// function op or a region branch op and gets cleaned when these ops are
600-
// cleaned.
601-
} else if (isa<RegionBranchTerminatorOpInterface>(op)) {
602-
// Nothing to do because this terminator is associated with a region
603-
// branch op and gets cleaned when the latter is cleaned.
597+
} else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
598+
// Nothing to do here because this is a terminator op and it should be
599+
// honored with respect to its parent
604600
} else if (isa<CallOpInterface>(op)) {
605601
// Nothing to do because this op is associated with a function op and gets
606602
// cleaned when the latter is cleaned.

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,25 @@ func.func @main(%arg3 : i32, %arg4 : i1) {
335335
%non_live_0 = func.call @clean_region_branch_op_erase_it(%arg3, %arg4) : (i32, i1) -> (i32)
336336
return
337337
}
338+
339+
// -----
340+
341+
#map = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
342+
func.func @kernel(%arg0: memref<18xf32>) {
343+
%c1 = arith.constant 1 : index
344+
%c18 = arith.constant 18 : index
345+
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c18, %arg10 = %c18, %arg11 = %c18) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
346+
%c1_0 = arith.constant 1 : index
347+
%c0_1 = arith.constant 0 : index
348+
%cst_2 = arith.constant 25.4669495 : f32
349+
%6 = affine.apply #map(%arg3)[%c1_0, %c0_1]
350+
memref.store %cst_2, %arg0[%6] : memref<18xf32>
351+
gpu.terminator
352+
} {SCFToGPU_visited}
353+
return
354+
}
355+
356+
// CHECK-LABEL: func.func @kernel(%arg0: memref<18xf32>) {
357+
// CHECK: gpu.launch blocks
358+
// CHECK: memref.store
359+
// CHECK-NEXT: gpu.terminator

0 commit comments

Comments
 (0)