Skip to content

Commit 6a47aaa

Browse files
committed
[mlir][scf] Remove identical scf.for iter args
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
1 parent 059722d commit 6a47aaa

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ namespace {
843843
// 3) The iter arguments have no use and the corresponding (operation) results
844844
// have no use.
845845
//
846-
// These arguments must be defined outside of
847-
// the ForOp region and can just be forwarded after simplifying the op inits,
848-
// yields and returns.
846+
// These arguments must be defined outside of the ForOp region and can just be
847+
// forwarded after simplifying the op inits, yields and returns.
849848
//
850849
// The implementation uses `inlineBlockBefore` to steal the content of the
851850
// original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
871870
newIterArgs.reserve(forOp.getInitArgs().size());
872871
newYieldValues.reserve(numResults);
873872
newResultValues.reserve(numResults);
873+
DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
874874
for (auto [init, arg, result, yielded] :
875875
llvm::zip(forOp.getInitArgs(), // iter from outside
876876
forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
884884
// result has no use.
885885
bool forwarded = (arg == yielded) || (init == yielded) ||
886886
(arg.use_empty() && result.use_empty());
887-
keepMask.push_back(!forwarded);
888-
canonicalize |= forwarded;
889887
if (forwarded) {
888+
canonicalize = true;
889+
keepMask.push_back(false);
890890
newBlockTransferArgs.push_back(init);
891891
newResultValues.push_back(init);
892892
continue;
893893
}
894+
895+
// Check if a previous kept argument always has the same values for init
896+
// and yielded values.
897+
if (auto it = initYieldToArg.find({init, yielded});
898+
it != initYieldToArg.end()) {
899+
canonicalize = true;
900+
keepMask.push_back(false);
901+
auto [sameArg, sameResult] = it->second;
902+
rewriter.replaceAllUsesWith(arg, sameArg);
903+
rewriter.replaceAllUsesWith(result, sameResult);
904+
// The replacement value doesn't matter because there are no uses.
905+
newBlockTransferArgs.push_back(init);
906+
newResultValues.push_back(init);
907+
continue;
908+
}
909+
910+
// This value is kept.
911+
initYieldToArg.insert({{init, yielded}, {arg, result}});
912+
keepMask.push_back(true);
894913
newIterArgs.push_back(init);
895914
newYieldValues.push_back(yielded);
896915
newBlockTransferArgs.push_back(Value()); // placeholder with null value

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
821821

822822
// -----
823823

824+
// CHECK-LABEL: @replace_duplicate_iter_args
825+
// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
826+
func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
827+
// CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
828+
%0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
829+
// CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
830+
%1 = arith.addi %k0, %k1 : index
831+
// CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
832+
%2 = arith.addi %k2, %k3 : index
833+
// CHECK-NEXT: yield [[V0]], [[V1]]
834+
scf.yield %1, %2, %2, %1 : index, index, index, index
835+
}
836+
// CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
837+
return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
838+
}
839+
840+
// -----
841+
824842
func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
825843

826844
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)