@@ -843,9 +843,8 @@ namespace {
843
843
// 3) The iter arguments have no use and the corresponding (operation) results
844
844
// have no use.
845
845
//
846
- // These arguments must be defined outside of
847
- // the ForOp region and can just be forwarded after simplifying the op inits,
848
- // yields and returns.
846
+ // These arguments must be defined outside of the ForOp region and can just be
847
+ // forwarded after simplifying the op inits, yields and returns.
849
848
//
850
849
// The implementation uses `inlineBlockBefore` to steal the content of the
851
850
// original ForOp and avoid cloning.
@@ -871,6 +870,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
871
870
newIterArgs.reserve (forOp.getInitArgs ().size ());
872
871
newYieldValues.reserve (numResults);
873
872
newResultValues.reserve (numResults);
873
+ DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
874
874
for (auto [init, arg, result, yielded] :
875
875
llvm::zip (forOp.getInitArgs (), // iter from outside
876
876
forOp.getRegionIterArgs (), // iter inside region
@@ -884,13 +884,32 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
884
884
// result has no use.
885
885
bool forwarded = (arg == yielded) || (init == yielded) ||
886
886
(arg.use_empty () && result.use_empty ());
887
- keepMask.push_back (!forwarded);
888
- canonicalize |= forwarded;
889
887
if (forwarded) {
888
+ canonicalize = true ;
889
+ keepMask.push_back (false );
890
890
newBlockTransferArgs.push_back (init);
891
891
newResultValues.push_back (init);
892
892
continue ;
893
893
}
894
+
895
+ // Check if a previous kept argument always has the same values for init
896
+ // and yielded values.
897
+ if (auto it = initYieldToArg.find ({init, yielded});
898
+ it != initYieldToArg.end ()) {
899
+ canonicalize = true ;
900
+ keepMask.push_back (false );
901
+ auto [sameArg, sameResult] = it->second ;
902
+ rewriter.replaceAllUsesWith (arg, sameArg);
903
+ rewriter.replaceAllUsesWith (result, sameResult);
904
+ // The replacement value doesn't matter because there are no uses.
905
+ newBlockTransferArgs.push_back (init);
906
+ newResultValues.push_back (init);
907
+ continue ;
908
+ }
909
+
910
+ // This value is kept.
911
+ initYieldToArg.insert ({{init, yielded}, {arg, result}});
912
+ keepMask.push_back (true );
894
913
newIterArgs.push_back (init);
895
914
newYieldValues.push_back (yielded);
896
915
newBlockTransferArgs.push_back (Value ()); // placeholder with null value
0 commit comments