Skip to content

Commit a35640f

Browse files
authored
[AArch64] Extend vecreduce to udot/sdot transformation to support usdot (#120094)
1 parent 3cc311a commit a35640f

File tree

2 files changed

+1130
-10
lines changed

2 files changed

+1130
-10
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18177,16 +18177,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1817718177
unsigned ExtOpcode = Op0.getOpcode();
1817818178
SDValue A = Op0;
1817918179
SDValue B;
18180+
unsigned DotOpcode;
1818018181
if (ExtOpcode == ISD::MUL) {
1818118182
A = Op0.getOperand(0);
1818218183
B = Op0.getOperand(1);
18183-
if (A.getOpcode() != B.getOpcode() ||
18184-
A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
18184+
if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
1818518185
return SDValue();
18186-
ExtOpcode = A.getOpcode();
18187-
}
18188-
if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
18186+
auto OpCodeA = A.getOpcode();
18187+
if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
18188+
return SDValue();
18189+
18190+
auto OpCodeB = B.getOpcode();
18191+
if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
18192+
return SDValue();
18193+
18194+
if (OpCodeA == OpCodeB) {
18195+
DotOpcode =
18196+
OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
18197+
} else {
18198+
// Check USDOT support support
18199+
if (!ST->hasMatMulInt8())
18200+
return SDValue();
18201+
DotOpcode = AArch64ISD::USDOT;
18202+
if (OpCodeA == ISD::SIGN_EXTEND)
18203+
std::swap(A, B);
18204+
}
18205+
} else if (ExtOpcode == ISD::ZERO_EXTEND) {
18206+
DotOpcode = AArch64ISD::UDOT;
18207+
} else if (ExtOpcode == ISD::SIGN_EXTEND) {
18208+
DotOpcode = AArch64ISD::SDOT;
18209+
} else {
1818918210
return SDValue();
18211+
}
1819018212

1819118213
EVT Op0VT = A.getOperand(0).getValueType();
1819218214
bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18212,8 +18234,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1821218234
NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
1821318235
TargetType = MVT::v2i32;
1821418236
}
18215-
auto DotOpcode =
18216-
(ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
1821718237
// Handle the case where we need to generate only one Dot operation.
1821818238
if (NumOfVecReduce == 1) {
1821918239
SDValue Zeros = DAG.getConstant(0, DL, TargetType);

0 commit comments

Comments
 (0)