Skip to content

Commit 5a31403

Browse files
committed
[MLIR][SCF] Create selects from if yield results which are not defined in the body
Previously, the canonicalizer to create ifs from selects would only work if the if did not have a body other than yielding. This patch upgrade the functionality to be able to create selects from any if result whose operands are not defined within the body. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121943
1 parent 1f001b2 commit 5a31403

File tree

2 files changed

+89
-42
lines changed

2 files changed

+89
-42
lines changed

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,8 @@ struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
13261326
}
13271327
};
13281328

1329+
/// Hoist any yielded results whose operands are defined outside
1330+
/// the if, to a select instruction.
13291331
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
13301332
using OpRewritePattern<IfOp>::OpRewritePattern;
13311333

@@ -1334,31 +1336,58 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
13341336
if (op->getNumResults() == 0)
13351337
return failure();
13361338

1337-
if (!llvm::hasSingleElement(op.getThenRegion().front()) ||
1338-
!llvm::hasSingleElement(op.getElseRegion().front()))
1339+
auto cond = op.getCondition();
1340+
auto thenYieldArgs = op.thenYield().getOperands();
1341+
auto elseYieldArgs = op.elseYield().getOperands();
1342+
1343+
SmallVector<Type> nonHoistable;
1344+
for (const auto &it :
1345+
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1346+
Value trueVal = std::get<0>(it.value());
1347+
Value falseVal = std::get<1>(it.value());
1348+
if (&op.getThenRegion() == trueVal.getParentRegion() ||
1349+
&op.getElseRegion() == falseVal.getParentRegion())
1350+
nonHoistable.push_back(trueVal.getType());
1351+
}
1352+
// Early exit if there aren't any yielded values we can
1353+
// hoist outside the if.
1354+
if (nonHoistable.size() == op->getNumResults())
13391355
return failure();
13401356

1341-
auto cond = op.getCondition();
1342-
auto thenYieldArgs =
1343-
cast<scf::YieldOp>(op.getThenRegion().front().getTerminator())
1344-
.getOperands();
1345-
auto elseYieldArgs =
1346-
cast<scf::YieldOp>(op.getElseRegion().front().getTerminator())
1347-
.getOperands();
1357+
IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond);
1358+
if (replacement.thenBlock())
1359+
rewriter.eraseBlock(replacement.thenBlock());
1360+
replacement.getThenRegion().takeBody(op.getThenRegion());
1361+
replacement.getElseRegion().takeBody(op.getElseRegion());
1362+
13481363
SmallVector<Value> results(op->getNumResults());
13491364
assert(thenYieldArgs.size() == results.size());
13501365
assert(elseYieldArgs.size() == results.size());
1366+
1367+
SmallVector<Value> trueYields;
1368+
SmallVector<Value> falseYields;
13511369
for (const auto &it :
13521370
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
13531371
Value trueVal = std::get<0>(it.value());
13541372
Value falseVal = std::get<1>(it.value());
1355-
if (trueVal == falseVal)
1373+
if (&replacement.getThenRegion() == trueVal.getParentRegion() ||
1374+
&replacement.getElseRegion() == falseVal.getParentRegion()) {
1375+
results[it.index()] = replacement.getResult(trueYields.size());
1376+
trueYields.push_back(trueVal);
1377+
falseYields.push_back(falseVal);
1378+
} else if (trueVal == falseVal)
13561379
results[it.index()] = trueVal;
13571380
else
13581381
results[it.index()] = rewriter.create<arith::SelectOp>(
13591382
op.getLoc(), cond, trueVal, falseVal);
13601383
}
13611384

