@@ -217,7 +217,7 @@ class LoopIdiomRecognize {
217
217
bool processLoopMemCpy (MemCpyInst *MCI, const SCEV *BECount);
218
218
bool processLoopMemSet (MemSetInst *MSI, const SCEV *BECount);
219
219
220
- bool processLoopStridedStore (Value *DestPtr, unsigned StoreSize ,
220
+ bool processLoopStridedStore (Value *DestPtr, const SCEV *StoreSizeSCEV ,
221
221
MaybeAlign StoreAlignment, Value *StoredVal,
222
222
Instruction *TheStore,
223
223
SmallPtrSetImpl<Instruction *> &Stores,
@@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
786
786
787
787
bool NegStride = StoreSize == -Stride;
788
788
789
- if (processLoopStridedStore (StorePtr, StoreSize,
789
+ const SCEV *StoreSizeSCEV = SE->getConstant (BECount->getType (), StoreSize);
790
+ if (processLoopStridedStore (StorePtr, StoreSizeSCEV,
790
791
MaybeAlign (HeadStore->getAlignment ()),
791
792
StoredVal, HeadStore, AdjacentStores, StoreEv,
792
793
BECount, NegStride)) {
@@ -936,17 +937,18 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
936
937
SmallPtrSet<Instruction *, 1 > MSIs;
937
938
MSIs.insert (MSI);
938
939
bool NegStride = SizeInBytes == -Stride;
939
- return processLoopStridedStore (
940
- Pointer, (unsigned )SizeInBytes, MaybeAlign (MSI->getDestAlignment ()),
941
- SplatValue, MSI, MSIs, Ev, BECount, NegStride, /* IsLoopMemset=*/ true );
940
+ return processLoopStridedStore (Pointer, SE->getSCEV (MSI->getLength ()),
941
+ MaybeAlign (MSI->getDestAlignment ()),
942
+ SplatValue, MSI, MSIs, Ev, BECount, NegStride,
943
+ /* IsLoopMemset=*/ true );
942
944
}
943
945
944
946
// / mayLoopAccessLocation - Return true if the specified loop might access the
945
947
// / specified pointer location, which is a loop-strided access. The 'Access'
946
948
// / argument specifies what the verboten forms of access are (read or write).
947
949
static bool
948
950
mayLoopAccessLocation (Value *Ptr, ModRefInfo Access, Loop *L,
949
- const SCEV *BECount, unsigned StoreSize ,
951
+ const SCEV *BECount, const SCEV *StoreSizeSCEV ,
950
952
AliasAnalysis &AA,
951
953
SmallPtrSetImpl<Instruction *> &IgnoredStores) {
952
954
// Get the location that may be stored across the loop. Since the access is
@@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
956
958
957
959
// If the loop iterates a fixed number of times, we can refine the access size
958
960
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
959
- if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
961
+ const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
962
+ const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
963
+ if (BECst && ConstSize)
960
964
AccessSize = LocationSize::precise ((BECst->getValue ()->getZExtValue () + 1 ) *
961
- StoreSize );
965
+ ConstSize-> getValue ()-> getZExtValue () );
962
966
963
967
// TODO: For this to be really effective, we have to dive into the pointer
964
968
// operand in the store. Store to &A[i] of 100 will always return may alias
@@ -973,62 +977,85 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
973
977
isModOrRefSet (
974
978
intersectModRef (AA.getModRefInfo (&I, StoreLoc), Access)))
975
979
return true ;
976
-
977
980
return false ;
978
981
}
979
982
980
983
// If we have a negative stride, Start refers to the end of the memory location
981
984
// we're trying to memset. Therefore, we need to recompute the base pointer,
982
985
// which is just Start - BECount*Size.
983
986
static const SCEV *getStartForNegStride (const SCEV *Start, const SCEV *BECount,
984
- Type *IntPtr, unsigned StoreSize ,
987
+ Type *IntPtr, const SCEV *StoreSizeSCEV ,
985
988
ScalarEvolution *SE) {
986
989
const SCEV *Index = SE->getTruncateOrZeroExtend (BECount, IntPtr);
987
- if (StoreSize != 1 )
988
- Index = SE->getMulExpr (Index, SE->getConstant (IntPtr, StoreSize),
990
+ if (!StoreSizeSCEV->isOne ()) {
991
+ // index = back edge count * store size
992
+ Index = SE->getMulExpr (Index,
993
+ SE->getTruncateOrZeroExtend (StoreSizeSCEV, IntPtr),
989
994
SCEV::FlagNUW);
995
+ }
996
+ // base pointer = start - index * store size
990
997
return SE->getMinusSCEV (Start, Index);
991
998
}
992
999
993
- // / Compute the number of bytes as a SCEV from the backedge taken count.
994
- // /
995
- // / This also maps the SCEV into the provided type and tries to handle the
996
- // / computation in a way that will fold cleanly.
997
- static const SCEV *getNumBytes (const SCEV *BECount, Type *IntPtr,
998
- unsigned StoreSize, Loop *CurLoop,
999
- const DataLayout *DL, ScalarEvolution *SE) {
1000
- const SCEV *NumBytesS;
1001
- // The # stored bytes is (BECount+1)*Size. Expand the trip count out to
1000
+ // / Compute trip count from the backedge taken count.
1001
+ static const SCEV *getTripCount (const SCEV *BECount, Type *IntPtr,
1002
+ Loop *CurLoop, const DataLayout *DL,
1003
+ ScalarEvolution *SE) {
1004
+ const SCEV *TripCountS = nullptr ;
1005
+ // The # stored bytes is (BECount+1). Expand the trip count out to
1002
1006
// pointer size if it isn't already.
1003
1007
//
1004
1008
// If we're going to need to zero extend the BE count, check if we can add
1005
1009
// one to it prior to zero extending without overflow. Provided this is safe,
1006
1010
// it allows better simplification of the +1.
1007
- if (DL->getTypeSizeInBits (BECount->getType ()). getFixedSize () <
1008
- DL->getTypeSizeInBits (IntPtr). getFixedSize () &&
1011
+ if (DL->getTypeSizeInBits (BECount->getType ()) <
1012
+ DL->getTypeSizeInBits (IntPtr) &&
1009
1013
SE->isLoopEntryGuardedByCond (
1010
1014
CurLoop, ICmpInst::ICMP_NE, BECount,
1011
1015
SE->getNegativeSCEV (SE->getOne (BECount->getType ())))) {
1012
- NumBytesS = SE->getZeroExtendExpr (
1016
+ TripCountS = SE->getZeroExtendExpr (
1013
1017
SE->getAddExpr (BECount, SE->getOne (BECount->getType ()), SCEV::FlagNUW),
1014
1018
IntPtr);
1015
1019
} else {
1016
- NumBytesS = SE->getAddExpr (SE->getTruncateOrZeroExtend (BECount, IntPtr),
1017
- SE->getOne (IntPtr), SCEV::FlagNUW);
1020
+ TripCountS = SE->getAddExpr (SE->getTruncateOrZeroExtend (BECount, IntPtr),
1021
+ SE->getOne (IntPtr), SCEV::FlagNUW);
1018
1022
}
1019
1023
1024
+ return TripCountS;
1025
+ }
1026
+
1027
+ // / Compute the number of bytes as a SCEV from the backedge taken count.
1028
+ // /
1029
+ // / This also maps the SCEV into the provided type and tries to handle the
1030
+ // / computation in a way that will fold cleanly.
1031
+ static const SCEV *getNumBytes (const SCEV *BECount, Type *IntPtr,
1032
+ unsigned StoreSize, Loop *CurLoop,
1033
+ const DataLayout *DL, ScalarEvolution *SE) {
1034
+ const SCEV *TripCountSCEV = getTripCount (BECount, IntPtr, CurLoop, DL, SE);
1035
+
1020
1036
// And scale it based on the store size.
1021
1037
if (StoreSize != 1 ) {
1022
- NumBytesS = SE->getMulExpr (NumBytesS , SE->getConstant (IntPtr, StoreSize),
1023
- SCEV::FlagNUW);
1038
+ return SE->getMulExpr (TripCountSCEV , SE->getConstant (IntPtr, StoreSize),
1039
+ SCEV::FlagNUW);
1024
1040
}
1025
- return NumBytesS;
1041
+ return TripCountSCEV;
1042
+ }
1043
+
1044
+ // / getNumBytes that takes StoreSize as a SCEV
1045
+ static const SCEV *getNumBytes (const SCEV *BECount, Type *IntPtr,
1046
+ const SCEV *StoreSizeSCEV, Loop *CurLoop,
1047
+ const DataLayout *DL, ScalarEvolution *SE) {
1048
+ const SCEV *TripCountSCEV = getTripCount (BECount, IntPtr, CurLoop, DL, SE);
1049
+
1050
+ return SE->getMulExpr (TripCountSCEV,
1051
+ SE->getTruncateOrZeroExtend (StoreSizeSCEV, IntPtr),
1052
+ SCEV::FlagNUW);
1026
1053
}
1027
1054
1028
1055
// / processLoopStridedStore - We see a strided store of some value. If we can
1029
1056
// / transform this into a memset or memset_pattern in the loop preheader, do so.
1030
1057
bool LoopIdiomRecognize::processLoopStridedStore (
1031
- Value *DestPtr, unsigned StoreSize , MaybeAlign StoreAlignment,
1058
+ Value *DestPtr, const SCEV *StoreSizeSCEV , MaybeAlign StoreAlignment,
1032
1059
Value *StoredVal, Instruction *TheStore,
1033
1060
SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
1034
1061
const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
@@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
1057
1084
const SCEV *Start = Ev->getStart ();
1058
1085
// Handle negative strided loops.
1059
1086
if (NegStride)
1060
- Start = getStartForNegStride (Start, BECount, IntIdxTy, StoreSize , SE);
1087
+ Start = getStartForNegStride (Start, BECount, IntIdxTy, StoreSizeSCEV , SE);
1061
1088
1062
1089
// TODO: ideally we should still be able to generate memset if SCEV expander
1063
1090
// is taught to generate the dependencies at the latest point.
@@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
1082
1109
Changed = true ;
1083
1110
1084
1111
if (mayLoopAccessLocation (BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1085
- StoreSize , *AA, Stores))
1112
+ StoreSizeSCEV , *AA, Stores))
1086
1113
return Changed;
1087
1114
1088
1115
if (avoidLIRForMultiBlockLoop (/* IsMemset=*/ true , IsLoopMemset))
@@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
1091
1118
// Okay, everything looks good, insert the memset.
1092
1119
1093
1120
const SCEV *NumBytesS =
1094
- getNumBytes (BECount, IntIdxTy, StoreSize , CurLoop, DL, SE);
1121
+ getNumBytes (BECount, IntIdxTy, StoreSizeSCEV , CurLoop, DL, SE);
1095
1122
1096
1123
// TODO: ideally we should still be able to generate memset if SCEV expander
1097
1124
// is taught to generate the dependencies at the latest point.
@@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1215
1242
APInt Stride = getStoreStride (StoreEv);
1216
1243
bool NegStride = StoreSize == -Stride;
1217
1244
1245
+ const SCEV *StoreSizeSCEV = SE->getConstant (BECount->getType (), StoreSize);
1218
1246
// Handle negative strided loops.
1219
1247
if (NegStride)
1220
- StrStart = getStartForNegStride (StrStart, BECount, IntIdxTy, StoreSize, SE);
1248
+ StrStart =
1249
+ getStartForNegStride (StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
1221
1250
1222
1251
// Okay, we have a strided store "p[i]" of a loaded value. We can turn
1223
1252
// this into a memcpy in the loop preheader now if we want. However, this
@@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1245
1274
1246
1275
bool UseMemMove =
1247
1276
mayLoopAccessLocation (StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
1248
- StoreSize , *AA, Stores);
1277
+ StoreSizeSCEV , *AA, Stores);
1249
1278
if (UseMemMove) {
1250
1279
Stores.insert (TheLoad);
1251
1280
if (mayLoopAccessLocation (StoreBasePtr, ModRefInfo::ModRef, CurLoop,
1252
- BECount, StoreSize , *AA, Stores)) {
1281
+ BECount, StoreSizeSCEV , *AA, Stores)) {
1253
1282
ORE.emit ([&]() {
1254
1283
return OptimizationRemarkMissed (DEBUG_TYPE, " LoopMayAccessStore" ,
1255
1284
TheStore)
@@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1268
1297
1269
1298
// Handle negative strided loops.
1270
1299
if (NegStride)
1271
- LdStart = getStartForNegStride (LdStart, BECount, IntIdxTy, StoreSize, SE);
1300
+ LdStart =
1301
+ getStartForNegStride (LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
1272
1302
1273
1303
// For a memcpy, we have to make sure that the input array is not being
1274
1304
// mutated by the loop.
@@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
1280
1310
if (IsMemCpy)
1281
1311
Stores.erase (TheStore);
1282
1312
if (mayLoopAccessLocation (LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
1283
- StoreSize , *AA, Stores)) {
1313
+ StoreSizeSCEV , *AA, Stores)) {
1284
1314
ORE.emit ([&]() {
1285
1315
return OptimizationRemarkMissed (DEBUG_TYPE, " LoopMayAccessLoad" , TheLoad)
1286
1316
<< ore::NV (" Inst" , InstRemark) << " in "
0 commit comments