@@ -1239,7 +1239,15 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1239
1239
setOperationAction(ISD::TRUNCATE, MVT::v2i32, Custom);
1240
1240
setOperationAction(ISD::TRUNCATE, MVT::v4i8, Custom);
1241
1241
setOperationAction(ISD::TRUNCATE, MVT::v4i16, Custom);
1242
+ setOperationAction(ISD::TRUNCATE, MVT::v4i32, Custom);
1242
1243
setOperationAction(ISD::TRUNCATE, MVT::v8i8, Custom);
1244
+ setOperationAction(ISD::TRUNCATE, MVT::v8i16, Custom);
1245
+ setOperationAction(ISD::TRUNCATE, MVT::v8i32, Custom);
1246
+ setOperationAction(ISD::TRUNCATE, MVT::v8i64, Custom);
1247
+ setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom);
1248
+ setOperationAction(ISD::TRUNCATE, MVT::v16i16, Custom);
1249
+ setOperationAction(ISD::TRUNCATE, MVT::v16i32, Custom);
1250
+ setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom);
1243
1251
1244
1252
// In the customized shift lowering, the legal v4i32/v2i64 cases
1245
1253
// in AVX2 will be recognized.
@@ -1480,9 +1488,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1480
1488
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
1481
1489
}
1482
1490
1483
- setOperationAction(ISD::TRUNCATE, MVT::v16i8, Custom);
1484
- setOperationAction(ISD::TRUNCATE, MVT::v8i16, Custom);
1485
- setOperationAction(ISD::TRUNCATE, MVT::v4i32, Custom);
1491
+ setOperationAction(ISD::TRUNCATE, MVT::v32i8, Custom);
1492
+ setOperationAction(ISD::TRUNCATE, MVT::v32i16, Custom);
1493
+ setOperationAction(ISD::TRUNCATE, MVT::v32i32, Custom);
1494
+ setOperationAction(ISD::TRUNCATE, MVT::v32i64, Custom);
1495
+
1486
1496
setOperationAction(ISD::BITREVERSE, MVT::v32i8, Custom);
1487
1497
1488
1498
for (auto VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 }) {
@@ -1802,7 +1812,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
1802
1812
setOperationAction(ISD::TRUNCATE, MVT::v8i32, Legal);
1803
1813
setOperationAction(ISD::TRUNCATE, MVT::v16i16, Legal);
1804
1814
setOperationAction(ISD::TRUNCATE, MVT::v32i8, HasBWI ? Legal : Custom);
1805
- setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom);
1806
1815
setOperationAction(ISD::ZERO_EXTEND, MVT::v32i16, Custom);
1807
1816
setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom);
1808
1817
setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64, Custom);
@@ -2338,10 +2347,6 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2338
2347
setOperationAction(ISD::FP_EXTEND, MVT::v4f16, Custom);
2339
2348
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::v4f16, Custom);
2340
2349
}
2341
-
2342
- setOperationAction(ISD::TRUNCATE, MVT::v16i32, Custom);
2343
- setOperationAction(ISD::TRUNCATE, MVT::v8i64, Custom);
2344
- setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom);
2345
2350
}
2346
2351
2347
2352
if (Subtarget.hasAMXTILE()) {
@@ -22869,6 +22874,84 @@ static SDValue truncateVectorWithPACKSS(EVT DstVT, SDValue In, const SDLoc &DL,
22869
22874
return truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG, Subtarget);
22870
22875
}
22871
22876
22877
+ /// This function lowers a vector truncation of 'extended sign-bits' or
22878
+ /// 'extended zero-bits' values.
22879
+ /// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
22880
+ static SDValue LowerTruncateVecPackWithSignBits(MVT DstVT, SDValue In,
22881
+ const SDLoc &DL,
22882
+ const X86Subtarget &Subtarget,
22883
+ SelectionDAG &DAG) {
22884
+ MVT SrcVT = In.getSimpleValueType();
22885
+ MVT DstSVT = DstVT.getVectorElementType();
22886
+ MVT SrcSVT = SrcVT.getVectorElementType();
22887
+ if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
22888
+ (DstSVT == MVT::i8 || DstSVT == MVT::i16 || DstSVT == MVT::i32)))
22889
+ return SDValue();
22890
+
22891
+ unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits();
22892
+ unsigned NumPackedSignBits = std::min<unsigned>(DstSVT.getSizeInBits(), 16);
22893
+ unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
22894
+
22895
+ // Truncate with PACKUS if we are truncating a vector with leading zero
22896
+ // bits that extend all the way to the packed/truncated value. Pre-SSE41
22897
+ // we can only use PACKUSWB.
22898
+ KnownBits Known = DAG.computeKnownBits(In);
22899
+ if ((NumSrcEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros())
22900
+ if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, DstVT, In, DL, DAG,
22901
+ Subtarget))
22902
+ return V;
22903
+
22904
+ // Truncate with PACKSS if we are truncating a vector with sign-bits
22905
+ // that extend all the way to the packed/truncated value.
22906
+ if ((NumSrcEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In))
22907
+ if (SDValue V = truncateVectorWithPACK(X86ISD::PACKSS, DstVT, In, DL, DAG,
22908
+ Subtarget))
22909
+ return V;
22910
+
22911
+ return SDValue();
22912
+ }
22913
+
22914
+ /// This function lowers a vector truncation from vXi32/vXi64 to vXi8/vXi16 into
22915
+ /// X86ISD::PACKUS/X86ISD::PACKSS operations.
22916
+ static SDValue LowerTruncateVecPack(MVT DstVT, SDValue In, const SDLoc &DL,
22917
+ const X86Subtarget &Subtarget,
22918
+ SelectionDAG &DAG) {
22919
+ MVT SrcVT = In.getSimpleValueType();
22920
+ MVT DstSVT = DstVT.getVectorElementType();
22921
+ MVT SrcSVT = SrcVT.getVectorElementType();
22922
+ unsigned NumElems = DstVT.getVectorNumElements();
22923
+ if (!((SrcSVT == MVT::i16 || SrcSVT == MVT::i32 || SrcSVT == MVT::i64) &&
22924
+ (DstSVT == MVT::i8 || DstSVT == MVT::i16) && isPowerOf2_32(NumElems) &&
22925
+ NumElems >= 8))
22926
+ return SDValue();
22927
+
22928
+ // SSSE3's pshufb results in less instructions in the cases below.
22929
+ if (Subtarget.hasSSSE3() && NumElems == 8) {
22930
+ if (SrcSVT == MVT::i16)
22931
+ return SDValue();
22932
+ if (SrcSVT == MVT::i32 && (DstSVT == MVT::i8 || !Subtarget.hasSSE41()))
22933
+ return SDValue();
22934
+ }
22935
+
22936
+ // SSE2 provides PACKUS for only 2 x v8i16 -> v16i8 and SSE4.1 provides PACKUS
22937
+ // for 2 x v4i32 -> v8i16. For SSSE3 and below, we need to use PACKSS to
22938
+ // truncate 2 x v4i32 to v8i16.
22939
+ if (Subtarget.hasSSE41() || DstSVT == MVT::i8)
22940
+ return truncateVectorWithPACKUS(DstVT, In, DL, Subtarget, DAG);
22941
+
22942
+ if (SrcSVT == MVT::i16 || SrcSVT == MVT::i32)
22943
+ return truncateVectorWithPACKSS(DstVT, In, DL, Subtarget, DAG);
22944
+
22945
+ // Special case vXi64 -> vXi16, shuffle to vXi32 and then use PACKSS.
22946
+ if (DstSVT == MVT::i16 && SrcSVT == MVT::i64) {
22947
+ MVT TruncVT = MVT::getVectorVT(MVT::i32, NumElems);
22948
+ SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, In);
22949
+ return truncateVectorWithPACKSS(DstVT, Trunc, DL, Subtarget, DAG);
22950
+ }
22951
+
22952
+ return SDValue();
22953
+ }
22954
+
22872
22955
static SDValue LowerTruncateVecI1(SDValue Op, SelectionDAG &DAG,
22873
22956
const X86Subtarget &Subtarget) {
22874
22957
@@ -22955,16 +23038,14 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
22955
23038
MVT VT = Op.getSimpleValueType();
22956
23039
SDValue In = Op.getOperand(0);
22957
23040
MVT InVT = In.getSimpleValueType();
22958
- unsigned InNumEltBits = InVT.getScalarSizeInBits();
22959
-
22960
23041
assert(VT.getVectorNumElements() == InVT.getVectorNumElements() &&
22961
23042
"Invalid TRUNCATE operation");
22962
23043
22963
23044
// If we're called by the type legalizer, handle a few cases.
22964
23045
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22965
23046
if (!TLI.isTypeLegal(InVT)) {
22966
23047
if ((InVT == MVT::v8i64 || InVT == MVT::v16i32 || InVT == MVT::v16i64) &&
22967
- VT.is128BitVector()) {
23048
+ VT.is128BitVector() && Subtarget.hasAVX512() ) {
22968
23049
assert((InVT == MVT::v16i64 || Subtarget.hasVLX()) &&
22969
23050
"Unexpected subtarget!");
22970
23051
// The default behavior is to truncate one step, concatenate, and then
@@ -22981,35 +23062,28 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
22981
23062
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
22982
23063
}
22983
23064
23065
+ // Pre-AVX512 see if we can make use of PACKSS/PACKUS.
23066
+ if (!Subtarget.hasAVX512()) {
23067
+ if (SDValue SignPack =
23068
+ LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
23069
+ return SignPack;
23070
+
23071
+ return LowerTruncateVecPack(VT, In, DL, Subtarget, DAG);
23072
+ }
23073
+
22984
23074
// Otherwise let default legalization handle it.
22985
23075
return SDValue();
22986
23076
}
22987
23077
22988
23078
if (VT.getVectorElementType() == MVT::i1)
22989
23079
return LowerTruncateVecI1(Op, DAG, Subtarget);
22990
23080
22991
- unsigned NumPackedSignBits = std::min<unsigned>(VT.getScalarSizeInBits(), 16);
22992
- unsigned NumPackedZeroBits = Subtarget.hasSSE41() ? NumPackedSignBits : 8;
22993
-
22994
23081
// Attempt to truncate with PACKUS/PACKSS even on AVX512 if we'd have to
22995
23082
// concat from subvectors to use VPTRUNC etc.
22996
- if (!Subtarget.hasAVX512() || isFreeToSplitVector(In.getNode(), DAG)) {
22997
- // Truncate with PACKUS if we are truncating a vector with leading zero
22998
- // bits that extend all the way to the packed/truncated value. Pre-SSE41
22999
- // we can only use PACKUSWB.
23000
- KnownBits Known = DAG.computeKnownBits(In);
23001
- if ((InNumEltBits - NumPackedZeroBits) <= Known.countMinLeadingZeros())
23002
- if (SDValue V = truncateVectorWithPACK(X86ISD::PACKUS, VT, In, DL, DAG,
23003
- Subtarget))
23004
- return V;
23005
-
23006
- // Truncate with PACKSS if we are truncating a vector with sign-bits
23007
- // that extend all the way to the packed/truncated value.
23008
- if ((InNumEltBits - NumPackedSignBits) < DAG.ComputeNumSignBits(In))
23009
- if (SDValue V = truncateVectorWithPACK(X86ISD::PACKSS, VT, In, DL, DAG,
23010
- Subtarget))
23011
- return V;
23012
- }
23083
+ if (!Subtarget.hasAVX512() || isFreeToSplitVector(In.getNode(), DAG))
23084
+ if (SDValue SignPack =
23085
+ LowerTruncateVecPackWithSignBits(VT, In, DL, Subtarget, DAG))
23086
+ return SignPack;
23013
23087
23014
23088
// vpmovqb/w/d, vpmovdb/w, vpmovwb
23015
23089
if (Subtarget.hasAVX512()) {
@@ -23068,27 +23142,9 @@ SDValue X86TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const {
23068
23142
return DAG.getBitcast(MVT::v8i16, In);
23069
23143
}
23070
23144
23071
- SDValue OpLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i32, In,
23072
- DAG.getIntPtrConstant(0, DL));
23073
- SDValue OpHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i32, In,
23074
- DAG.getIntPtrConstant(4, DL));
23075
-
23076
- // The PSHUFB mask:
23077
- static const int ShufMask1[] = {0, 2, 4, 6, -1, -1, -1, -1};
23078
-
23079
- OpLo = DAG.getBitcast(MVT::v8i16, OpLo);
23080
- OpHi = DAG.getBitcast(MVT::v8i16, OpHi);
23081
-
23082
- OpLo = DAG.getVectorShuffle(MVT::v8i16, DL, OpLo, OpLo, ShufMask1);
23083
- OpHi = DAG.getVectorShuffle(MVT::v8i16, DL, OpHi, OpHi, ShufMask1);
23084
-
23085
- OpLo = DAG.getBitcast(MVT::v4i32, OpLo);
23086
- OpHi = DAG.getBitcast(MVT::v4i32, OpHi);
23087
-
23088
- // The MOVLHPS Mask:
23089
- static const int ShufMask2[] = {0, 1, 4, 5};
23090
- SDValue res = DAG.getVectorShuffle(MVT::v4i32, DL, OpLo, OpHi, ShufMask2);
23091
- return DAG.getBitcast(MVT::v8i16, res);
23145
+ return Subtarget.hasSSE41()
23146
+ ? truncateVectorWithPACKUS(VT, In, DL, Subtarget, DAG)
23147
+ : truncateVectorWithPACKSS(VT, In, DL, Subtarget, DAG);
23092
23148
}
23093
23149
23094
23150
if (VT == MVT::v16i8 && InVT == MVT::v16i16)
@@ -53152,6 +53208,7 @@ static SDValue combineTruncatedArithmetic(SDNode *N, SelectionDAG &DAG,
53152
53208
/// legalization the truncation will be translated into a BUILD_VECTOR with each
53153
53209
/// element that is extracted from a vector and then truncated, and it is
53154
53210
/// difficult to do this optimization based on them.
53211
+ /// TODO: Remove this and just use LowerTruncateVecPack.
53155
53212
static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG,
53156
53213
const X86Subtarget &Subtarget) {
53157
53214
EVT OutVT = N->getValueType(0);
@@ -53200,6 +53257,7 @@ static SDValue combineVectorTruncation(SDNode *N, SelectionDAG &DAG,
53200
53257
/// This function transforms vector truncation of 'extended sign-bits' or
53201
53258
/// 'extended zero-bits' values.
53202
53259
/// vXi16/vXi32/vXi64 to vXi8/vXi16/vXi32 into X86ISD::PACKSS/PACKUS operations.
53260
+ /// TODO: Remove this and just use LowerTruncateVecPackWithSignBits.
53203
53261
static SDValue combineVectorSignBitsTruncation(SDNode *N, const SDLoc &DL,
53204
53262
SelectionDAG &DAG,
53205
53263
const X86Subtarget &Subtarget) {
0 commit comments