Skip to content

Commit a57e58d

Browse files
authored
[mlir][scf] Remove identical scf.for iter args (llvm#127145)
This augments the iter arg canonicalizer to remove iter args that always have the same value, i.e. their correpsonding init and yielded values are the same.
1 parent 912b154 commit a57e58d

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ namespace {
843843
// 3) The iter arguments have no use and the corresponding (operation) results
844844
// have no use.
845845
//
846-
// These arguments must be defined outside of
847-
// the ForOp region and can just be forwarded after simplifying the op inits,
848-
// yields and returns.
846+
// These arguments must be defined outside of the ForOp region and can just be
847+
// forwarded after simplifying the op inits, yields and returns.
849848
//
850849
// The implementation uses `inlineBlockBefore` to steal the content of the
851850
// original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
871870
newIterArgs.reserve(forOp.getInitArgs().size());
872871
newYieldValues.reserve(numResults);
873872
newResultValues.reserve(numResults);
873+
DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
874874
for (auto [init, arg, result, yielded] :
875875
llvm::zip(forOp.getInitArgs(), // iter from outside
876876
forOp.getRegionIterArgs(), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
884884
// result has no use.
885885
bool forwarded = (arg == yielded) || (init == yielded) ||
886886
(arg.use_empty() && result.use_empty());
887-
keepMask.push_back(!forwarded);
888-
canonicalize |= forwarded;
889887
if (forwarded) {
888+
canonicalize = true;
889+
keepMask.push_back(false);
890890
newBlockTransferArgs.push_back(init);
891891
newResultValues.push_back(init);
892892
continue;
893893
}
894+
895+
// Check if a previous kept argument always has the same values for init
896+
// and yielded values.
897+
if (auto it = initYieldToArg.find({init, yielded});
898+
it != initYieldToArg.end()) {
899+
canonicalize = true;
900+
keepMask.push_back(false);
901+
auto [sameArg, sameResult] = it->second;
902+
rewriter.replaceAllUsesWith(arg, sameArg);
903+
rewriter.replaceAllUsesWith(result, sameResult);
904+
// The replacement value doesn't matter because there are no uses.
905+
newBlockTransferArgs.push_back(init);
906+
newResultValues.push_back(init);
907+
continue;
908+
}
909+
910+
// This value is kept.
911+
initYieldToArg.insert({{init, yielded}, {arg, result}});
912+
keepMask.push_back(true);
894913
newIterArgs.push_back(init);
895914
newYieldValues.push_back(yielded);
896915
newBlockTransferArgs.push_back(Value()); // placeholder with null value

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,24 @@ func.func @fold_away_iter_and_result_with_no_use(%arg0 : i32,
821821

822822
// -----
823823

824+
// CHECK-LABEL: @replace_duplicate_iter_args
825+
// CHECK-SAME: [[LB:%arg[0-9]]]: index, [[UB:%arg[0-9]]]: index, [[STEP:%arg[0-9]]]: index, [[A:%arg[0-9]]]: index, [[B:%arg[0-9]]]: index
826+
func.func @replace_duplicate_iter_args(%lb: index, %ub: index, %step: index, %a: index, %b: index) -> (index, index, index, index) {
827+
// CHECK-NEXT: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[K0:%.*]] = [[A]], [[K1:%.*]] = [[B]])
828+
%0:4 = scf.for %i = %lb to %ub step %step iter_args(%k0 = %a, %k1 = %b, %k2 = %b, %k3 = %a) -> (index, index, index, index) {
829+
// CHECK-NEXT: [[V0:%.*]] = arith.addi [[K0]], [[K1]]
830+
%1 = arith.addi %k0, %k1 : index
831+
// CHECK-NEXT: [[V1:%.*]] = arith.addi [[K1]], [[K0]]
832+
%2 = arith.addi %k2, %k3 : index
833+
// CHECK-NEXT: yield [[V0]], [[V1]]
834+
scf.yield %1, %2, %2, %1 : index, index, index, index
835+
}
836+
// CHECK: return [[RES]]#0, [[RES]]#1, [[RES]]#1, [[RES]]#0
837+
return %0#0, %0#1, %0#2, %0#3 : index, index, index, index
838+
}
839+
840+
// -----
841+
824842
func.func private @do(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
825843

826844
func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> {

0 commit comments

Comments
 (0)