@@ -5923,9 +5923,9 @@ static bool isMaskedLoadCompress(
5923
5923
// Check for very large distances between elements.
5924
5924
if (*Diff / Sz >= MaxRegSize / 8)
5925
5925
return false;
5926
- Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
5927
5926
LoadVecTy = getWidenedType(ScalarTy, *Diff + 1);
5928
5927
auto *LI = cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]);
5928
+ Align CommonAlignment = LI->getAlign();
5929
5929
IsMasked = !isSafeToLoadUnconditionally(
5930
5930
Ptr0, LoadVecTy, CommonAlignment, DL,
5931
5931
cast<LoadInst>(Order.empty() ? VL.back() : VL[Order.back()]), &AC, &DT,
@@ -5964,26 +5964,28 @@ static bool isMaskedLoadCompress(
5964
5964
TTI.getMaskedMemoryOpCost(Instruction::Load, LoadVecTy, CommonAlignment,
5965
5965
LI->getPointerAddressSpace(), CostKind);
5966
5966
} else {
5967
- CommonAlignment = LI->getAlign();
5968
5967
LoadCost =
5969
5968
TTI.getMemoryOpCost(Instruction::Load, LoadVecTy, CommonAlignment,
5970
5969
LI->getPointerAddressSpace(), CostKind);
5971
5970
}
5972
- if (IsStrided) {
5971
+ if (IsStrided && !IsMasked ) {
5973
5972
// Check for potential segmented(interleaved) loads.
5974
- if (TTI.isLegalInterleavedAccessType(LoadVecTy, CompressMask[1],
5973
+ auto *AlignedLoadVecTy = getWidenedType(
5974
+ ScalarTy, getFullVectorNumberOfElements(TTI, ScalarTy, *Diff + 1));
5975
+ if (TTI.isLegalInterleavedAccessType(AlignedLoadVecTy, CompressMask[1],
5975
5976
CommonAlignment,
5976
5977
LI->getPointerAddressSpace())) {
5977
5978
InstructionCost InterleavedCost =
5978
5979
VectorGEPCost + TTI.getInterleavedMemoryOpCost(
5979
- Instruction::Load, LoadVecTy, CompressMask[1] ,
5980
- std::nullopt, CommonAlignment,
5980
+ Instruction::Load, AlignedLoadVecTy ,
5981
+ CompressMask[1], std::nullopt, CommonAlignment,
5981
5982
LI->getPointerAddressSpace(), CostKind, IsMasked);
5982
5983
if (!Mask.empty())
5983
5984
InterleavedCost += ::getShuffleCost(TTI, TTI::SK_PermuteSingleSrc,
5984
5985
VecTy, Mask, CostKind);
5985
5986
if (InterleavedCost < GatherCost) {
5986
5987
InterleaveFactor = CompressMask[1];
5988
+ LoadVecTy = AlignedLoadVecTy;
5987
5989
return true;
5988
5990
}
5989
5991
}
@@ -6001,6 +6003,24 @@ static bool isMaskedLoadCompress(
6001
6003
return TotalVecCost < GatherCost;
6002
6004
}
6003
6005
6006
+ /// Checks if the \p VL can be transformed to a (masked)load + compress or
6007
+ /// (masked) interleaved load.
6008
+ static bool
6009
+ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6010
+ ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
6011
+ const DataLayout &DL, ScalarEvolution &SE,
6012
+ AssumptionCache &AC, const DominatorTree &DT,
6013
+ const TargetLibraryInfo &TLI,
6014
+ const function_ref<bool(Value *)> AreAllUsersVectorized) {
6015
+ bool IsMasked;
6016
+ unsigned InterleaveFactor;
6017
+ SmallVector<int> CompressMask;
6018
+ VectorType *LoadVecTy;
6019
+ return isMaskedLoadCompress(VL, PointerOps, Order, TTI, DL, SE, AC, DT, TLI,
6020
+ AreAllUsersVectorized, IsMasked, InterleaveFactor,
6021
+ CompressMask, LoadVecTy);
6022
+ }
6023
+
6004
6024
/// Checks if strided loads can be generated out of \p VL loads with pointers \p
6005
6025
/// PointerOps:
6006
6026
/// 1. Target with strided load support is detected.
@@ -6137,6 +6157,12 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
6137
6157
// Check that the sorted loads are consecutive.
6138
6158
if (static_cast<unsigned>(*Diff) == Sz - 1)
6139
6159
return LoadsState::Vectorize;
6160
+ if (isMaskedLoadCompress(VL, PointerOps, Order, *TTI, *DL, *SE, *AC, *DT,
6161
+ *TLI, [&](Value *V) {
6162
+ return areAllUsersVectorized(
6163
+ cast<Instruction>(V), UserIgnoreList);
6164
+ }))
6165
+ return LoadsState::CompressVectorize;
6140
6166
// Simple check if not a strided access - clear order.
6141
6167
bool IsPossibleStrided = *Diff % (Sz - 1) == 0;
6142
6168
// Try to generate strided load node.
@@ -6150,18 +6176,6 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
6150
6176
isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE,
6151
6177
IsAnyPointerUsedOutGraph, *Diff))
6152
6178
return LoadsState::StridedVectorize;
6153
- bool IsMasked;
6154
- unsigned InterleaveFactor;
6155
- SmallVector<int> CompressMask;
6156
- VectorType *LoadVecTy;
6157
- if (isMaskedLoadCompress(
6158
- VL, PointerOps, Order, *TTI, *DL, *SE, *AC, *DT, *TLI,
6159
- [&](Value *V) {
6160
- return areAllUsersVectorized(cast<Instruction>(V),
6161
- UserIgnoreList);
6162
- },
6163
- IsMasked, InterleaveFactor, CompressMask, LoadVecTy))
6164
- return LoadsState::CompressVectorize;
6165
6179
}
6166
6180
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
6167
6181
TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment))
@@ -13439,11 +13453,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13439
13453
assert(IsVectorized && "Expected to be vectorized");
13440
13454
CompressEntryToData.try_emplace(E, CompressMask, LoadVecTy,
13441
13455
InterleaveFactor, IsMasked);
13442
- Align CommonAlignment;
13443
- if (IsMasked)
13444
- CommonAlignment = computeCommonAlignment<LoadInst>(VL);
13445
- else
13446
- CommonAlignment = LI0->getAlign();
13456
+ Align CommonAlignment = LI0->getAlign();
13447
13457
if (InterleaveFactor) {
13448
13458
VecLdCost = TTI->getInterleavedMemoryOpCost(
13449
13459
Instruction::Load, LoadVecTy, InterleaveFactor, std::nullopt,
@@ -18049,14 +18059,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
18049
18059
PointerOps[I] = cast<LoadInst>(V)->getPointerOperand();
18050
18060
auto [CompressMask, LoadVecTy, InterleaveFactor, IsMasked] =
18051
18061
CompressEntryToData.at(E);
18052
- Align CommonAlignment;
18053
- if (IsMasked)
18054
- CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars);
18055
- else
18056
- CommonAlignment = LI->getAlign();
18062
+ Align CommonAlignment = LI->getAlign();
18057
18063
if (IsMasked) {
18064
+ unsigned VF = getNumElements(LoadVecTy);
18058
18065
SmallVector<Constant *> MaskValues(
18059
- getNumElements(LoadVecTy) / getNumElements(LI->getType()),
18066
+ VF / getNumElements(LI->getType()),
18060
18067
ConstantInt::getFalse(VecTy->getContext()));
18061
18068
for (int I : CompressMask)
18062
18069
MaskValues[I] = ConstantInt::getTrue(VecTy->getContext());
0 commit comments