@@ -3884,14 +3884,89 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
3884
3884
return success ();
3885
3885
}
3886
3886
};
3887
+
3888
+ // / If both ranges contain same values return mappping indices from args1 to
3889
+ // / args2. Otherwise return std::nullopt
3890
+ static std::optional<SmallVector<unsigned >> getArgsMapping (ValueRange args1,
3891
+ ValueRange args2) {
3892
+ if (args1.size () != args2.size ())
3893
+ return std::nullopt;
3894
+
3895
+ SmallVector<unsigned > ret (args1.size ());
3896
+ for (auto &&[i, arg1] : llvm::enumerate (args1)) {
3897
+ auto it = llvm::find (args2, arg1);
3898
+ if (it == args2.end ())
3899
+ return std::nullopt;
3900
+
3901
+ auto j = it - args2.begin ();
3902
+ ret[j] = static_cast <unsigned >(i);
3903
+ }
3904
+
3905
+ return ret;
3906
+ }
3907
+
3908
+ // / If `before` block args are directly forwarded to `scf.condition`, rearrange
3909
+ // / `scf.condition` args into same order as block args. Update `after` block
3910
+ // args and results values accordingly.
3911
+ // / Needed to simplify `scf.while` -> `scf.for` uplifting.
3912
+ struct WhileOpAlignBeforeArgs : public OpRewritePattern <WhileOp> {
3913
+ using OpRewritePattern::OpRewritePattern;
3914
+
3915
+ LogicalResult matchAndRewrite (WhileOp loop,
3916
+ PatternRewriter &rewriter) const override {
3917
+ auto oldBefore = loop.getBeforeBody ();
3918
+ ConditionOp oldTerm = loop.getConditionOp ();
3919
+ ValueRange beforeArgs = oldBefore->getArguments ();
3920
+ ValueRange termArgs = oldTerm.getArgs ();
3921
+ if (beforeArgs == termArgs)
3922
+ return failure ();
3923
+
3924
+ auto mapping = getArgsMapping (beforeArgs, termArgs);
3925
+ if (!mapping)
3926
+ return failure ();
3927
+
3928
+ {
3929
+ OpBuilder::InsertionGuard g (rewriter);
3930
+ rewriter.setInsertionPoint (oldTerm);
3931
+ rewriter.replaceOpWithNewOp <ConditionOp>(oldTerm, oldTerm.getCondition (),
3932
+ beforeArgs);
3933
+ }
3934
+
3935
+ auto oldAfter = loop.getAfterBody ();
3936
+
3937
+ SmallVector<Type> newResultTypes (beforeArgs.size ());
3938
+ for (auto &&[i, j] : llvm::enumerate (*mapping))
3939
+ newResultTypes[j] = loop.getResult (i).getType ();
3940
+
3941
+ auto newLoop = rewriter.create <WhileOp>(loop.getLoc (), newResultTypes,
3942
+ loop.getInits (), nullptr , nullptr );
3943
+ auto newBefore = newLoop.getBeforeBody ();
3944
+ auto newAfter = newLoop.getAfterBody ();
3945
+
3946
+ SmallVector<Value> newResults (beforeArgs.size ());
3947
+ SmallVector<Value> newAfterArgs (beforeArgs.size ());
3948
+ for (auto &&[i, j] : llvm::enumerate (*mapping)) {
3949
+ newResults[i] = newLoop.getResult (j);
3950
+ newAfterArgs[i] = newAfter->getArgument (j);
3951
+ }
3952
+
3953
+ rewriter.inlineBlockBefore (oldBefore, newBefore, newBefore->begin (),
3954
+ newBefore->getArguments ());
3955
+ rewriter.inlineBlockBefore (oldAfter, newAfter, newAfter->begin (),
3956
+ newAfterArgs);
3957
+
3958
+ rewriter.replaceOp (loop, newResults);
3959
+ return success ();
3960
+ }
3961
+ };
3887
3962
} // namespace
3888
3963
3889
3964
void WhileOp::getCanonicalizationPatterns (RewritePatternSet &results,
3890
3965
MLIRContext *context) {
3891
3966
results.add <RemoveLoopInvariantArgsFromBeforeBlock,
3892
3967
RemoveLoopInvariantValueYielded, WhileConditionTruth,
3893
3968
WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults,
3894
- WhileRemoveUnusedArgs>(context);
3969
+ WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs >(context);
3895
3970
}
3896
3971
3897
3972
// ===----------------------------------------------------------------------===//
0 commit comments