Skip to content

[SCEV] Preserve flags in SCEVLoopGuardRewriter for add and mul. #91472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14973,10 +14973,18 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;

SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;

public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
DenseMap<const SCEV *, const SCEV *> &M)
: SCEVRewriteVisitor(SE), Map(M) {}
DenseMap<const SCEV *, const SCEV *> &M,
bool PreserveNUW, bool PreserveNSW)
: SCEVRewriteVisitor(SE), Map(M) {
if (PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (PreserveNSW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
}

const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }

Expand Down Expand Up @@ -15032,6 +15040,36 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
return I->second;
}

const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (const auto *Op : Expr->operands()) {
Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
Changed |= Op != Operands.back();
}
// We are only replacing operands with equivalent values, so transfer the
// flags from the original expression.
return !Changed
? Expr
: SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
Expr->getNoWrapFlags(), FlagMask));
}

const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
SmallVector<const SCEV *, 2> Operands;
bool Changed = false;
for (const auto *Op : Expr->operands()) {
Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
Changed |= Op != Operands.back();
}
// We are only replacing operands with equivalent values, so transfer the
// flags from the original expression.
return !Changed
? Expr
: SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
Expr->getNoWrapFlags(), FlagMask));
}
};

const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
Expand Down Expand Up @@ -15446,18 +15484,29 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (RewriteMap.empty())
return Expr;

// Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
// the replacement expressions are contained in the ranges of the replaced
// expressions.
bool PreserveNUW = true;
bool PreserveNSW = true;
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = RewriteMap[Expr];
PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
}

// Now that all rewrite information is collect, rewrite the collected
// expressions with the information in the map. This applies information to
// sub-expressions.
if (ExprsToRewrite.size() > 1) {
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = RewriteMap[Expr];
RewriteMap.erase(Expr);
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
PreserveNSW);
RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
}
}

SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
return Rewriter.visit(Expr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
; CHECK-NEXT: %add = add nsw i32 %a, 4
; CHECK-NEXT: --> (4 + %a)<nsw> U: [-2147483644,-2147483648) S: [-2147483644,-2147483648)
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]
; CHECK-NEXT: --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (0 smax (4 + %a)<nsw>) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {0,+,1}<nuw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (4 + %a)<nsw> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = add i32 %iv, 1
; CHECK-NEXT: --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (1 + (0 smax (4 + %a)<nsw>))<nuw> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {1,+,1}<nuw><%loop> U: [1,-2147483647) S: [1,-2147483647) Exits: (5 + %a) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @rewrite_preserve_add_nsw
; CHECK-NEXT: Loop %loop: backedge-taken count is (0 smax (4 + %a)<nsw>)
; CHECK-NEXT: Loop %loop: backedge-taken count is (4 + %a)<nsw>
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 2147483647
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (0 smax (4 + %a)<nsw>)
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (4 + %a)<nsw>
; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ define void @rewrite_preserve_add_nsw(i32 %a) {
; CHECK-NEXT: [[PRE:%.*]] = icmp sgt i32 [[A]], -4
; CHECK-NEXT: br i1 [[PRE]], label [[LOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
; CHECK: loop.preheader:
; CHECK-NEXT: [[SMAX:%.*]] = call i32 @llvm.smax.i32(i32 [[ADD]], i32 0)
; CHECK-NEXT: [[TMP0:%.*]] = add nuw i32 [[SMAX]], 1
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[A]], 5
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], [[LOOP]] ], [ 0, [[LOOP_PREHEADER]] ]
Expand Down
Loading