Skip to content

Commit c739dd8

Browse files
Zain Jaffalfhahn
authored andcommitted
[AArch64] turn extended vecreduce bigger than v16i8 into udot/sdot
We can do this by breaking vecreduce into v16i8 vectors generating udot/sdot and concatenating them. Differential Revision: https://reviews.llvm.org/D141693
1 parent e59ec57 commit c739dd8

File tree

2 files changed

+1039
-1369
lines changed

2 files changed

+1039
-1369
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14805,6 +14805,9 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
1480514805
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
1480614806
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
1480714807
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
14808+
// If we have vectors larger than v16i8 we extract v16i8 vectors,
14809+
// Follow the same steps above to get DOT instructions concatenate them
14810+
// and generate vecreduce.add(concat_vector(DOT, DOT2, ..)).
1480814811
static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1480914812
const AArch64Subtarget *ST) {
1481014813
if (!ST->hasDotProd())
@@ -14830,7 +14833,9 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1483014833
return SDValue();
1483114834

1483214835
EVT Op0VT = A.getOperand(0).getValueType();
14833-
if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
14836+
bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
14837+
bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
14838+
if (!IsValidElementCount || !IsValidSize)
1483414839
return SDValue();
1483514840

1483614841
SDLoc DL(Op0);
@@ -14841,13 +14846,65 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1484114846
else
1484214847
B = B.getOperand(0);
1484314848

14844-
SDValue Zeros =
14845-
DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
14849+
unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0;
14850+
unsigned NumOfVecReduce;
14851+
EVT TargetType;
14852+
if (IsMultipleOf16) {
14853+
NumOfVecReduce = Op0VT.getVectorNumElements() / 16;
14854+
TargetType = MVT::v4i32;
14855+
} else {
14856+
NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
14857+
TargetType = MVT::v2i32;
14858+
}
1484614859
auto DotOpcode =
1484714860
(ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
14848-
SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
14849-
A.getOperand(0), B);
14850-
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14861+
// Handle the case where we need to generate only one Dot operation.
14862+
if (NumOfVecReduce == 1) {
14863+
SDValue Zeros = DAG.getConstant(0, DL, TargetType);
14864+
SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
14865+
A.getOperand(0), B);
14866+
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14867+
}
14868+
// Generate Dot instructions that are multiple of 16.
14869+
unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16;
14870+
SmallVector<SDValue, 4> SDotVec16;
14871+
unsigned I = 0;
14872+
for (; I < VecReduce16Num; I += 1) {
14873+
SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
14874+
SDValue Op0 =
14875+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
14876+
DAG.getConstant(I * 16, DL, MVT::i64));
14877+
SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B,
14878+
DAG.getConstant(I * 16, DL, MVT::i64));
14879+
SDValue Dot =
14880+
DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1);
14881+
SDotVec16.push_back(Dot);
14882+
}
14883+
// Concatenate dot operations.
14884+
EVT SDot16EVT =
14885+
EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num);
14886+
SDValue ConcatSDot16 =
14887+
DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16);
14888+
SDValue VecReduceAdd16 =
14889+
DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16);
14890+
unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8;
14891+
if (VecReduce8Num == 0)
14892+
return VecReduceAdd16;
14893+
14894+
// Generate the remainder Dot operation that is multiple of 8.
14895+
SmallVector<SDValue, 4> SDotVec8;
14896+
SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32);
14897+
SDValue Vec8Op0 =
14898+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0),
14899+
DAG.getConstant(I * 16, DL, MVT::i64));
14900+
SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B,
14901+
DAG.getConstant(I * 16, DL, MVT::i64));
14902+
SDValue Dot =
14903+
DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1);
14904+
SDValue VecReudceAdd8 =
14905+
DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
14906+
return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16,
14907+
VecReudceAdd8);
1485114908
}
1485214909

1485314910
// Given an (integer) vecreduce, we know the order of the inputs does not

0 commit comments

Comments
 (0)