Skip to content

Commit fdd9f0c

Browse files
committed
[LoopIdiom] Improve code; use SCEVPatternMatch (NFC)
1 parent 47ce75e commit fdd9f0c

File tree

1 file changed

+25
-36
lines changed

1 file changed

+25
-36
lines changed

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "llvm/ADT/ArrayRef.h"
3333
#include "llvm/ADT/DenseMap.h"
3434
#include "llvm/ADT/MapVector.h"
35-
#include "llvm/ADT/STLExtras.h"
3635
#include "llvm/ADT/SetVector.h"
3736
#include "llvm/ADT/SmallPtrSet.h"
3837
#include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
4948
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
5049
#include "llvm/Analysis/ScalarEvolution.h"
5150
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
51+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
5252
#include "llvm/Analysis/TargetLibraryInfo.h"
5353
#include "llvm/Analysis/TargetTransformInfo.h"
5454
#include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
9191
#include <vector>
9292

9393
using namespace llvm;
94+
using namespace SCEVPatternMatch;
9495

9596
#define DEBUG_TYPE "loop-idiom"
9697

@@ -340,9 +341,8 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
340341

341342
// If this loop executes exactly one time, then it should be peeled, not
342343
// optimized by this pass.
343-
if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
344-
if (BECst->getAPInt() == 0)
345-
return false;
344+
if (match(BECount, m_scev_SpecificInt(0)))
345+
return false;
346346

347347
SmallVector<BasicBlock *, 8> ExitBlocks;
348348
CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -805,20 +805,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
805805

806806
// Check if the stride matches the size of the memcpy. If so, then we know
807807
// that every byte is touched in the loop.
808-
const SCEVConstant *ConstStoreStride =
809-
dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
810-
const SCEVConstant *ConstLoadStride =
811-
dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
812-
if (!ConstStoreStride || !ConstLoadStride)
808+
const APInt *StoreStrideValue, *LoadStrideValue;
809+
if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
810+
!match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
813811
return false;
814812

815-
APInt StoreStrideValue = ConstStoreStride->getAPInt();
816-
APInt LoadStrideValue = ConstLoadStride->getAPInt();
817813
// Huge stride value - give up
818-
if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
814+
if (StoreStrideValue->getBitWidth() > 64 ||
815+
LoadStrideValue->getBitWidth() > 64)
819816
return false;
820817

821-
if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
818+
if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
822819
ORE.emit([&]() {
823820
return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
824821
<< ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +826,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
829826
return false;
830827
}
831828

832-
int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
833-
int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
829+
int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
830+
int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
834831
// Check if the load stride matches the store stride.
835832
if (StoreStrideInt != LoadStrideInt)
836833
return false;
@@ -879,15 +876,14 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
879876
// we know that every byte is touched in the loop.
880877
LLVM_DEBUG(dbgs() << " memset size is constant\n");
881878
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
882-
const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
883-
if (!ConstStride)
879+
const APInt *Stride;
880+
if (!match(Ev->getOperand(1), m_scev_APInt(Stride)))
884881
return false;
885882

886-
APInt Stride = ConstStride->getAPInt();
887-
if (SizeInBytes != Stride && SizeInBytes != -Stride)
883+
if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
888884
return false;
889885

890-
IsNegStride = SizeInBytes == -Stride;
886+
IsNegStride = SizeInBytes == -*Stride;
891887
} else {
892888
// Memset size is non-constant.
893889
// Check if the pointer stride matches the memset size.
@@ -963,11 +959,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
963959

964960
// If the loop iterates a fixed number of times, we can refine the access size
965961
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
966-
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
967-
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
968-
if (BECst && ConstSize) {
969-
std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
970-
std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
962+
const APInt *BECst, *ConstSize;
963+
if (match(BECount, m_scev_APInt(BECst)) &&
964+
match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
965+
std::optional<uint64_t> BEInt = BECst->tryZExtValue();
966+
std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
971967
// FIXME: Should this check for overflow?
972968
if (BEInt && SizeInt)
973969
AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1605,16 +1601,11 @@ class StrlenVerifier {
16051601

16061602
LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
16071603

1608-
const SCEVConstant *Step =
1609-
dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
1610-
if (!Step)
1604+
const APInt *Step;
1605+
if (!match(LoadEv->getStepRecurrence(*SE), m_scev_APInt(Step)))
16111606
return false;
16121607

1613-
unsigned StepSize = 0;
1614-
StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
1615-
if (!StepSizeCI)
1616-
return false;
1617-
StepSize = StepSizeCI->getZExtValue();
1608+
unsigned StepSize = Step->getZExtValue();
16181609

16191610
// Verify that StepSize is consistent with platform char width.
16201611
OpWidth = OperandType->getIntegerBitWidth();
@@ -3294,9 +3285,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
32943285
// Ok, transform appears worthwhile.
32953286
MadeChange = true;
32963287

3297-
bool OffsetIsZero = false;
3298-
if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
3299-
OffsetIsZero = ExtraOffsetExprC->isZero();
3288+
bool OffsetIsZero = match(ExtraOffsetExpr, m_scev_SpecificInt(0));
33003289

33013290
// Step 1: Compute the loop's final IV value / trip count.
33023291

0 commit comments

Comments
 (0)