@@ -3872,14 +3872,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
3872
3872
return success ();
3873
3873
}
3874
3874
};
3875
+
3876
+ // / If both ranges contain same values return mappping indices from args1 to
3877
+ // / args2. Otherwise return std::nullopt
3878
+ static std::optional<SmallVector<unsigned >> getArgsMapping (ValueRange args1,
3879
+ ValueRange args2) {
3880
+ if (args1.size () != args2.size ())
3881
+ return std::nullopt;
3882
+
3883
+ SmallVector<unsigned > ret (args1.size ());
3884
+ for (auto &&[i, arg1] : llvm::enumerate (args1)) {
3885
+ auto it = llvm::find (args2, arg1);
3886
+ if (it == args2.end ())
3887
+ return std::nullopt;
3888
+
3889
+ auto j = it - args2.begin ();
3890
+ ret[j] = static_cast <unsigned >(i);
3891
+ }
3892
+
3893
+ return ret;
3894
+ }
3895
+
3896
+ // / If `before` block args are directly forwarded to `scf.condition`, rearrange
3897
+ // / `scf.condition` args into same order as block args. Update `after` block
3898
+ // args and results values accordingly.
3899
+ // / Needed to simplify `scf.while` -> `scf.for` uplifting.
3900
+ struct WhileOpAlignBeforeArgs : public OpRewritePattern <WhileOp> {
3901
+ using OpRewritePattern::OpRewritePattern;
3902
+
3903
+ LogicalResult matchAndRewrite (WhileOp loop,
3904
+ PatternRewriter &rewriter) const override {
3905
+ auto oldBefore = loop.getBeforeBody ();
3906
+ ConditionOp oldTerm = loop.getConditionOp ();
3907
+ ValueRange beforeArgs = oldBefore->getArguments ();
3908
+ ValueRange termArgs = oldTerm.getArgs ();
3909
+ if (beforeArgs == termArgs)
3910
+ return failure ();
3911
+
3912
+ auto mapping = getArgsMapping (beforeArgs, termArgs);
3913
+ if (!mapping)
3914
+ return failure ();
3915
+
3916
+ {
3917
+ OpBuilder::InsertionGuard g (rewriter);
3918
+ rewriter.setInsertionPoint (oldTerm);
3919
+ rewriter.replaceOpWithNewOp <ConditionOp>(oldTerm, oldTerm.getCondition (),
3920
+ beforeArgs);
3921
+ }
3922
+
3923
+ auto oldAfter = loop.getAfterBody ();
3924
+
3925
+ SmallVector<Type> newResultTypes (beforeArgs.size ());
3926
+ for (auto &&[i, j] : llvm::enumerate (*mapping))
3927
+ newResultTypes[j] = loop.getResult (i).getType ();
3928
+
3929
+ auto newLoop = rewriter.create <WhileOp>(loop.getLoc (), newResultTypes,
3930
+ loop.getInits (), nullptr , nullptr );
3931
+ auto newBefore = newLoop.getBeforeBody ();
3932
+ auto newAfter = newLoop.getAfterBody ();
3933
+
3934
+ SmallVector<Value> newResults (beforeArgs.size ());
3935
+ SmallVector<Value> newAfterArgs (beforeArgs.size ());
3936
+ for (auto &&[i, j] : llvm::enumerate (*mapping)) {
3937
+ newResults[i] = newLoop.getResult (j);
3938
+ newAfterArgs[i] = newAfter->getArgument (j);
3939
+ }
3940
+
3941
+ rewriter.inlineBlockBefore (oldBefore, newBefore, newBefore->begin (),
3942
+ newBefore->getArguments ());
3943
+ rewriter.inlineBlockBefore (oldAfter, newAfter, newAfter->begin (),
3944
+ newAfterArgs);
3945
+
3946
+ rewriter.replaceOp (loop, newResults);
3947
+ return success ();
3948
+ }
3949
+ };
3875
3950
} // namespace
3876
3951
3877
3952
void WhileOp::getCanonicalizationPatterns (RewritePatternSet &results,
3878
3953
MLIRContext *context) {
3879
3954
results.add <RemoveLoopInvariantArgsFromBeforeBlock,
3880
3955
RemoveLoopInvariantValueYielded, WhileConditionTruth,
3881
3956
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3882
- WhileRemoveUnusedArgs>(context);
3957
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs >(context);
3883
3958
}
3884
3959
3885
3960
// ===----------------------------------------------------------------------===//
0 commit comments