Skip to content

Commit 246f345

Browse files
authored
[SLP][REVEC] Make CastInst support vector instructions. (#103216)
1 parent 2d7a2c1 commit 246f345

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9877,16 +9877,18 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
98779877
auto *SrcVecTy = getWidenedType(SrcScalarTy, VL.size());
98789878
unsigned Opcode = ShuffleOrOp;
98799879
unsigned VecOpcode = Opcode;
9880-
if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
9880+
if (!ScalarTy->isFPOrFPVectorTy() && !SrcScalarTy->isFPOrFPVectorTy() &&
98819881
(SrcIt != MinBWs.end() || It != MinBWs.end())) {
98829882
// Check if the values are candidates to demote.
9883-
unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
9883+
unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy->getScalarType());
98849884
if (SrcIt != MinBWs.end()) {
98859885
SrcBWSz = SrcIt->second.first;
9886+
unsigned SrcScalarTyNumElements = getNumElements(SrcScalarTy);
98869887
SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
9887-
SrcVecTy = getWidenedType(SrcScalarTy, VL.size());
9888+
SrcVecTy =
9889+
getWidenedType(SrcScalarTy, VL.size() * SrcScalarTyNumElements);
98889890
}
9889-
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
9891+
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy->getScalarType());
98909892
if (BWSz == SrcBWSz) {
98919893
VecOpcode = Instruction::BitCast;
98929894
} else if (BWSz < SrcBWSz) {
@@ -13452,14 +13454,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1345213454
Instruction::CastOps VecOpcode = CI->getOpcode();
1345313455
Type *SrcScalarTy = cast<VectorType>(InVec->getType())->getElementType();
1345413456
auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
13455-
if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
13457+
if (!ScalarTy->isFPOrFPVectorTy() && !SrcScalarTy->isFPOrFPVectorTy() &&
1345613458
(SrcIt != MinBWs.end() || It != MinBWs.end() ||
13457-
SrcScalarTy != CI->getOperand(0)->getType())) {
13459+
SrcScalarTy != CI->getOperand(0)->getType()->getScalarType())) {
1345813460
// Check if the values are candidates to demote.
1345913461
unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
1346013462
if (SrcIt != MinBWs.end())
1346113463
SrcBWSz = SrcIt->second.first;
13462-
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
13464+
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy->getScalarType());
1346313465
if (BWSz == SrcBWSz) {
1346413466
VecOpcode = Instruction::BitCast;
1346513467
} else if (BWSz < SrcBWSz) {

llvm/test/Transforms/SLPVectorizer/revec.ll

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,43 @@ for.body13: ; preds = %for.body13, %entry
296296
store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
297297
br label %for.body13
298298
}
299+
300+
define void @test10() {
301+
; CHECK-LABEL: @test10(
302+
; CHECK-NEXT: entry:
303+
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i8>, ptr null, align 1
304+
; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i8> @llvm.vector.insert.v32i8.v16i8(<32 x i8> poison, <16 x i8> poison, i64 16)
305+
; CHECK-NEXT: [[TMP2:%.*]] = call <32 x i8> @llvm.vector.insert.v32i8.v16i8(<32 x i8> [[TMP1]], <16 x i8> [[TMP0]], i64 0)
306+
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <32 x i8> [[TMP2]], <32 x i8> poison, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
307+
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <32 x i8> [[TMP3]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
308+
; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[TMP4]] to <16 x i16>
309+
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <16 x i16> [[TMP5]], <16 x i16> poison, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
310+
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <32 x i16> [[TMP6]], <32 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 28, i32 29, i32 30, i32 31>
311+
; CHECK-NEXT: [[TMP8:%.*]] = trunc <16 x i16> [[TMP7]] to <16 x i8>
312+
; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[TMP8]] to <16 x i32>
313+
; CHECK-NEXT: store <16 x i32> [[TMP9]], ptr null, align 4
314+
; CHECK-NEXT: ret void
315+
;
316+
entry:
317+
%0 = load <16 x i8>, ptr null, align 1
318+
%shuffle.i = shufflevector <16 x i8> %0, <16 x i8> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
319+
%shuffle.i107 = shufflevector <16 x i8> %0, <16 x i8> zeroinitializer, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
320+
%vmovl.i106 = sext <8 x i8> %shuffle.i to <8 x i16>
321+
%vmovl.i = sext <8 x i8> %shuffle.i107 to <8 x i16>
322+
%shuffle.i113 = shufflevector <8 x i16> %vmovl.i106, <8 x i16> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
323+
%shuffle.i115 = shufflevector <8 x i16> %vmovl.i106, <8 x i16> zeroinitializer, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
324+
%shuffle.i112 = shufflevector <8 x i16> %vmovl.i, <8 x i16> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
325+
%shuffle.i114 = shufflevector <8 x i16> %vmovl.i, <8 x i16> zeroinitializer, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
326+
%vmovl.i111 = sext <4 x i16> %shuffle.i113 to <4 x i32>
327+
%vmovl.i110 = sext <4 x i16> %shuffle.i115 to <4 x i32>
328+
%vmovl.i109 = sext <4 x i16> %shuffle.i112 to <4 x i32>
329+
%vmovl.i108 = sext <4 x i16> %shuffle.i114 to <4 x i32>
330+
%add.ptr29 = getelementptr i8, ptr null, i64 16
331+
%add.ptr32 = getelementptr i8, ptr null, i64 32
332+
%add.ptr35 = getelementptr i8, ptr null, i64 48
333+
store <4 x i32> %vmovl.i111, ptr null, align 4
334+
store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
335+
store <4 x i32> %vmovl.i109, ptr %add.ptr32, align 4
336+
store <4 x i32> %vmovl.i108, ptr %add.ptr35, align 4
337+
ret void
338+
}

0 commit comments

Comments
 (0)