Skip to content

Commit e23a692

Browse files
authored
[LoopIdiom] Improve code; use SCEVPatternMatch (NFC) (#139540)
1 parent b544853 commit e23a692

File tree

1 file changed

+45
-63
lines changed

1 file changed

+45
-63
lines changed

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 45 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "llvm/ADT/ArrayRef.h"
3333
#include "llvm/ADT/DenseMap.h"
3434
#include "llvm/ADT/MapVector.h"
35-
#include "llvm/ADT/STLExtras.h"
3635
#include "llvm/ADT/SetVector.h"
3736
#include "llvm/ADT/SmallPtrSet.h"
3837
#include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
4948
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
5049
#include "llvm/Analysis/ScalarEvolution.h"
5150
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
51+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
5252
#include "llvm/Analysis/TargetLibraryInfo.h"
5353
#include "llvm/Analysis/TargetTransformInfo.h"
5454
#include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
9191
#include <vector>
9292

9393
using namespace llvm;
94+
using namespace SCEVPatternMatch;
9495

9596
#define DEBUG_TYPE "loop-idiom"
9697

@@ -340,9 +341,8 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
340341

341342
// If this loop executes exactly one time, then it should be peeled, not
342343
// optimized by this pass.
343-
if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
344-
if (BECst->getAPInt() == 0)
345-
return false;
344+
if (BECount->isZero())
345+
return false;
346346

347347
SmallVector<BasicBlock *, 8> ExitBlocks;
348348
CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -453,13 +453,10 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
453453
// See if the pointer expression is an AddRec like {base,+,1} on the current
454454
// loop, which indicates a strided store. If we have something else, it's a
455455
// random store we can't handle.
456-
const SCEVAddRecExpr *StoreEv =
457-
dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
458-
if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
459-
return LegalStoreKind::None;
460-
461-
// Check to see if we have a constant stride.
462-
if (!isa<SCEVConstant>(StoreEv->getOperand(1)))
456+
const SCEV *StoreEv = SE->getSCEV(StorePtr);
457+
const SCEVConstant *Stride;
458+
if (!match(StoreEv, m_scev_AffineAddRec(m_SCEV(), m_SCEVConstant(Stride))) ||
459+
cast<SCEVAddRecExpr>(StoreEv)->getLoop() != CurLoop)
463460
return LegalStoreKind::None;
464461

465462
// See if the store can be turned into a memset.
@@ -494,9 +491,9 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
494491
if (HasMemcpy && !DisableLIRP::Memcpy) {
495492
// Check to see if the stride matches the size of the store. If so, then we
496493
// know that every byte is touched in the loop.
497-
APInt Stride = getStoreStride(StoreEv);
498494
unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
499-
if (StoreSize != Stride && StoreSize != -Stride)
495+
APInt StrideAP = Stride->getAPInt();
496+
if (StoreSize != StrideAP && StoreSize != -StrideAP)
500497
return LegalStoreKind::None;
501498

502499
// The store must be feeding a non-volatile load.
@@ -512,13 +509,12 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) {
512509
// See if the pointer expression is an AddRec like {base,+,1} on the current
513510
// loop, which indicates a strided load. If we have something else, it's a
514511
// random load we can't handle.
515-
const SCEVAddRecExpr *LoadEv =
516-
dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
517-
if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
518-
return LegalStoreKind::None;
512+
const SCEV *LoadEv = SE->getSCEV(LI->getPointerOperand());
519513

520514
// The store and load must share the same stride.
521-
if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
515+
if (!match(LoadEv,
516+
m_scev_AffineAddRec(m_SCEV(), m_scev_Specific(Stride))) ||
517+
cast<SCEVAddRecExpr>(LoadEv)->getLoop() != CurLoop)
522518
return LegalStoreKind::None;
523519

524520
// Success. This store can be converted into a memcpy.
@@ -805,20 +801,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
805801

806802
// Check if the stride matches the size of the memcpy. If so, then we know
807803
// that every byte is touched in the loop.
808-
const SCEVConstant *ConstStoreStride =
809-
dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
810-
const SCEVConstant *ConstLoadStride =
811-
dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
812-
if (!ConstStoreStride || !ConstLoadStride)
804+
const APInt *StoreStrideValue, *LoadStrideValue;
805+
if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
806+
!match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
813807
return false;
814808

815-
APInt StoreStrideValue = ConstStoreStride->getAPInt();
816-
APInt LoadStrideValue = ConstLoadStride->getAPInt();
817809
// Huge stride value - give up
818-
if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
810+
if (StoreStrideValue->getBitWidth() > 64 ||
811+
LoadStrideValue->getBitWidth() > 64)
819812
return false;
820813

821-
if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
814+
if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
822815
ORE.emit([&]() {
823816
return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
824817
<< ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +822,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
829822
return false;
830823
}
831824

832-
int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
833-
int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
825+
int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
826+
int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
834827
// Check if the load stride matches the store stride.
835828
if (StoreStrideInt != LoadStrideInt)
836829
return false;
@@ -857,15 +850,15 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
857850
// See if the pointer expression is an AddRec like {base,+,1} on the current
858851
// loop, which indicates a strided store. If we have something else, it's a
859852
// random store we can't handle.
860-
const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer));
861-
if (!Ev || Ev->getLoop() != CurLoop)
862-
return false;
863-
if (!Ev->isAffine()) {
853+
const SCEV *Ev = SE->getSCEV(Pointer);
854+
const SCEV *PointerStrideSCEV;
855+
if (!match(Ev, m_scev_AffineAddRec(m_SCEV(), m_SCEV(PointerStrideSCEV)))) {
864856
LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n");
865857
return false;
866858
}
859+
if (cast<SCEVAddRecExpr>(Ev)->getLoop() != CurLoop)
860+
return false;
867861

868-
const SCEV *PointerStrideSCEV = Ev->getOperand(1);
869862
const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength());
870863
if (!PointerStrideSCEV || !MemsetSizeSCEV)
871864
return false;
@@ -879,15 +872,14 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
879872
// we know that every byte is touched in the loop.
880873
LLVM_DEBUG(dbgs() << " memset size is constant\n");
881874
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
882-
const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
883-
if (!ConstStride)
875+
const APInt *Stride;
876+
if (!match(PointerStrideSCEV, m_scev_APInt(Stride)))
884877
return false;
885878

886-
APInt Stride = ConstStride->getAPInt();
887-
if (SizeInBytes != Stride && SizeInBytes != -Stride)
879+
if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
888880
return false;
889881

890-
IsNegStride = SizeInBytes == -Stride;
882+
IsNegStride = SizeInBytes == -*Stride;
891883
} else {
892884
// Memset size is non-constant.
893885
// Check if the pointer stride matches the memset size.
@@ -944,8 +936,9 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
944936
SmallPtrSet<Instruction *, 1> MSIs;
945937
MSIs.insert(MSI);
946938
return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
947-
MSI->getDestAlign(), SplatValue, MSI, MSIs, Ev,
948-
BECount, IsNegStride, /*IsLoopMemset=*/true);
939+
MSI->getDestAlign(), SplatValue, MSI, MSIs,
940+
cast<SCEVAddRecExpr>(Ev), BECount, IsNegStride,
941+
/*IsLoopMemset=*/true);
949942
}
950943

