Skip to content

Commit 701f240

Browse files
authored
[mlir] fix crash when scf utils work on llvm.func (#120688)
fixed #119378
1 parent 4472648 commit 701f240

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter,
130130

131131
// Outline before current function.
132132
OpBuilder::InsertionGuard g(rewriter);
133-
rewriter.setInsertionPoint(region.getParentOfType<func::FuncOp>());
133+
rewriter.setInsertionPoint(region.getParentOfType<FunctionOpInterface>());
134134

135135
SetVector<Value> captures;
136136
getUsedValuesDefinedAbove(region, captures);

mlir/test/Transforms/scf-if-utils.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,30 @@ func.func @outline_empty_if_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8
7373
}
7474
return
7575
}
76+
77+
// -----
78+
79+
// This test checks scf utils can work on llvm func.
80+
81+
// CHECK: func @outlined_then0() {
82+
// CHECK-NEXT: return
83+
// CHECK-NEXT: }
84+
// CHECK: func @outlined_else0(%{{.*}}: i1, %{{.*}}: i32) {
85+
// CHECK-NEXT: "some_op"(%{{.*}}, %{{.*}}) : (i1, i32) -> ()
86+
// CHECK-NEXT: return
87+
// CHECK-NEXT: }
88+
// CHECK: llvm.func @llvm_func(%{{.*}}: i1, %{{.*}}: i32) {
89+
// CHECK-NEXT: scf.if %{{.*}} {
90+
// CHECK-NEXT: func.call @outlined_then0() : () -> ()
91+
// CHECK-NEXT: } else {
92+
// CHECK-NEXT: func.call @outlined_else0(%{{.*}}, %{{.*}}) : (i1, i32) -> ()
93+
// CHECK-NEXT: }
94+
// CHECK-NEXT: llvm.return
95+
// CHECK-NEXT: }
96+
llvm.func @llvm_func(%cond: i1, %a: i32) {
97+
scf.if %cond {
98+
} else {
99+
"some_op"(%cond, %a) : (i1, i32) -> ()
100+
}
101+
llvm.return
102+
}

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ struct TestSCFIfUtilsPass
7979
StringRef getDescription() const final { return "test scf.if utils"; }
8080
explicit TestSCFIfUtilsPass() = default;
8181

82+
void getDependentDialects(DialectRegistry &registry) const override {
83+
registry.insert<func::FuncDialect>();
84+
}
85+
8286
void runOnOperation() override {
8387
int count = 0;
8488
getOperation().walk([&](scf::IfOp ifOp) {

0 commit comments

Comments
 (0)