@@ -42964,6 +42964,24 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
42964
42964
42965
42965
SDLoc DL(ExtElt);
42966
42966
42967
+ // Extend v4i8/v8i8 vector to v16i8, with undef upper 64-bits.
42968
+ auto WidenToV16I8 = [&](SDValue V, bool ZeroExtend) {
42969
+ if (VecVT == MVT::v4i8) {
42970
+ if (ZeroExtend && Subtarget.hasSSE41()) {
42971
+ V = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v4i32,
42972
+ DAG.getConstant(0, DL, MVT::v4i32),
42973
+ DAG.getBitcast(MVT::i32, V),
42974
+ DAG.getIntPtrConstant(0, DL));
42975
+ return DAG.getBitcast(MVT::v16i8, V);
42976
+ }
42977
+ V = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i8, V,
42978
+ ZeroExtend ? DAG.getConstant(0, DL, MVT::v4i8)
42979
+ : DAG.getUNDEF(MVT::v4i8));
42980
+ }
42981
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V,
42982
+ DAG.getUNDEF(MVT::v8i8));
42983
+ };
42984
+
42967
42985
// vXi8 mul reduction - promote to vXi16 mul reduction.
42968
42986
if (Opc == ISD::MUL) {
42969
42987
unsigned NumElts = VecVT.getVectorNumElements();
@@ -42981,11 +42999,7 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
42981
42999
Rdx = DAG.getNode(Opc, DL, Lo.getValueType(), Lo, Hi);
42982
43000
}
42983
43001
} else {
42984
- if (VecVT == MVT::v4i8)
42985
- Rdx = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i8, Rdx,
42986
- DAG.getUNDEF(MVT::v4i8));
42987
- Rdx = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Rdx,
42988
- DAG.getUNDEF(MVT::v8i8));
43002
+ Rdx = WidenToV16I8(Rdx, false);
42989
43003
Rdx = getUnpackl(DAG, DL, MVT::v16i8, Rdx, DAG.getUNDEF(MVT::v16i8));
42990
43004
Rdx = DAG.getBitcast(MVT::v8i16, Rdx);
42991
43005
}
@@ -43005,24 +43019,7 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG,
43005
43019
43006
43020
// vXi8 add reduction - sub 128-bit vector.
43007
43021
if (VecVT == MVT::v4i8 || VecVT == MVT::v8i8) {
43008
- if (VecVT == MVT::v4i8) {
43009
- // Pad with zero.
43010
- if (Subtarget.hasSSE41()) {
43011
- Rdx = DAG.getBitcast(MVT::i32, Rdx);
43012
- Rdx = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v4i32,
43013
- DAG.getConstant(0, DL, MVT::v4i32), Rdx,
43014
- DAG.getIntPtrConstant(0, DL));
43015
- Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
43016
- } else {
43017
- Rdx = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i8, Rdx,
43018
- DAG.getConstant(0, DL, VecVT));
43019
- }
43020
- }
43021
- if (Rdx.getValueType() == MVT::v8i8) {
43022
- // Pad with undef.
43023
- Rdx = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Rdx,
43024
- DAG.getUNDEF(MVT::v8i8));
43025
- }
43022
+ Rdx = WidenToV16I8(Rdx, true);
43026
43023
Rdx = DAG.getNode(X86ISD::PSADBW, DL, MVT::v2i64, Rdx,
43027
43024
DAG.getConstant(0, DL, MVT::v16i8));
43028
43025
Rdx = DAG.getBitcast(MVT::v16i8, Rdx);
0 commit comments