Skip to content

Commit 2347aa1

Browse files
authored
[SLP][REVEC] Fix the mismatch between the result of getAltInstrMask and the VecTy argument of TargetTransformInfo::isLegalAltInstr. (#134795)
We cannot determine ScalarTy from VL because some ScalarTy is determined from VL[0]->getType(), while others are determined from getValueType(VL[0]). Fix "Mask and VecTy are incompatible".
1 parent 97c4cb4 commit 2347aa1

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,9 +1264,8 @@ static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) {
12641264

12651265
/// \returns a bitset for selecting opcodes. false for Opcode0 and true for
12661266
/// Opcode1.
1267-
static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0,
1268-
unsigned Opcode1) {
1269-
Type *ScalarTy = VL[0]->getType();
1267+
static SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, Type *ScalarTy,
1268+
unsigned Opcode0, unsigned Opcode1) {
12701269
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
12711270
SmallBitVector OpcodeMask(VL.size() * ScalarTyNumElements, false);
12721271
for (unsigned Lane : seq<unsigned>(VL.size())) {
@@ -6667,11 +6666,12 @@ void BoUpSLP::reorderTopToBottom() {
66676666
// to take into account their order when looking for the most used order.
66686667
if (TE->hasState() && TE->isAltShuffle() &&
66696668
TE->State != TreeEntry::SplitVectorize) {
6670-
VectorType *VecTy =
6671-
getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size());
6669+
Type *ScalarTy = TE->Scalars[0]->getType();
6670+
VectorType *VecTy = getWidenedType(ScalarTy, TE->Scalars.size());
66726671
unsigned Opcode0 = TE->getOpcode();
66736672
unsigned Opcode1 = TE->getAltOpcode();
6674-
SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1));
6673+
SmallBitVector OpcodeMask(
6674+
getAltInstrMask(TE->Scalars, ScalarTy, Opcode0, Opcode1));
66756675
// If this pattern is supported by the target then we consider the order.
66766676
if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
66776677
VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get());
@@ -8352,12 +8352,13 @@ static bool isAlternateInstruction(const Instruction *I,
83528352

83538353
bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S,
83548354
ArrayRef<Value *> VL) const {
8355+
Type *ScalarTy = S.getMainOp()->getType();
83558356
unsigned Opcode0 = S.getOpcode();
83568357
unsigned Opcode1 = S.getAltOpcode();
8357-
SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
8358+
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
83588359
// If this pattern is supported by the target then consider it profitable.
8359-
if (TTI->isLegalAltInstr(getWidenedType(S.getMainOp()->getType(), VL.size()),
8360-
Opcode0, Opcode1, OpcodeMask))
8360+
if (TTI->isLegalAltInstr(getWidenedType(ScalarTy, VL.size()), Opcode0,
8361+
Opcode1, OpcodeMask))
83618362
return true;
83628363
SmallVector<ValueList> Operands;
83638364
for (unsigned I : seq<unsigned>(S.getMainOp()->getNumOperands())) {
@@ -9270,7 +9271,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
92709271
VectorType *VecTy = getWidenedType(ScalarTy, VL.size());
92719272
unsigned Opcode0 = LocalState.getOpcode();
92729273
unsigned Opcode1 = LocalState.getAltOpcode();
9273-
SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1));
9274+
SmallBitVector OpcodeMask(getAltInstrMask(VL, ScalarTy, Opcode0, Opcode1));
92749275
// Enable split node, only if all nodes do not form legal alternate
92759276
// instruction (like X86 addsub).
92769277
SmallPtrSet<Value *, 4> UOp1(llvm::from_range, Op1);
@@ -13200,7 +13201,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1320013201
// order.
1320113202
unsigned Opcode0 = E->getOpcode();
1320213203
unsigned Opcode1 = E->getAltOpcode();
13203-
SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1));
13204+
SmallBitVector OpcodeMask(
13205+
getAltInstrMask(E->Scalars, ScalarTy, Opcode0, Opcode1));
1320413206
// If this pattern is supported by the target then we consider the
1320513207
// order.
1320613208
if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) {
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -mtriple=x86_64-unknown-linux-gnu -mattr=+avx -passes=slp-vectorizer -S -slp-revec %s | FileCheck %s
3+
4+
define i32 @test() {
5+
; CHECK-LABEL: @test(
6+
; CHECK-NEXT: entry:
7+
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr null, align 1
8+
; CHECK-NEXT: [[WIDE_LOAD136:%.*]] = load <16 x i8>, ptr null, align 1
9+
; CHECK-NEXT: [[WIDE_LOAD137:%.*]] = load <16 x i8>, ptr null, align 1
10+
; CHECK-NEXT: [[WIDE_LOAD138:%.*]] = load <16 x i8>, ptr null, align 1
11+
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
12+
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
13+
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i8> [[WIDE_LOAD]], zeroinitializer
14+
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i8> [[WIDE_LOAD136]], zeroinitializer
15+
; CHECK-NEXT: [[TMP4:%.*]] = or <16 x i8> [[WIDE_LOAD137]], zeroinitializer
16+
; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i8> [[WIDE_LOAD138]], zeroinitializer
17+
; CHECK-NEXT: [[TMP6:%.*]] = icmp ult <16 x i8> [[TMP2]], zeroinitializer
18+
; CHECK-NEXT: [[TMP7:%.*]] = icmp ult <16 x i8> [[TMP3]], zeroinitializer
19+
; CHECK-NEXT: [[TMP8:%.*]] = icmp ult <16 x i8> [[TMP4]], zeroinitializer
20+
; CHECK-NEXT: [[TMP9:%.*]] = icmp ult <16 x i8> [[TMP5]], zeroinitializer
21+
; CHECK-NEXT: [[TMP10:%.*]] = or <16 x i8> [[TMP0]], zeroinitializer
22+
; CHECK-NEXT: [[TMP11:%.*]] = or <16 x i8> [[TMP1]], zeroinitializer
23+
; CHECK-NEXT: [[TMP12:%.*]] = icmp ult <16 x i8> [[TMP10]], zeroinitializer
24+
; CHECK-NEXT: [[TMP13:%.*]] = icmp ult <16 x i8> [[TMP11]], zeroinitializer
25+
; CHECK-NEXT: ret i32 0
26+
;
27+
entry:
28+
%wide.load = load <16 x i8>, ptr null, align 1
29+
%wide.load136 = load <16 x i8>, ptr null, align 1
30+
%wide.load137 = load <16 x i8>, ptr null, align 1
31+
%wide.load138 = load <16 x i8>, ptr null, align 1
32+
%0 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
33+
%1 = insertelement <16 x i8> zeroinitializer, i8 0, i64 0
34+
%2 = or <16 x i8> %wide.load, zeroinitializer
35+
%3 = or <16 x i8> %wide.load136, zeroinitializer
36+
%4 = or <16 x i8> %wide.load137, zeroinitializer
37+
%5 = or <16 x i8> %wide.load138, zeroinitializer
38+
%6 = icmp ult <16 x i8> %2, zeroinitializer
39+
%7 = icmp ult <16 x i8> %3, zeroinitializer
40+
%8 = icmp ult <16 x i8> %4, zeroinitializer
41+
%9 = icmp ult <16 x i8> %5, zeroinitializer
42+
%10 = or <16 x i8> %0, zeroinitializer
43+
%11 = or <16 x i8> %1, zeroinitializer
44+
%12 = icmp ult <16 x i8> %10, zeroinitializer
45+
%13 = icmp ult <16 x i8> %11, zeroinitializer
46+
ret i32 0
47+
}

0 commit comments

Comments
 (0)