@@ -14805,6 +14805,9 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
14805
14805
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
14806
14806
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
14807
14807
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
14808
+ // If we have vectors larger than v16i8 we extract v16i8 vectors,
14809
+ // Follow the same steps above to get DOT instructions concatenate them
14810
+ // and generate vecreduce.add(concat_vector(DOT, DOT2, ..)).
14808
14811
static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
14809
14812
const AArch64Subtarget *ST) {
14810
14813
if (!ST->hasDotProd())
@@ -14830,7 +14833,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
14830
14833
return SDValue();
14831
14834
14832
14835
EVT Op0VT = A.getOperand(0).getValueType();
14833
- if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
14836
+ bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
14837
+ bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
14838
+ if (!IsValidElementCount || !IsValidSize)
14834
14839
return SDValue();
14835
14840
14836
14841
SDLoc DL(Op0);
@@ -14841,13 +14846,65 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
14841
14846
else
14842
14847
B = B.getOperand(0);
14843
14848
14844
- SDValue Zeros =
14845
- DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
14849
+ unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0;
14850
+ unsigned NumOfVecReduce;
14851
+ EVT TargetType;
14852
+ if (IsMultipleOf16) {
14853
+ NumOfVecReduce = Op0VT.getVectorNumElements() / 16;
14854
+ TargetType = MVT::v4i32;
14855
+ } else {
14856
+ NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
14857
+ TargetType = MVT::v2i32;
14858
+ }
14846
14859
auto DotOpcode =
14847
14860
(ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
14848
- SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
14849
- A.getOperand(0), B);
14850
- return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14861
+ // Handle the case where we need to generate only one Dot operation.
14862
+ if (NumOfVecReduce == 1) {
14863
+ SDValue Zeros = DAG.getConstant(0, DL, TargetType);
14864
+ SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
14865
+ A.getOperand(0), B);
14866
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14867
+ }
14868
+ // Generate Dot instructions that are multiple of 16.
14869
+ unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16;
14870
+ SmallVector<SDValue, 4> SDotVec16;
14871
+ unsigned I = 0;
14872
+ for (; I < VecReduce16Num; I += 1) {
14873
+ SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
14874
+ SDValue Op0 =
14875
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
14876
+ DAG.getConstant(I * 16, DL, MVT::i64));
14877
+ SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B,
14878
+ DAG.getConstant(I * 16, DL, MVT::i64));
14879
+ SDValue Dot =
14880
+ DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1);
14881
+ SDotVec16.push_back(Dot);
14882
+ }
14883
+ // Concatenate dot operations.
14884
+ EVT SDot16EVT =
14885
+ EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num);
14886
+ SDValue ConcatSDot16 =
14887
+ DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16);
14888
+ SDValue VecReduceAdd16 =
14889
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16);
14890
+ unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8;
14891
+ if (VecReduce8Num == 0)
14892
+ return VecReduceAdd16;
14893
+
14894
+ // Generate the remainder Dot operation that is multiple of 8.
14895
+ SmallVector<SDValue, 4> SDotVec8;
14896
+ SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32);
14897
+ SDValue Vec8Op0 =
14898
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0),
14899
+ DAG.getConstant(I * 16, DL, MVT::i64));
14900
+ SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B,
14901
+ DAG.getConstant(I * 16, DL, MVT::i64));
14902
+ SDValue Dot =
14903
+ DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1);
14904
+ SDValue VecReudceAdd8 =
14905
+ DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14906
+ return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16,
14907
+ VecReudceAdd8);
14851
14908
}
14852
14909
14853
14910
// Given an (integer) vecreduce, we know the order of the inputs does not
0 commit comments