Skip to content

Commit 6aad491

Browse files
authored
[SLP][REVEC] Make MinBWs support vector instructions. (#103049)
If ScalarTy is FixedVectorType, it should remain as FixedVectorType.
1 parent 5c3a0fa commit 6aad491

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9527,8 +9527,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
95279527
// that the costs will be accurate.
95289528
auto It = MinBWs.find(E);
95299529
Type *OrigScalarTy = ScalarTy;
9530-
if (It != MinBWs.end())
9530+
if (It != MinBWs.end()) {
9531+
auto VecTy = dyn_cast<FixedVectorType>(ScalarTy);
95319532
ScalarTy = IntegerType::get(F->getContext(), It->second.first);
9533+
if (VecTy)
9534+
ScalarTy = getWidenedType(ScalarTy, VecTy->getNumElements());
9535+
}
95329536
auto *VecTy = getWidenedType(ScalarTy, VL.size());
95339537
unsigned EntryVF = E->getVectorFactor();
95349538
auto *FinalVecTy = getWidenedType(ScalarTy, EntryVF);
@@ -13127,8 +13131,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1312713131
else if (auto *IE = dyn_cast<InsertElementInst>(V))
1312813132
ScalarTy = IE->getOperand(1)->getType();
1312913133
auto It = MinBWs.find(E);
13130-
if (It != MinBWs.end())
13134+
if (It != MinBWs.end()) {
13135+
auto VecTy = dyn_cast<FixedVectorType>(ScalarTy);
1313113136
ScalarTy = IntegerType::get(F->getContext(), It->second.first);
13137+
if (VecTy)
13138+
ScalarTy = getWidenedType(ScalarTy, VecTy->getNumElements());
13139+
}
1313213140
auto *VecTy = getWidenedType(ScalarTy, E->Scalars.size());
1313313141
if (E->isGather()) {
1313413142
// Set insert point for non-reduction initial nodes.
@@ -16003,16 +16011,18 @@ void BoUpSLP::computeMinimumValueSizes() {
1600316011
}
1600416012

1600516013
unsigned VF = E.getVectorFactor();
16006-
auto *TreeRootIT =
16007-
dyn_cast<IntegerType>(E.Scalars.front()->getType()->getScalarType());
16014+
Type *ScalarTy = E.Scalars.front()->getType();
16015+
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
16016+
auto *TreeRootIT = dyn_cast<IntegerType>(ScalarTy->getScalarType());
1600816017
if (!TreeRootIT || !Opcode)
1600916018
return 0u;
1601016019

1601116020
if (any_of(E.Scalars,
1601216021
[&](Value *V) { return AnalyzedMinBWVals.contains(V); }))
1601316022
return 0u;
1601416023

16015-
unsigned NumParts = TTI->getNumberOfParts(getWidenedType(TreeRootIT, VF));
16024+
unsigned NumParts = TTI->getNumberOfParts(
16025+
getWidenedType(TreeRootIT, VF * ScalarTyNumElements));
1601616026

1601716027
// The maximum bit width required to represent all the values that can be
1601816028
// demoted without loss of precision. It would be safe to truncate the roots
@@ -16034,7 +16044,8 @@ void BoUpSLP::computeMinimumValueSizes() {
1603416044
// we can truncate the roots to this narrower type.
1603516045
for (Value *Root : E.Scalars) {
1603616046
unsigned NumSignBits = ComputeNumSignBits(Root, *DL, 0, AC, nullptr, DT);
16037-
TypeSize NumTypeBits = DL->getTypeSizeInBits(Root->getType());
16047+
TypeSize NumTypeBits =
16048+
DL->getTypeSizeInBits(Root->getType()->getScalarType());
1603816049
unsigned BitWidth1 = NumTypeBits - NumSignBits;
1603916050
// If we can't prove that the sign bit is zero, we must add one to the
1604016051
// maximum bit width to account for the unknown sign bit. This preserves
@@ -16206,7 +16217,8 @@ void BoUpSLP::computeMinimumValueSizes() {
1620616217
// type, we can proceed with the narrowing. Otherwise, do nothing.
1620716218
if (MaxBitWidth == 0 ||
1620816219
MaxBitWidth >=
16209-
cast<IntegerType>(TreeRoot.front()->getType())->getBitWidth()) {
16220+
cast<IntegerType>(TreeRoot.front()->getType()->getScalarType())
16221+
->getBitWidth()) {
1621016222
if (UserIgnoreList)
1621116223
AnalyzedMinBWVals.insert(TreeRoot.begin(), TreeRoot.end());
1621216224
continue;

llvm/test/Transforms/SLPVectorizer/revec.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,27 @@ for.body:
272272
%5 = phi <2 x float> [ %5, %for.body ], [ zeroinitializer, %entry ]
273273
br i1 false, label %for0, label %for.body
274274
}
275+
276+
define void @test9() {
277+
; CHECK-LABEL: @test9(
278+
; CHECK-NEXT: entry:
279+
; CHECK-NEXT: [[TMP0:%.*]] = call <8 x i16> @llvm.vector.insert.v8i16.v4i16(<8 x i16> poison, <4 x i16> zeroinitializer, i64 0)
280+
; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.vector.insert.v8i16.v4i16(<8 x i16> [[TMP0]], <4 x i16> zeroinitializer, i64 4)
281+
; CHECK-NEXT: br label [[FOR_BODY13:%.*]]
282+
; CHECK: for.body13:
283+
; CHECK-NEXT: [[TMP2:%.*]] = trunc <8 x i16> [[TMP1]] to <8 x i1>
284+
; CHECK-NEXT: [[TMP3:%.*]] = zext <8 x i1> [[TMP2]] to <8 x i32>
285+
; CHECK-NEXT: store <8 x i32> [[TMP3]], ptr null, align 4
286+
; CHECK-NEXT: br label [[FOR_BODY13]]
287+
;
288+
entry:
289+
br label %for.body13
290+
291+
for.body13: ; preds = %for.body13, %entry
292+
%vmovl.i111 = sext <4 x i16> zeroinitializer to <4 x i32>
293+
%vmovl.i110 = sext <4 x i16> zeroinitializer to <4 x i32>
294+
store <4 x i32> %vmovl.i111, ptr null, align 4
295+
%add.ptr29 = getelementptr i8, ptr null, i64 16
296+
store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
297+
br label %for.body13
298+
}

0 commit comments

Comments
 (0)