Skip to content

Commit e4e5c42

Browse files
committed
[X86][SSE] isTargetShuffleEquivalent - ensure shuffle inputs are the correct size.
Preliminary patch for the next stage of PR45974 - we don't want to be creating 'padded' vectors on-the-fly at all in combineX86ShufflesRecursively, and only pad the source inputs if we have a definite match inside combineX86ShuffleChain. This means that the inputs to combineX86ShuffleChain might soon be smaller than the final root value type, so we should ensure that isTargetShuffleEquivalent only matches with the inputs if they are the correct size.
1 parent a566f05 commit e4e5c42

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10930,7 +10930,7 @@ static bool isShuffleEquivalent(ArrayRef<int> Mask, ArrayRef<int> ExpectedMask,
1093010930
///
1093110931
/// SM_SentinelZero is accepted as a valid negative index but must match in
1093210932
/// both.
10933-
static bool isTargetShuffleEquivalent(ArrayRef<int> Mask,
10933+
static bool isTargetShuffleEquivalent(MVT VT, ArrayRef<int> Mask,
1093410934
ArrayRef<int> ExpectedMask,
1093510935
SDValue V1 = SDValue(),
1093610936
SDValue V2 = SDValue()) {
@@ -10944,6 +10944,12 @@ static bool isTargetShuffleEquivalent(ArrayRef<int> Mask,
1094410944
if (!isUndefOrZeroOrInRange(Mask, 0, 2 * Size))
1094510945
return false;
1094610946

10947+
// Don't use V1/V2 if they're not the same size as the shuffle mask type.
10948+
if (V1 && V1.getValueSizeInBits() != VT.getSizeInBits())
10949+
V1 = SDValue();
10950+
if (V2 && V2.getValueSizeInBits() != VT.getSizeInBits())
10951+
V2 = SDValue();
10952+
1094710953
for (int i = 0; i < Size; ++i) {
1094810954
int MaskIdx = Mask[i];
1094910955
int ExpectedIdx = ExpectedMask[i];
@@ -11002,8 +11008,8 @@ static bool isUnpackWdShuffleMask(ArrayRef<int> Mask, MVT VT) {
1100211008
SmallVector<int, 8> Unpckhwd;
1100311009
createUnpackShuffleMask(MVT::v8i16, Unpckhwd, /* Lo = */ false,
1100411010
/* Unary = */ false);
11005-
bool IsUnpackwdMask = (isTargetShuffleEquivalent(Mask, Unpcklwd) ||
11006-
isTargetShuffleEquivalent(Mask, Unpckhwd));
11011+
bool IsUnpackwdMask = (isTargetShuffleEquivalent(VT, Mask, Unpcklwd) ||
11012+
isTargetShuffleEquivalent(VT, Mask, Unpckhwd));
1100711013
return IsUnpackwdMask;
1100811014
}
1100911015

@@ -11020,8 +11026,8 @@ static bool is128BitUnpackShuffleMask(ArrayRef<int> Mask) {
1102011026
for (unsigned i = 0; i != 4; ++i) {
1102111027
SmallVector<int, 16> UnpackMask;
1102211028
createUnpackShuffleMask(VT, UnpackMask, (i >> 1) % 2, i % 2);
11023-
if (isTargetShuffleEquivalent(Mask, UnpackMask) ||
11024-
isTargetShuffleEquivalent(CommutedMask, UnpackMask))
11029+
if (isTargetShuffleEquivalent(VT, Mask, UnpackMask) ||
11030+
isTargetShuffleEquivalent(VT, CommutedMask, UnpackMask))
1102511031
return true;
1102611032
}
1102711033
return false;
@@ -11214,15 +11220,15 @@ static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
1121411220
// Attempt to match the target mask against the unpack lo/hi mask patterns.
1121511221
SmallVector<int, 64> Unpckl, Unpckh;
1121611222
createUnpackShuffleMask(VT, Unpckl, /* Lo = */ true, IsUnary);
11217-
if (isTargetShuffleEquivalent(TargetMask, Unpckl)) {
11223+
if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl)) {
1121811224
UnpackOpcode = X86ISD::UNPCKL;
1121911225
V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
1122011226
V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
1122111227
return true;
1122211228
}
1122311229

1122411230
createUnpackShuffleMask(VT, Unpckh, /* Lo = */ false, IsUnary);
11225-
if (isTargetShuffleEquivalent(TargetMask, Unpckh)) {
11231+
if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh)) {
1122611232
UnpackOpcode = X86ISD::UNPCKH;
1122711233
V2 = (Undef2 ? DAG.getUNDEF(VT) : (IsUnary ? V1 : V2));
1122811234
V1 = (Undef1 ? DAG.getUNDEF(VT) : V1);
@@ -11260,14 +11266,14 @@ static bool matchShuffleWithUNPCK(MVT VT, SDValue &V1, SDValue &V2,
1126011266
// If a binary shuffle, commute and try again.
1126111267
if (!IsUnary) {
1126211268
ShuffleVectorSDNode::commuteMask(Unpckl);
11263-
if (isTargetShuffleEquivalent(TargetMask, Unpckl)) {
11269+
if (isTargetShuffleEquivalent(VT, TargetMask, Unpckl)) {
1126411270
UnpackOpcode = X86ISD::UNPCKL;
1126511271
std::swap(V1, V2);
1126611272
return true;
1126711273
}
1126811274

1126911275
ShuffleVectorSDNode::commuteMask(Unpckh);
11270-
if (isTargetShuffleEquivalent(TargetMask, Unpckh)) {
11276+
if (isTargetShuffleEquivalent(VT, TargetMask, Unpckh)) {
1127111277
UnpackOpcode = X86ISD::UNPCKH;
1127211278
std::swap(V1, V2);
1127311279
return true;
@@ -11638,14 +11644,14 @@ static bool matchShuffleWithPACK(MVT VT, MVT &SrcVT, SDValue &V1, SDValue &V2,
1163811644
// Try binary shuffle.
1163911645
SmallVector<int, 32> BinaryMask;
1164011646
createPackShuffleMask(VT, BinaryMask, false, NumStages);
11641-
if (isTargetShuffleEquivalent(TargetMask, BinaryMask, V1, V2))
11647+
if (isTargetShuffleEquivalent(VT, TargetMask, BinaryMask, V1, V2))
1164211648
if (MatchPACK(V1, V2, PackVT))
1164311649
return true;
1164411650

1164511651
// Try unary shuffle.
1164611652
SmallVector<int, 32> UnaryMask;
1164711653
createPackShuffleMask(VT, UnaryMask, true, NumStages);
11648-
if (isTargetShuffleEquivalent(TargetMask, UnaryMask, V1))
11654+
if (isTargetShuffleEquivalent(VT, TargetMask, UnaryMask, V1))
1164911655
if (MatchPACK(V1, V1, PackVT))
1165011656
return true;
1165111657
}
@@ -34522,17 +34528,17 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
3452234528
// instructions are no slower than UNPCKLPD but has the option to
3452334529
// fold the input operand into even an unaligned memory load.
3452434530
if (MaskVT.is128BitVector() && Subtarget.hasSSE3() && AllowFloatDomain) {
34525-
if (isTargetShuffleEquivalent(Mask, {0, 0}, V1)) {
34531+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}, V1)) {
3452634532
Shuffle = X86ISD::MOVDDUP;
3452734533
SrcVT = DstVT = MVT::v2f64;
3452834534
return true;
3452934535
}
34530-
if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}, V1)) {
34536+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, V1)) {
3453134537
Shuffle = X86ISD::MOVSLDUP;
3453234538
SrcVT = DstVT = MVT::v4f32;
3453334539
return true;
3453434540
}
34535-
if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3}, V1)) {
34541+
if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3}, V1)) {
3453634542
Shuffle = X86ISD::MOVSHDUP;
3453734543
SrcVT = DstVT = MVT::v4f32;
3453834544
return true;
@@ -34541,17 +34547,17 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
3454134547

3454234548
if (MaskVT.is256BitVector() && AllowFloatDomain) {
3454334549
assert(Subtarget.hasAVX() && "AVX required for 256-bit vector shuffles");
34544-
if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2}, V1)) {
34550+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2}, V1)) {
3454534551
Shuffle = X86ISD::MOVDDUP;
3454634552
SrcVT = DstVT = MVT::v4f64;
3454734553
return true;
3454834554
}
34549-
if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
34555+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
3455034556
Shuffle = X86ISD::MOVSLDUP;
3455134557
SrcVT = DstVT = MVT::v8f32;
3455234558
return true;
3455334559
}
34554-
if (isTargetShuffleEquivalent(Mask, {1, 1, 3, 3, 5, 5, 7, 7}, V1)) {
34560+
if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1, 3, 3, 5, 5, 7, 7}, V1)) {
3455534561
Shuffle = X86ISD::MOVSHDUP;
3455634562
SrcVT = DstVT = MVT::v8f32;
3455734563
return true;
@@ -34561,19 +34567,21 @@ static bool matchUnaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
3456134567
if (MaskVT.is512BitVector() && AllowFloatDomain) {
3456234568
assert(Subtarget.hasAVX512() &&
3456334569
"AVX512 required for 512-bit vector shuffles");
34564-
if (isTargetShuffleEquivalent(Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
34570+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0, 2, 2, 4, 4, 6, 6}, V1)) {
3456534571
Shuffle = X86ISD::MOVDDUP;
3456634572
SrcVT = DstVT = MVT::v8f64;
3456734573
return true;
3456834574
}
3456934575
if (isTargetShuffleEquivalent(
34570-
Mask, {0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}, V1)) {
34576+
MaskVT, Mask,
34577+
{0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14}, V1)) {
3457134578
Shuffle = X86ISD::MOVSLDUP;
3457234579
SrcVT = DstVT = MVT::v16f32;
3457334580
return true;
3457434581
}
3457534582
if (isTargetShuffleEquivalent(
34576-
Mask, {1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, V1)) {
34583+
MaskVT, Mask,
34584+
{1, 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}, V1)) {
3457734585
Shuffle = X86ISD::MOVSHDUP;
3457834586
SrcVT = DstVT = MVT::v16f32;
3457934587
return true;
@@ -34732,27 +34740,27 @@ static bool matchBinaryShuffle(MVT MaskVT, ArrayRef<int> Mask,
3473234740
unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
3473334741

3473434742
if (MaskVT.is128BitVector()) {
34735-
if (isTargetShuffleEquivalent(Mask, {0, 0}) && AllowFloatDomain) {
34743+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 0}) && AllowFloatDomain) {
3473634744
V2 = V1;
3473734745
V1 = (SM_SentinelUndef == Mask[0] ? DAG.getUNDEF(MVT::v4f32) : V1);
3473834746
Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKL : X86ISD::MOVLHPS;
3473934747
SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
3474034748
return true;
3474134749
}
34742-
if (isTargetShuffleEquivalent(Mask, {1, 1}) && AllowFloatDomain) {
34750+
if (isTargetShuffleEquivalent(MaskVT, Mask, {1, 1}) && AllowFloatDomain) {
3474334751
V2 = V1;
3474434752
Shuffle = Subtarget.hasSSE2() ? X86ISD::UNPCKH : X86ISD::MOVHLPS;
3474534753
SrcVT = DstVT = Subtarget.hasSSE2() ? MVT::v2f64 : MVT::v4f32;
3474634754
return true;
3474734755
}
34748-
if (isTargetShuffleEquivalent(Mask, {0, 3}) && Subtarget.hasSSE2() &&
34749-
(AllowFloatDomain || !Subtarget.hasSSE41())) {
34756+
if (isTargetShuffleEquivalent(MaskVT, Mask, {0, 3}) &&
34757+
Subtarget.hasSSE2() && (AllowFloatDomain || !Subtarget.hasSSE41())) {
3475034758
std::swap(V1, V2);
3475134759
Shuffle = X86ISD::MOVSD;
3475234760
SrcVT = DstVT = MVT::v2f64;
3475334761
return true;
3475434762
}
34755-
if (isTargetShuffleEquivalent(Mask, {4, 1, 2, 3}) &&
34763+
if (isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3}) &&
3475634764
(AllowFloatDomain || !Subtarget.hasSSE41())) {
3475734765
Shuffle = X86ISD::MOVSS;
3475834766
SrcVT = DstVT = MVT::v4f32;
@@ -35325,7 +35333,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
3532535333
// from a scalar.
3532635334
// TODO: Handle other insertions here as well?
3532735335
if (!UnaryShuffle && AllowFloatDomain && RootSizeInBits == 128 &&
35328-
Subtarget.hasSSE41() && !isTargetShuffleEquivalent(Mask, {4, 1, 2, 3})) {
35336+
Subtarget.hasSSE41() &&
35337+
!isTargetShuffleEquivalent(MaskVT, Mask, {4, 1, 2, 3})) {
3532935338
if (MaskEltSizeInBits == 32) {
3533035339
SDValue SrcV1 = V1, SrcV2 = V2;
3533135340
if (matchShuffleAsInsertPS(SrcV1, SrcV2, PermuteImm, Zeroable, Mask,
@@ -35340,7 +35349,8 @@ static SDValue combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
3534035349
return DAG.getBitcast(RootVT, Res);
3534135350
}
3534235351
}
35343-
if (MaskEltSizeInBits == 64 && isTargetShuffleEquivalent(Mask, {0, 2}) &&
35352+
if (MaskEltSizeInBits == 64 &&
35353+
isTargetShuffleEquivalent(MaskVT, Mask, {0, 2}) &&
3534435354
V2.getOpcode() == ISD::SCALAR_TO_VECTOR &&
3534535355
V2.getScalarValueSizeInBits() <= 32) {
3534635356
if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS)

0 commit comments

Comments
 (0)