-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][scf] Remove identical scf.for
iter args
#127145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Jeff Niu (Mogball) ChangesThis augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same. Full diff: https://github.com/llvm/llvm-project/pull/127145.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 83ae79ce48266..448141735ba7f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -843,9 +843,8 @@ namespace {
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
-// These arguments must be defined outside of
-// the ForOp region and can just be forwarded after simplifying the op inits,
-// yields and returns.
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
//
// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
+ DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
for (auto [init, arg, result, yielded] :
llvm::zip(forOp.getInitArgs(), // iter from outside
forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
// result has no use.
bool forwarded = (arg == yielded) || (init == yielded) ||
(arg.use_empty() && result.use_empty());
- keepMask.push_back(!forwarded);
- canonicalize |= forwarded;
if (forwarded) {
+ canonicalize = true;
+ keepMask.push_back(false);
newBlockTransferArgs.push_back(init);
newResultValues.push_back(init);
continue;
}
+
+ // Check if a previous kept argument always has the same values for init
+ // and yielded values.
+ if (auto it = initYieldToArg.find({init, yielded});
+ it != initYieldToArg.end()) {
+ canonicalize = true;
+ keepMask.push_back(false);
+ auto [sameArg, sameResult] = it->second;
+ rewriter.replaceAllUsesWith(arg, sameArg);
+ rewriter.replaceAllUsesWith(result, sameResult);
+ // The replacement value doesn't matter because there are no uses.
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
+ continue;
+ }
+
+ // This value is kept.
+ initYieldToArg.insert({{init, yielded}, {arg, result}});
+ keepMask.push_back(true);
newIterArgs.push_back(init);
newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 828758df6d31c..c18bd617216f1 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
// -----
+// CHECK-LABEL: @replace_duplicate_iter_args
+// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
+func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
+ // CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
+ %0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
+ // CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
+ %1 = arith.addi %k0, %k1 : index
+ // CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
+ %2 = arith.addi %k2, %k3 : index
+ // CHECK-NEXT: yield [[V0]], [[V1]]
+ scf.yield %1, %2, %2, %1 : index, index, index, index
+ }
+ // CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
+ return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
+}
+
+// -----
+
func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {
|
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
6e01664
to
6a47aaa
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.