951944
/// mayLoopAccessLocation - Return true if the specified loop might access the
@@ -963,11 +956,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
963956

964957
// If the loop iterates a fixed number of times, we can refine the access size
965958
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
966-
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
967-
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
968-
if (BECst && ConstSize) {
969-
std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
970-
std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
959+
const APInt *BECst, *ConstSize;
960+
if (match(BECount, m_scev_APInt(BECst)) &&
961+
match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
962+
std::optional<uint64_t> BEInt = BECst->tryZExtValue();
963+
std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
971964
// FIXME: Should this check for overflow?
972965
if (BEInt && SizeInt)
973966
AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1579,24 +1572,15 @@ class StrlenVerifier {
15791572
// See if the pointer expression is an AddRec with constant step a of form
15801573
// ({n,+,a}) where a is the width of the char type.
15811574
Value *IncPtr = LoopLoad->getPointerOperand();
1582-
const SCEVAddRecExpr *LoadEv =
1583-
dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IncPtr));
1584-
if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
1575+
const SCEV *LoadEv = SE->getSCEV(IncPtr);
1576+
const APInt *Step;
1577+
if (!match(LoadEv,
1578+
m_scev_AffineAddRec(m_SCEV(LoadBaseEv), m_scev_APInt(Step))))
15851579
return false;
1586-
LoadBaseEv = LoadEv->getStart();
15871580

15881581
LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
15891582

1590-
const SCEVConstant *Step =
1591-
dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
1592-
if (!Step)
1593-
return false;
1594-
1595-
unsigned StepSize = 0;
1596-
StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
1597-
if (!StepSizeCI)
1598-
return false;
1599-
StepSize = StepSizeCI->getZExtValue();
1583+
unsigned StepSize = Step->getZExtValue();
16001584

16011585
// Verify that StepSize is consistent with platform char width.
16021586
OpWidth = OperandType->getIntegerBitWidth();
@@ -3277,9 +3261,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
32773261
// Ok, transform appears worthwhile.
32783262
MadeChange = true;
32793263

3280-
bool OffsetIsZero = false;
3281-
if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
3282-
OffsetIsZero = ExtraOffsetExprC->isZero();
3264+
bool OffsetIsZero = ExtraOffsetExpr->isZero();
32833265

32843266
// Step 1: Compute the loop's final IV value / trip count.
32853267

0 commit comments

Comments
 (0)