@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
839
839
namespace {
840
840
// Fold away ForOp iter arguments when:
841
841
// 1) The op yields the iter arguments.
842
- // 2) The iter arguments have no use and the corresponding outer region
843
- // iterators (inputs) are yielded.
842
+ // 2) The argument's corresponding outer region iterators (inputs) are yielded.
844
843
// 3) The iter arguments have no use and the corresponding (operation) results
845
844
// have no use.
846
845
//
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
872
871
newIterArgs.reserve (forOp.getInitArgs ().size ());
873
872
newYieldValues.reserve (numResults);
874
873
newResultValues.reserve (numResults);
875
- for (auto it : llvm::zip (forOp.getInitArgs (), // iter from outside
876
- forOp.getRegionIterArgs (), // iter inside region
877
- forOp.getResults (), // op results
878
- forOp.getYieldedValues () // iter yield
879
- )) {
874
+ for (auto [init, arg, result, yielded] :
875
+ llvm::zip (forOp.getInitArgs (), // iter from outside
876
+ forOp.getRegionIterArgs (), // iter inside region
877
+ forOp.getResults (), // op results
878
+ forOp.getYieldedValues () // iter yield
879
+ )) {
880
880
// Forwarded is `true` when:
881
881
// 1) The region `iter` argument is yielded.
882
- // 2) The region `iter` argument has no use, and the corresponding iter
883
- // operand (input) is yielded.
882
+ // 2) The region `iter` argument the corresponding input is yielded.
884
883
// 3) The region `iter` argument has no use, and the corresponding op
885
884
// result has no use.
886
- bool forwarded = ((std::get<1 >(it) == std::get<3 >(it)) ||
887
- (std::get<1 >(it).use_empty () &&
888
- (std::get<0 >(it) == std::get<3 >(it) ||
889
- std::get<2 >(it).use_empty ())));
885
+ bool forwarded = (arg == yielded) || (init == yielded) ||
886
+ (arg.use_empty () && result.use_empty ());
890
887
keepMask.push_back (!forwarded);
891
888
canonicalize |= forwarded;
892
889
if (forwarded) {
893
- newBlockTransferArgs.push_back (std::get< 0 >(it) );
894
- newResultValues.push_back (std::get< 0 >(it) );
890
+ newBlockTransferArgs.push_back (init );
891
+ newResultValues.push_back (init );
895
892
continue ;
896
893
}
897
- newIterArgs.push_back (std::get< 0 >(it) );
898
- newYieldValues.push_back (std::get< 3 >(it) );
894
+ newIterArgs.push_back (init );
895
+ newYieldValues.push_back (yielded );
899
896
newBlockTransferArgs.push_back (Value ()); // placeholder with null value
900
897
newResultValues.push_back (Value ()); // placeholder with null value
901
898
}
0 commit comments