Skip to content

[SCEV][LV] Invalidate LCSSA exit phis more thoroughly #69909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,10 @@ class ScalarEvolution {
/// def-use chain linking it to a loop.
void forgetValue(Value *V);

/// Forget LCSSA phi node V of loop L to which a new predecessor was added,
/// such that it may no longer be trivial.
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V);

/// Called when the client has changed the disposition of values in
/// this loop.
///
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8410,6 +8410,44 @@ void ScalarEvolution::forgetValue(Value *V) {
forgetMemoizedResults(ToForget);
}

void ScalarEvolution::forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V) {
if (!isSCEVable(V->getType()))
return;

// If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
// directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
// extra predecessor is added, this is no longer valid. Find all Unknowns and
// AddRecs defined in the loop and invalidate any SCEV's making use of them.
if (const SCEV *S = getExistingSCEV(V)) {
struct InvalidationRootCollector {
Loop *L;
SmallVector<const SCEV *, 8> Roots;

InvalidationRootCollector(Loop *L) : L(L) {}

bool follow(const SCEV *S) {
if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
if (auto *I = dyn_cast<Instruction>(SU->getValue()))
if (L->contains(I))
Roots.push_back(S);
} else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
if (L->contains(AddRec->getLoop()))
Roots.push_back(S);
}
return true;
}
bool isDone() const { return false; }
};

InvalidationRootCollector C(L);
visitAll(S, C);
forgetMemoizedResults(C.Roots);
}

// Also perform the normal invalidation.
forgetValue(V);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sufficient to do this if V is SCEVable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. forgetValue() will skip non-SCEVable itself. I've converted this into an early exit for the whole function.

}

void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }

void ScalarEvolution::forgetBlockAndLoopDispositions(Value *V) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3566,7 +3566,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State,
OrigLoop->getExitBlocks(ExitBlocks);
for (BasicBlock *Exit : ExitBlocks)
for (PHINode &PN : Exit->phis())
PSE.getSE()->forgetValue(&PN);
PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN);

VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock();
Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]);
Expand Down
95 changes: 95 additions & 0 deletions llvm/test/Transforms/LoopVectorize/pr66616.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
; RUN: opt -passes="print<scalar-evolution>,loop-vectorize" --verify-scev -S < %s -force-vector-width=4 2>/dev/null | FileCheck %s

; Make sure users of SCEVUnknowns from the scalar loop are invalidated.