1385+
rewriter.setInsertionPointToEnd(replacement.thenBlock());
1386+
rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
1387+
1388+
rewriter.setInsertionPointToEnd(replacement.elseBlock());
1389+
rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
1390+
13621391
rewriter.replaceOp(op, results);
13631392
return success();
13641393
}

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,26 +136,26 @@ func @nested_parallel(%0: memref<?x?x?xf64>) -> memref<?x?x?xf64> {
136136

137137
func private @side_effect()
138138
func @one_unused(%cond: i1) -> (index) {
139-
%c0 = arith.constant 0 : index
140-
%c1 = arith.constant 1 : index
141-
%c2 = arith.constant 2 : index
142-
%c3 = arith.constant 3 : index
143139
%0, %1 = scf.if %cond -> (index, index) {
144140
call @side_effect() : () -> ()
141+
%c0 = "test.value0"() : () -> (index)
142+
%c1 = "test.value1"() : () -> (index)
145143
scf.yield %c0, %c1 : index, index
146144
} else {
145+
%c2 = "test.value2"() : () -> (index)
146+
%c3 = "test.value3"() : () -> (index)
147147
scf.yield %c2, %c3 : index, index
148148
}
149149
return %1 : index
150150
}
151151

152152
// CHECK-LABEL: func @one_unused
153-
// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index
154-
// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
155153
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
156154
// CHECK: call @side_effect() : () -> ()
157-
// CHECK: scf.yield [[C0]] : index
155+
// CHECK: [[C1:%.*]] = "test.value1"
156+
// CHECK: scf.yield [[C1]] : index
158157
// CHECK: } else
158+
// CHECK: [[C3:%.*]] = "test.value3"
159159
// CHECK: scf.yield [[C3]] : index
160160
// CHECK: }
161161
// CHECK: return [[V0]] : index
@@ -164,37 +164,40 @@ func @one_unused(%cond: i1) -> (index) {
164164

165165
func private @side_effect()
166166
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
167-
%c0 = arith.constant 0 : index
168-
%c1 = arith.constant 1 : index
169-
%c2 = arith.constant 2 : index
170-
%c3 = arith.constant 3 : index
171167
%0, %1 = scf.if %cond1 -> (index, index) {
172168
%2, %3 = scf.if %cond2 -> (index, index) {
173169
call @side_effect() : () -> ()
170+
%c0 = "test.value0"() : () -> (index)
171+
%c1 = "test.value1"() : () -> (index)
174172
scf.yield %c0, %c1 : index, index
175173
} else {
174+
%c2 = "test.value2"() : () -> (index)
175+
%c3 = "test.value3"() : () -> (index)
176176
scf.yield %c2, %c3 : index, index
177177
}
178178
scf.yield %2, %3 : index, index
179179
} else {
180+
%c0 = "test.value0_2"() : () -> (index)
181+
%c1 = "test.value1_2"() : () -> (index)
180182
scf.yield %c0, %c1 : index, index
181183
}
182184
return %1 : index
183185
}
184186

185187
// CHECK-LABEL: func @nested_unused
186-
// CHECK-DAG: [[C0:%.*]] = arith.constant 1 : index
187-
// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
188188
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
189189
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
190190
// CHECK: call @side_effect() : () -> ()
191-
// CHECK: scf.yield [[C0]] : index
191+
// CHECK: [[C1:%.*]] = "test.value1"
192+
// CHECK: scf.yield [[C1]] : index
192193
// CHECK: } else
194+
// CHECK: [[C3:%.*]] = "test.value3"
193195
// CHECK: scf.yield [[C3]] : index
194196
// CHECK: }
195197
// CHECK: scf.yield [[V1]] : index
196198
// CHECK: } else
197-
// CHECK: scf.yield [[C0]] : index
199+
// CHECK: [[C1_2:%.*]] = "test.value1_2"
200+
// CHECK: scf.yield [[C1_2]] : index
198201
// CHECK: }
199202
// CHECK: return [[V0]] : index
200203

@@ -302,6 +305,27 @@ func @to_select_same_val(%cond: i1) -> (index, index) {
302305
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
303306
// CHECK: return [[V0]], [[C1]] : index, index
304307

308+
309+
func @to_select_with_body(%cond: i1) -> index {
310+
%c0 = arith.constant 0 : index
311+
%c1 = arith.constant 1 : index
312+
%0 = scf.if %cond -> index {
313+
"test.op"() : () -> ()
314+
scf.yield %c0 : index
315+
} else {
316+
scf.yield %c1 : index
317+
}
318+
return %0 : index
319+
}
320+
321+
// CHECK-LABEL: func @to_select_with_body
322+
// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
323+
// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
324+
// CHECK: scf.if {{.*}} {
325+
// CHECK: "test.op"() : () -> ()
326+
// CHECK: }
327+
// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]]
328+
// CHECK: return [[V0]] : index
305329
// -----
306330

307331
func @to_select2(%cond: i1) -> (index, index) {
@@ -731,38 +755,32 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
731755

732756
// CHECK-LABEL: @cond_prop
733757
func @cond_prop(%arg0 : i1) -> index {
734-
%c1 = arith.constant 1 : index
735-
%c2 = arith.constant 2 : index
736-
%c3 = arith.constant 3 : index
737-
%c4 = arith.constant 4 : index
738758
%res = scf.if %arg0 -> index {
739759
%res1 = scf.if %arg0 -> index {
740-
%v1 = "test.get_some_value"() : () -> i32
741-
scf.yield %c1 : index
760+
%v1 = "test.get_some_value1"() : () -> index
761+
scf.yield %v1 : index
742762
} else {
743-
%v2 = "test.get_some_value"() : () -> i32
744-
scf.yield %c2 : index
763+
%v2 = "test.get_some_value2"() : () -> index
764+
scf.yield %v2 : index
745765
}
746766
scf.yield %res1 : index
747767
} else {
748768
%res2 = scf.if %arg0 -> index {
749-
%v3 = "test.get_some_value"() : () -> i32
750-
scf.yield %c3 : index
769+
%v3 = "test.get_some_value3"() : () -> index
770+
scf.yield %v3 : index
751771
} else {
752-
%v4 = "test.get_some_value"() : () -> i32
753-
scf.yield %c4 : index
772+
%v4 = "test.get_some_value4"() : () -> index
773+
scf.yield %v4 : index
754774
}
755775
scf.yield %res2 : index
756776
}
757777
return %res : index
758778
}
759-
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
760-
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
761779
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) {
762-
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
780+
// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index
763781
// CHECK-NEXT: scf.yield %[[c1]] : index
764782
// CHECK-NEXT: } else {
765-
// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32
783+
// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index
766784
// CHECK-NEXT: scf.yield %[[c4]] : index
767785
// CHECK-NEXT: }
768786
// CHECK-NEXT: return %[[if]] : index
@@ -808,14 +826,14 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
808826
return %res#0, %res#1 : i32, i1
809827
}
810828
// CHECK-NEXT: %true = arith.constant true
811-
// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
812829
// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) {
813830
// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32
814831
// CHECK-NEXT: scf.yield %[[sv1]] : i32
815832
// CHECK-NEXT: } else {
816833
// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32
817834
// CHECK-NEXT: scf.yield %[[sv2]] : i32
818835
// CHECK-NEXT: }
836+
// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1
819837
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
820838

821839
// -----

0 commit comments

Comments
 (0)