Skip to content

Commit 9d8e634

Browse files
authored
[mlir][scf] Always remove for iter args that are loop invariant (#121555)
This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.
1 parent 3b72c62 commit 9d8e634

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
839839
namespace {
840840
// Fold away ForOp iter arguments when:
841841
// 1) The op yields the iter arguments.
842-
// 2) The iter arguments have no use and the corresponding outer region
843-
// iterators (inputs) are yielded.
842+
// 2) The argument's corresponding outer region iterators (inputs) are yielded.
844843
// 3) The iter arguments have no use and the corresponding (operation) results
845844
// have no use.
846845
//
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
872871
newIterArgs.reserve(forOp.getInitArgs().size());
873872
newYieldValues.reserve(numResults);
874873
newResultValues.reserve(numResults);
875-
for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
876-
forOp.getRegionIterArgs(), // iter inside region
877-
forOp.getResults(), // op results
878-
forOp.getYieldedValues() // iter yield
879-
)) {
874+
for (auto [init, arg, result, yielded] :
875+
llvm::zip(forOp.getInitArgs(), // iter from outside
876+
forOp.getRegionIterArgs(), // iter inside region
877+
forOp.getResults(), // op results
878+
forOp.getYieldedValues() // iter yield
879+
)) {
880880
// Forwarded is `true` when:
881881
// 1) The region `iter` argument is yielded.
882-
// 2) The region `iter` argument has no use, and the corresponding iter
883-
// operand (input) is yielded.
882+
// 2) The region `iter` argument the corresponding input is yielded.
884883
// 3) The region `iter` argument has no use, and the corresponding op
885884
// result has no use.
886-
bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
887-
(std::get<1>(it).use_empty() &&
888-
(std::get<0>(it) == std::get<3>(it) ||
889-
std::get<2>(it).use_empty())));
885+
bool forwarded = (arg == yielded) || (init == yielded) ||
886+
(arg.use_empty() && result.use_empty());
890887
keepMask.push_back(!forwarded);
891888
canonicalize |= forwarded;
892889
if (forwarded) {
893-
newBlockTransferArgs.push_back(std::get<0>(it));
894-
newResultValues.push_back(std::get<0>(it));
890+
newBlockTransferArgs.push_back(init);
891+
newResultValues.push_back(init);
895892
continue;
896893
}
897-
newIterArgs.push_back(std::get<0>(it));
898-
newYieldValues.push_back(std::get<3>(it));
894+
newIterArgs.push_back(init);
895+
newYieldValues.push_back(yielded);
899896
newBlockTransferArgs.push_back(Value()); // placeholder with null value
900897
newResultValues.push_back(Value()); // placeholder with null value
901898
}

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
408408

409409
// -----
410410

411+
// CHECK-LABEL: @constant_iter_arg
412+
func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
413+
%c0_i32 = arith.constant 0 : i32
414+
// CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
415+
%0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
416+
// CHECK-NEXT: "test.use"(%c0_i32)
417+
"test.use"(%arg3) : (i32) -> ()
418+
scf.yield %c0_i32 : i32
419+
}
420+
return
421+
}
422+
423+
// -----
424+
411425
// CHECK-LABEL: @replace_true_if
412426
func.func @replace_true_if() {
413427
%true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
17891803
}
17901804
// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
17911805
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
1792-
// CHECK: %[[RESULT:.*]] = scf.forall
1806+
// CHECK: %[[RESULT:.*]] = scf.forall
17931807
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
17941808
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
17951809
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
18321846
}
18331847
// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
18341848
// CHECK-SAME: (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
1835-
// CHECK: %[[RESULT:.*]] = scf.forall
1849+
// CHECK: %[[RESULT:.*]] = scf.forall
18361850
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
18371851
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
18381852
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
18561870
%y = arith.constant 42.0 : f32
18571871
scf.yield %y : f32
18581872
}
1859-
1873+
18601874
%switch_cst_2 = arith.constant 2: index
18611875
%1 = scf.index_switch %switch_cst_2 -> f32
18621876
case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
18671881
%y = arith.constant 42.0 : f32
18681882
scf.yield %y : f32
18691883
}
1870-
1884+
18711885
return %0, %1 : f32, f32
18721886
}
18731887

0 commit comments

Comments
 (0)