define void @pr66616(ptr %ptr) {
; CHECK-LABEL: define void @pr66616(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[PTR]], align 4
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[TMP0]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i32> [[BROADCAST_SPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[BROADCAST_SPLAT]], <i32 1, i32 1, i32 1, i32 1>
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[INDEX_NEXT]], 256
; CHECK-NEXT: br i1 [[TMP2]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK: middle.block:
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP1]], i32 3
; CHECK-NEXT: br i1 true, label [[PREHEADER:%.*]], label [[SCALAR_PH]]
; CHECK: scalar.ph:
; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i8 [ 0, [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ]
; CHECK-NEXT: br label [[LOOP_1:%.*]]
; CHECK: loop.1:
; CHECK-NEXT: [[IV_1:%.*]] = phi i8 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INC:%.*]], [[LOOP_1]] ]
; CHECK-NEXT: [[LOAD:%.*]] = load i32, ptr [[PTR]], align 4
; CHECK-NEXT: [[ADD3:%.*]] = add i32 [[LOAD]], 1
; CHECK-NEXT: [[INC]] = add i8 [[IV_1]], 1
; CHECK-NEXT: [[COND1:%.*]] = icmp eq i8 [[INC]], 0
; CHECK-NEXT: br i1 [[COND1]], label [[PREHEADER]], label [[LOOP_1]], !llvm.loop [[LOOP3:![0-9]+]]
; CHECK: preheader:
; CHECK-NEXT: [[ADD3_LCSSA:%.*]] = phi i32 [ [[ADD3]], [[LOOP_1]] ], [ [[TMP3]], [[MIDDLE_BLOCK]] ]
; CHECK-NEXT: [[TMP4:%.*]] = sub i32 0, [[ADD3_LCSSA]]
; CHECK-NEXT: [[TMP5:%.*]] = zext i32 [[TMP4]] to i64
; CHECK-NEXT: [[TMP6:%.*]] = add nuw nsw i64 [[TMP5]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP6]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH2:%.*]], label [[VECTOR_PH3:%.*]]
; CHECK: vector.ph3:
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP6]], 4
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP6]], [[N_MOD_VF]]
; CHECK-NEXT: [[DOTCAST:%.*]] = trunc i64 [[N_VEC]] to i32
; CHECK-NEXT: [[IND_END:%.*]] = add i32 [[ADD3_LCSSA]], [[DOTCAST]]
; CHECK-NEXT: [[IND_END5:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[N_VEC]]
; CHECK-NEXT: br label [[VECTOR_BODY7:%.*]]
; CHECK: vector.body7:
; CHECK-NEXT: [[INDEX8:%.*]] = phi i64 [ 0, [[VECTOR_PH3]] ], [ [[INDEX_NEXT9:%.*]], [[VECTOR_BODY7]] ]
; CHECK-NEXT: [[INDEX_NEXT9]] = add nuw i64 [[INDEX8]], 4
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT9]], [[N_VEC]]
; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK1:%.*]], label [[VECTOR_BODY7]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK: middle.block1:
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP6]], [[N_VEC]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH2]]
; CHECK: scalar.ph2:
; CHECK-NEXT: [[BC_RESUME_VAL4:%.*]] = phi i32 [ [[IND_END]], [[MIDDLE_BLOCK1]] ], [ [[ADD3_LCSSA]], [[PREHEADER]] ]
; CHECK-NEXT: [[BC_RESUME_VAL6:%.*]] = phi ptr [ [[IND_END5]], [[MIDDLE_BLOCK1]] ], [ [[PTR]], [[PREHEADER]] ]
; CHECK-NEXT: br label [[LOOP_2:%.*]]
; CHECK: loop.2:
; CHECK-NEXT: [[IV_2:%.*]] = phi i32 [ [[IV_2_I:%.*]], [[LOOP_2]] ], [ [[BC_RESUME_VAL4]], [[SCALAR_PH2]] ]
; CHECK-NEXT: [[IV_3:%.*]] = phi ptr [ [[IV_3_I:%.*]], [[LOOP_2]] ], [ [[BC_RESUME_VAL6]], [[SCALAR_PH2]] ]
; CHECK-NEXT: [[IV_2_I]] = add i32 [[IV_2]], 1
; CHECK-NEXT: [[IV_3_I]] = getelementptr i8, ptr [[IV_3]], i64 1
; CHECK-NEXT: [[COND2:%.*]] = icmp eq i32 [[IV_2]], 0
; CHECK-NEXT: br i1 [[COND2]], label [[EXIT]], label [[LOOP_2]], !llvm.loop [[LOOP5:![0-9]+]]
; CHECK: exit:
; CHECK-NEXT: ret void
;
entry:
br label %loop.1

loop.1:
%iv.1 = phi i8 [ 0, %entry ], [ %inc, %loop.1 ]
%load = load i32, ptr %ptr, align 4
%add3 = add i32 %load, 1
%inc = add i8 %iv.1, 1
%cond1 = icmp eq i8 %inc, 0
br i1 %cond1, label %preheader, label %loop.1

preheader:
br label %loop.2

loop.2:
%iv.2 = phi i32 [ %iv.2.i, %loop.2 ], [ %add3, %preheader ]
%iv.3 = phi ptr [ %iv.3.i, %loop.2 ], [ %ptr, %preheader ]
%iv.2.i = add i32 %iv.2, 1
%iv.3.i = getelementptr i8, ptr %iv.3, i64 1
%cond2 = icmp eq i32 %iv.2, 0
br i1 %cond2, label %exit, label %loop.2

exit:
ret void
}