Skip to content

[mlir][scf] Always remove for iter args that are loop invariant #121555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 3, 2025

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Jan 3, 2025

This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.

This alters the condition in ForOpIterArgsFolder to always remove iter
args when their initial value equals the yielded value, not just when
the arg has no use.
@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

This alters the condition in ForOpIterArgsFolder to always remove iter args when their initial value equals the yielded value, not just when the arg has no use.


Full diff: https://github.com/llvm/llvm-project/pull/121555.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+12-13)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+18-4)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f126c..872d34de4495bf 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -872,30 +872,29 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
-    for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
-                             forOp.getRegionIterArgs(), // iter inside region
-                             forOp.getResults(),        // op results
-                             forOp.getYieldedValues()   // iter yield
-                             )) {
+    for (auto [init, arg, result, yielded] :
+         llvm::zip(forOp.getInitArgs(),       // iter from outside
+                   forOp.getRegionIterArgs(), // iter inside region
+                   forOp.getResults(),        // op results
+                   forOp.getYieldedValues()   // iter yield
+                   )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
       // 2) The region `iter` argument has no use, and the corresponding iter
       // operand (input) is yielded.
       // 3) The region `iter` argument has no use, and the corresponding op
       // result has no use.
-      bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
-                        (std::get<1>(it).use_empty() &&
-                         (std::get<0>(it) == std::get<3>(it) ||
-                          std::get<2>(it).use_empty())));
+      bool forwarded = (arg == yielded) || (init == yielded) ||
+                       (arg.use_empty() && result.use_empty());
       keepMask.push_back(!forwarded);
       canonicalize |= forwarded;
       if (forwarded) {
-        newBlockTransferArgs.push_back(std::get<0>(it));
-        newResultValues.push_back(std::get<0>(it));
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
         continue;
       }
-      newIterArgs.push_back(std::get<0>(it));
-      newYieldValues.push_back(std::get<3>(it));
+      newIterArgs.push_back(init);
+      newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
       newResultValues.push_back(Value());      // placeholder with null value
     }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8c4e7a41ee6bc4..828758df6d31c0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -408,6 +408,20 @@ func.func @for_yields_4() -> i32 {
 
 // -----
 
+// CHECK-LABEL: @constant_iter_arg
+func.func @constant_iter_arg(%arg0: index, %arg1: index, %arg2: index) {
+  %c0_i32 = arith.constant 0 : i32
+  // CHECK: scf.for %arg3 = %arg0 to %arg1 step %arg2 {
+  %0 = scf.for %i = %arg0 to %arg1 step %arg2 iter_args(%arg3 = %c0_i32) -> i32 {
+    // CHECK-NEXT: "test.use"(%c0_i32)
+    "test.use"(%arg3) : (i32) -> ()
+    scf.yield %c0_i32 : i32
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @replace_true_if
 func.func @replace_true_if() {
   %true = arith.constant true
@@ -1789,7 +1803,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
@@ -1832,7 +1846,7 @@ module {
 }
 // CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
 //  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
-//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//       CHECK:    %[[RESULT:.*]] = scf.forall
 //  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
 //       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
 //       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
@@ -1856,7 +1870,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   %switch_cst_2 = arith.constant 2: index
   %1 = scf.index_switch %switch_cst_2 -> f32
   case 0 {
@@ -1867,7 +1881,7 @@ func.func @index_switch_fold() -> (f32, f32) {
     %y = arith.constant 42.0 : f32
     scf.yield %y : f32
   }
-  
+
   return %0, %1 : f32, f32
 }
 

@Mogball Mogball merged commit 9d8e634 into main Jan 3, 2025
5 of 7 checks passed
@Mogball Mogball deleted the users/mogball/for_args branch January 3, 2025 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants