Skip to content

Commit 3a30d8e

Browse files
[SLP]Check if masked gather can be emitted as a serie of loads/insert subvector.
Masked gather is very expensive operation and sometimes better to represent it as a serie of consecutive/strided loads + insertsubvectors sequences. Patch adds some basic estimation and if loads+insertsubvector is cheaper, decides to represent it in this way rather than masked gather. Reviewers: RKSimon Reviewed By: RKSimon Pull Request: llvm#83481
1 parent dd426fa commit 3a30d8e

File tree

2 files changed

+102
-16
lines changed

2 files changed

+102
-16
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4000,12 +4000,14 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {
40004000

40014001
/// Checks if the given array of loads can be represented as a vectorized,
40024002
/// scatter or just simple gather.
4003-
static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
4003+
static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
4004+
const Value *VL0,
40044005
const TargetTransformInfo &TTI,
40054006
const DataLayout &DL, ScalarEvolution &SE,
40064007
LoopInfo &LI, const TargetLibraryInfo &TLI,
40074008
SmallVectorImpl<unsigned> &Order,
4008-
SmallVectorImpl<Value *> &PointerOps) {
4009+
SmallVectorImpl<Value *> &PointerOps,
4010+
bool TryRecursiveCheck = true) {
40094011
// Check that a vectorized load would load the same memory as a scalar
40104012
// load. For example, we don't want to vectorize loads that are smaller
40114013
// than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -4098,6 +4100,78 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
40984100
}
40994101
}
41004102
}
4103+
auto CheckForShuffledLoads = [&](Align CommonAlignment) {
4104+
unsigned Sz = DL.getTypeSizeInBits(ScalarTy);
4105+
unsigned MinVF = R.getMinVF(Sz);
4106+
unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
4107+
MaxVF = std::min(R.getMaximumVF(Sz, Instruction::Load), MaxVF);
4108+
for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
4109+
unsigned VectorizedCnt = 0;
4110+
SmallVector<LoadsState> States;
4111+
for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End;
4112+
Cnt += VF, ++VectorizedCnt) {
4113+
ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
4114+
SmallVector<unsigned> Order;
4115+
SmallVector<Value *> PointerOps;
4116+
LoadsState LS =
4117+
canVectorizeLoads(R, Slice, Slice.front(), TTI, DL, SE, LI, TLI,
4118+
Order, PointerOps, /*TryRecursiveCheck=*/false);
4119+
// Check that the sorted loads are consecutive.
4120+
if (LS == LoadsState::Gather)
4121+
break;
4122+
// If need the reorder - consider as high-cost masked gather for now.
4123+
if ((LS == LoadsState::Vectorize ||
4124+
LS == LoadsState::StridedVectorize) &&
4125+
!Order.empty() && !isReverseOrder(Order))
4126+
LS = LoadsState::ScatterVectorize;
4127+
States.push_back(LS);
4128+
}
4129+
// Can be vectorized later as a serie of loads/insertelements.
4130+
if (VectorizedCnt == VL.size() / VF) {
4131+
// Compare masked gather cost and loads + insersubvector costs.
4132+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
4133+
InstructionCost MaskedGatherCost = TTI.getGatherScatterOpCost(
4134+
Instruction::Load, VecTy,
4135+
cast<LoadInst>(VL0)->getPointerOperand(),
4136+
/*VariableMask=*/false, CommonAlignment, CostKind);
4137+
InstructionCost VecLdCost = 0;
4138+
auto *SubVecTy = FixedVectorType::get(ScalarTy, VF);
4139+
for (auto [I, LS] : enumerate(States)) {
4140+
auto *LI0 = cast<LoadInst>(VL[I * VF]);
4141+
switch (LS) {
4142+
case LoadsState::Vectorize:
4143+
VecLdCost += TTI.getMemoryOpCost(
4144+
Instruction::Load, SubVecTy, LI0->getAlign(),
4145+
LI0->getPointerAddressSpace(), CostKind,
4146+
TTI::OperandValueInfo());
4147+
break;
4148+
case LoadsState::StridedVectorize:
4149+
VecLdCost += TTI.getStridedMemoryOpCost(
4150+
Instruction::Load, SubVecTy, LI0->getPointerOperand(),
4151+
/*VariableMask=*/false, CommonAlignment, CostKind);
4152+
break;
4153+
case LoadsState::ScatterVectorize:
4154+
VecLdCost += TTI.getGatherScatterOpCost(
4155+
Instruction::Load, SubVecTy, LI0->getPointerOperand(),
4156+
/*VariableMask=*/false, CommonAlignment, CostKind);
4157+
break;
4158+
case LoadsState::Gather:
4159+
llvm_unreachable(
4160+
"Expected only consecutive, strided or masked gather loads.");
4161+
}
4162+
VecLdCost +=
4163+
TTI.getShuffleCost(TTI ::SK_InsertSubvector, VecTy,
4164+
std::nullopt, CostKind, I * VF, SubVecTy);
4165+
}
4166+
// If masked gather cost is higher - better to vectorize, so
4167+
// consider it as a gather node. It will be better estimated
4168+
// later.
4169+
if (MaskedGatherCost > VecLdCost)
4170+
return true;
4171+
}
4172+
}
4173+
return false;
4174+
};
41014175
// TODO: need to improve analysis of the pointers, if not all of them are
41024176
// GEPs or have > 2 operands, we end up with a gather node, which just
41034177
// increases the cost.
@@ -4114,8 +4188,17 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
41144188
})) {
41154189
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
41164190
if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
4117-
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment))
4191+
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
4192+
// Check if potential masked gather can be represented as series
4193+
// of loads + insertsubvectors.
4194+
if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
4195+
// If masked gather cost is higher - better to vectorize, so
4196+
// consider it as a gather node. It will be better estimated
4197+
// later.
4198+
return LoadsState::Gather;
4199+
}
41184200
return LoadsState::ScatterVectorize;
4201+
}
41194202
}
41204203
}
41214204

@@ -5554,8 +5637,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
55545637
// treats loading/storing it as an i8 struct. If we vectorize loads/stores
55555638
// from such a struct, we read/write packed bits disagreeing with the
55565639
// unvectorized version.
5557-
switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder,
5558-
PointerOps)) {
5640+
switch (canVectorizeLoads(*this, VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
5641+
CurrentOrder, PointerOps)) {
55595642
case LoadsState::Vectorize:
55605643
return TreeEntry::Vectorize;
55615644
case LoadsState::ScatterVectorize:
@@ -7336,7 +7419,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
73367419
SmallVector<Value *> PointerOps;
73377420
OrdersType CurrentOrder;
73387421
LoadsState LS =
7339-
canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE,
7422+
canVectorizeLoads(R, Slice, Slice.front(), TTI, *R.DL, *R.SE,
73407423
*R.LI, *R.TLI, CurrentOrder, PointerOps);
73417424
switch (LS) {
73427425
case LoadsState::Vectorize:
@@ -7599,7 +7682,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
75997682
transformMaskAfterShuffle(CommonMask, CommonMask);
76007683
}
76017684
SameNodesEstimated = false;
7602-
Cost += createShuffle(&E1, E2, Mask);
76037685
if (!E2 && InVectors.size() == 1) {
76047686
unsigned VF = E1.getVectorFactor();
76057687
if (Value *V1 = InVectors.front().dyn_cast<Value *>()) {

llvm/test/Transforms/SLPVectorizer/X86/scatter-vectorize-reused-pointer.ll

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@ define void @test(i1 %c, ptr %arg) {
55
; CHECK-LABEL: @test(
66
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
77
; CHECK: if:
8-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG:%.*]], i32 0
9-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x ptr> [[TMP1]], <4 x ptr> poison, <4 x i32> zeroinitializer
10-
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, <4 x ptr> [[TMP2]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
11-
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP3]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
8+
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARG:%.*]], align 8
9+
; CHECK-NEXT: [[ARG2_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
10+
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, ptr [[ARG2_2]], align 8
11+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x i64> [[TMP2]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
12+
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
13+
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
1214
; CHECK-NEXT: br label [[JOIN:%.*]]
1315
; CHECK: else:
14-
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG]], i32 0
15-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x ptr> [[TMP5]], <4 x ptr> poison, <4 x i32> zeroinitializer
16-
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, <4 x ptr> [[TMP6]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
17-
; CHECK-NEXT: [[TMP8:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP7]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
16+
; CHECK-NEXT: [[TMP6:%.*]] = load <2 x i64>, ptr [[ARG]], align 8
17+
; CHECK-NEXT: [[ARG_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
18+
; CHECK-NEXT: [[TMP7:%.*]] = load <2 x i64>, ptr [[ARG_2]], align 8
19+
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i64> [[TMP7]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
20+
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i64> [[TMP6]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
21+
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i64> [[TMP8]], <4 x i64> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
1822
; CHECK-NEXT: br label [[JOIN]]
1923
; CHECK: join:
20-
; CHECK-NEXT: [[TMP9:%.*]] = phi <4 x i64> [ [[TMP4]], [[IF]] ], [ [[TMP8]], [[ELSE]] ]
24+
; CHECK-NEXT: [[TMP11:%.*]] = phi <4 x i64> [ [[TMP5]], [[IF]] ], [ [[TMP10]], [[ELSE]] ]
2125
; CHECK-NEXT: ret void
2226
;
2327
br i1 %c, label %if, label %else

0 commit comments

Comments
 (0)