@@ -15030,10 +15030,18 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
15030
15030
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15031
15031
const DenseMap<const SCEV *, const SCEV *> ⤅
15032
15032
15033
+ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
15034
+
15033
15035
public:
15034
15036
SCEVLoopGuardRewriter(ScalarEvolution &SE,
15035
- DenseMap<const SCEV *, const SCEV *> &M)
15036
- : SCEVRewriteVisitor(SE), Map(M) {}
15037
+ DenseMap<const SCEV *, const SCEV *> &M,
15038
+ bool PreserveNUW, bool PreserveNSW)
15039
+ : SCEVRewriteVisitor(SE), Map(M) {
15040
+ if (PreserveNUW)
15041
+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15042
+ if (PreserveNSW)
15043
+ FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15044
+ }
15037
15045
15038
15046
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15039
15047
@@ -15089,6 +15097,36 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15089
15097
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
15090
15098
return I->second;
15091
15099
}
15100
+
15101
+ const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15102
+ SmallVector<const SCEV *, 2> Operands;
15103
+ bool Changed = false;
15104
+ for (const auto *Op : Expr->operands()) {
15105
+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
15106
+ Changed |= Op != Operands.back();
15107
+ }
15108
+ // We are only replacing operands with equivalent values, so transfer the
15109
+ // flags from the original expression.
15110
+ return !Changed
15111
+ ? Expr
15112
+ : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
15113
+ Expr->getNoWrapFlags(), FlagMask));
15114
+ }
15115
+
15116
+ const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
15117
+ SmallVector<const SCEV *, 2> Operands;
15118
+ bool Changed = false;
15119
+ for (const auto *Op : Expr->operands()) {
15120
+ Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
15121
+ Changed |= Op != Operands.back();
15122
+ }
15123
+ // We are only replacing operands with equivalent values, so transfer the
15124
+ // flags from the original expression.
15125
+ return !Changed
15126
+ ? Expr
15127
+ : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
15128
+ Expr->getNoWrapFlags(), FlagMask));
15129
+ }
15092
15130
};
15093
15131
15094
15132
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
@@ -15503,18 +15541,29 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15503
15541
if (RewriteMap.empty())
15504
15542
return Expr;
15505
15543
15544
+ // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15545
+ // the replacement expressions are contained in the ranges of the replaced
15546
+ // expressions.
15547
+ bool PreserveNUW = true;
15548
+ bool PreserveNSW = true;
15549
+ for (const SCEV *Expr : ExprsToRewrite) {
15550
+ const SCEV *RewriteTo = RewriteMap[Expr];
15551
+ PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
15552
+ PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
15553
+ }
15554
+
15506
15555
// Now that all rewrite information is collect, rewrite the collected
15507
15556
// expressions with the information in the map. This applies information to
15508
15557
// sub-expressions.
15509
15558
if (ExprsToRewrite.size() > 1) {
15510
15559
for (const SCEV *Expr : ExprsToRewrite) {
15511
15560
const SCEV *RewriteTo = RewriteMap[Expr];
15512
15561
RewriteMap.erase(Expr);
15513
- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15562
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
15563
+ PreserveNSW);
15514
15564
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15515
15565
}
15516
15566
}
15517
-
15518
- SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15567
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
15519
15568
return Rewriter.visit(Expr);
15520
15569
}
0 commit comments