Skip to content

[mlir][scf] Remove identical scf.for iter args #127145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025
Merged

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Feb 13, 2025

This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.


Full diff: https://github.com/llvm/llvm-project/pull/127145.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+24-5)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+18)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83ae79ce48266..448141735ba7f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -843,9 +843,8 @@ namespace {
 // 3) The iter arguments have no use and the corresponding (operation) results
 // have no use.
 //
-// These arguments must be defined outside of
-// the ForOp region and can just be forwarded after simplifying the op inits,
-// yields and returns.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
 //
 // The implementation uses `inlineBlockBefore` to steal the content of the
 // original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
     newIterArgs.reserve(forOp.getInitArgs().size());
     newYieldValues.reserve(numResults);
     newResultValues.reserve(numResults);
+    DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
     for (auto [init, arg, result, yielded] :
          llvm::zip(forOp.getInitArgs(),       // iter from outside
                    forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
       // result has no use.
       bool forwarded = (arg == yielded) || (init == yielded) ||
                        (arg.use_empty() && result.use_empty());
-      keepMask.push_back(!forwarded);
-      canonicalize |= forwarded;
       if (forwarded) {
+        canonicalize = true;
+        keepMask.push_back(false);
         newBlockTransferArgs.push_back(init);
         newResultValues.push_back(init);
         continue;
       }
+
+      // Check if a previous kept argument always has the same values for init
+      // and yielded values.
+      if (auto it = initYieldToArg.find({init, yielded});
+          it != initYieldToArg.end()) {
+        canonicalize = true;
+        keepMask.push_back(false);
+        auto [sameArg, sameResult] = it->second;
+        rewriter.replaceAllUsesWith(arg, sameArg);
+        rewriter.replaceAllUsesWith(result, sameResult);
+        // The replacement value doesn't matter because there are no uses.
+        newBlockTransferArgs.push_back(init);
+        newResultValues.push_back(init);
+        continue;
+      }
+
+      // This value is kept.
+      initYieldToArg.insert({{init, yielded}, {arg, result}});
+      keepMask.push_back(true);
       newIterArgs.push_back(init);
       newYieldValues.push_back(yielded);
       newBlockTransferArgs.push_back(Value()); // placeholder with null value
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 828758df6d31c..c18bd617216f1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
 
 // -----
 
+// CHECK-LABEL: @replace_duplicate_iter_args
+// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
+func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
+  // CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
+  %0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
+    // CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
+    %1 = arith.addi %k0, %k1 : index
+    // CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
+    %2 = arith.addi %k2, %k3 : index
+    // CHECK-NEXT: yield [[V0]], [[V1]]
+    scf.yield %1, %2, %2, %1 : index, index, index, index
+  }
+  // CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
+  return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
+}
+
+// -----
+
 func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
 
 func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

This augments the iter arg canonicalizer to remove iter args that always
have the same value, i.e. their correpsonding init and yielded values
are the same.
Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@Mogball Mogball merged commit a57e58d into main Feb 14, 2025
6 of 7 checks passed
@Mogball Mogball deleted the users/mogball/scf branch February 14, 2025 01:03
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
This augments the iter arg canonicalizer to remove iter args that always
have the same value, i.e. their correpsonding init and yielded values
are the same.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
This augments the iter arg canonicalizer to remove iter args that always
have the same value, i.e. their correpsonding init and yielded values
are the same.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants