Skip to content

Commit 26aa1bb

Browse files
eopXDeopXD
authored andcommitted
[NFCI] [LoopIdiom] Let processLoopStridedStore take StoreSize as SCEV instead of unsigned
Letting it take SCEV allows further modification on the function to optimize if the StoreSize / Stride is runtime determined. This is a preceeding of D107353. The big picture is to let LoopIdiom deal with runtime-determined sizes. Reviewed By: Whitney, lebedev.ri Differential Revision: https://reviews.llvm.org/D104595
1 parent 9c3345a commit 26aa1bb

File tree

1 file changed

+68
-38
lines changed

1 file changed

+68
-38
lines changed

llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ class LoopIdiomRecognize {
217217
bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
218218
bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
219219

220-
bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize,
220+
bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
221221
MaybeAlign StoreAlignment, Value *StoredVal,
222222
Instruction *TheStore,
223223
SmallPtrSetImpl<Instruction *> &Stores,
@@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
786786

787787
bool NegStride = StoreSize == -Stride;
788788

789-
if (processLoopStridedStore(StorePtr, StoreSize,
789+
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
790+
if (processLoopStridedStore(StorePtr, StoreSizeSCEV,
790791
MaybeAlign(HeadStore->getAlignment()),
791792
StoredVal, HeadStore, AdjacentStores, StoreEv,
792793
BECount, NegStride)) {
@@ -936,17 +937,18 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
936937
SmallPtrSet<Instruction *, 1> MSIs;
937938
MSIs.insert(MSI);
938939
bool NegStride = SizeInBytes == -Stride;
939-
return processLoopStridedStore(
940-
Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()),
941-
SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true);
940+
return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
941+
MaybeAlign(MSI->getDestAlignment()),
942+
SplatValue, MSI, MSIs, Ev, BECount, NegStride,
943+
/*IsLoopMemset=*/true);
942944
}
943945

944946
/// mayLoopAccessLocation - Return true if the specified loop might access the
945947
/// specified pointer location, which is a loop-strided access. The 'Access'
946948
/// argument specifies what the verboten forms of access are (read or write).
947949
static bool
948950
mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
949-
const SCEV *BECount, unsigned StoreSize,
951+
const SCEV *BECount, const SCEV *StoreSizeSCEV,
950952
AliasAnalysis &AA,
951953
SmallPtrSetImpl<Instruction *> &IgnoredStores) {
952954
// Get the location that may be stored across the loop. Since the access is
@@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
956958

957959
// If the loop iterates a fixed number of times, we can refine the access size
958960
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
959-
if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
961+
const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
962+
const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
963+
if (BECst && ConstSize)
960964
AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
961-
StoreSize);
965+
ConstSize->getValue()->getZExtValue());
962966

963967
// TODO: For this to be really effective, we have to dive into the pointer
964968
// operand in the store. Store to &A[i] of 100 will always return may alias
@@ -973,62 +977,85 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
973977
isModOrRefSet(
974978
intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
975979
return true;
976-
977980
return false;
978981
}
979982

980983
// If we have a negative stride, Start refers to the end of the memory location
981984
// we're trying to memset. Therefore, we need to recompute the base pointer,
982985
// which is just Start - BECount*Size.
983986
static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
984-
Type *IntPtr, unsigned StoreSize,
987+
Type *IntPtr, const SCEV *StoreSizeSCEV,
985988
ScalarEvolution *SE) {
986989
const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
987-
if (StoreSize != 1)
988-
Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize),
990+
if (!StoreSizeSCEV->isOne()) {
991+
// index = back edge count * store size
992+
Index = SE->getMulExpr(Index,
993+
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
989994
SCEV::FlagNUW);
995+
}
996+
// base pointer = start - index * store size
990997
return SE->getMinusSCEV(Start, Index);
991998
}
992999

993-
/// Compute the number of bytes as a SCEV from the backedge taken count.
994-
///
995-
/// This also maps the SCEV into the provided type and tries to handle the
996-
/// computation in a way that will fold cleanly.
997-
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
998-
unsigned StoreSize, Loop *CurLoop,
999-
const DataLayout *DL, ScalarEvolution *SE) {
1000-
const SCEV *NumBytesS;
1001-
// The # stored bytes is (BECount+1)*Size. Expand the trip count out to
1000+
/// Compute trip count from the backedge taken count.
1001+
static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
1002+
Loop *CurLoop, const DataLayout *DL,
1003+
ScalarEvolution *SE) {
1004+
const SCEV *TripCountS = nullptr;
1005+
// The # stored bytes is (BECount+1). Expand the trip count out to
10021006
// pointer size if it isn't already.
10031007
//
10041008
// If we're going to need to zero extend the BE count, check if we can add
10051009
// one to it prior to zero extending without overflow. Provided this is safe,
10061010
// it allows better simplification of the +1.
1007-
if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() <
1008-
DL->getTypeSizeInBits(IntPtr).getFixedSize() &&
1011+
if (DL->getTypeSizeInBits(BECount->getType()) <
1012+
DL->getTypeSizeInBits(IntPtr) &&
10091013
SE->isLoopEntryGuardedByCond(
10101014
CurLoop, ICmpInst::ICMP_NE, BECount,
10111015
SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
1012-
NumBytesS = SE->getZeroExtendExpr(
1016+
TripCountS = SE->getZeroExtendExpr(
10131017
SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
10141018
IntPtr);
10151019
} else {
1016-
NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
1017-
SE->getOne(IntPtr), SCEV::FlagNUW);
1020+
TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
1021+
SE->getOne(IntPtr), SCEV::FlagNUW);
10181022
}
10191023

1024+
return TripCountS;
1025+
}
1026+
1027+
/// Compute the number of bytes as a SCEV from the backedge taken count.
1028+
///
1029+
/// This also maps the SCEV into the provided type and tries to handle the
1030+
/// computation in a way that will fold cleanly.
1031+
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
1032+
unsigned StoreSize, Loop *CurLoop,
1033+
const DataLayout *DL, ScalarEvolution *SE) {
1034+
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
1035+
10201036
// And scale it based on the store size.
10211037
if (StoreSize != 1) {
1022-
NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
1023-
SCEV::FlagNUW);
1038+
return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize),
1039+
SCEV::FlagNUW);
10241040
}
1025-
return NumBytesS;
1041+
return TripCountSCEV;
1042+
}
1043+
1044+
/// getNumBytes that takes StoreSize as a SCEV
1045+
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
1046+
const SCEV *StoreSizeSCEV, Loop *CurLoop,
1047+
const DataLayout *DL, ScalarEvolution *SE) {
1048+
const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
1049+
1050+
return SE->getMulExpr(TripCountSCEV,
1051+
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
1052+
SCEV::FlagNUW);
10261053
}
10271054

10281055
/// processLoopStridedStore - We see a strided store of some value. If we can
10291056
/// transform this into a memset or memset_pattern in the loop preheader, do so.
10301057
bool LoopIdiomRecognize::processLoopStridedStore(
1031-
Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment,
1058+
Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
10321059
Value *StoredVal, Instruction *TheStore,
10331060
SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
10341061
const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
@@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
10571084
const SCEV *Start = Ev->getStart();
10581085
// Handle negative strided loops.
10591086
if (NegStride)
1060-
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE);
1087+
Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE);
10611088

10621089
// TODO: ideally we should still be able to generate memset if SCEV expander
10631090
// is taught to generate the dependencies at the latest point.
@@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
10821109
Changed = true;
10831110

10841111
if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1085-
StoreSize, *AA, Stores))
1112+
StoreSizeSCEV, *AA, Stores))
10861113
return Changed;
10871114

10881115
if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
@@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
10911118
// Okay, everything looks good, insert the memset.
10921119

10931120
const SCEV *NumBytesS =
1094-
getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE);
1121+
getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
10951122

10961123
// TODO: ideally we should still be able to generate memset if SCEV expander
10971124
// is taught to generate the dependencies at the latest point.
@@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
12151242
APInt Stride = getStoreStride(StoreEv);
12161243
bool NegStride = StoreSize == -Stride;
12171244

1245+
const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
12181246
// Handle negative strided loops.
12191247
if (NegStride)
1220-
StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE);
1248+
StrStart =
1249+
getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
12211250

12221251
// Okay, we have a strided store "p[i]" of a loaded value. We can turn
12231252
// this into a memcpy in the loop preheader now if we want. However, this
@@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
12451274

12461275
bool UseMemMove =
12471276
mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1248-
StoreSize, *AA, Stores);
1277+
StoreSizeSCEV, *AA, Stores);
12491278
if (UseMemMove) {
12501279
Stores.insert(TheLoad);
12511280
if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
1252-
BECount, StoreSize, *AA, Stores)) {
1281+
BECount, StoreSizeSCEV, *AA, Stores)) {
12531282
ORE.emit([&]() {
12541283
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
12551284
TheStore)
@@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
12681297

12691298
// Handle negative strided loops.
12701299
if (NegStride)
1271-
LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE);
1300+
LdStart =
1301+
getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
12721302

12731303
// For a memcpy, we have to make sure that the input array is not being
12741304
// mutated by the loop.
@@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
12801310
if (IsMemCpy)
12811311
Stores.erase(TheStore);
12821312
if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
1283-
StoreSize, *AA, Stores)) {
1313+
StoreSizeSCEV, *AA, Stores)) {
12841314
ORE.emit([&]() {
12851315
return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
12861316
<< ore::NV("Inst", InstRemark) << " in "

0 commit comments

Comments
 (0)