@@ -18177,16 +18177,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
18177
18177
unsigned ExtOpcode = Op0.getOpcode();
18178
18178
SDValue A = Op0;
18179
18179
SDValue B;
18180
+ unsigned DotOpcode;
18180
18181
if (ExtOpcode == ISD::MUL) {
18181
18182
A = Op0.getOperand(0);
18182
18183
B = Op0.getOperand(1);
18183
- if (A.getOpcode() != B.getOpcode() ||
18184
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
18184
+ if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
18185
18185
return SDValue();
18186
- ExtOpcode = A.getOpcode();
18187
- }
18188
- if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
18186
+ auto OpCodeA = A.getOpcode();
18187
+ if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
18188
+ return SDValue();
18189
+
18190
+ auto OpCodeB = B.getOpcode();
18191
+ if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
18192
+ return SDValue();
18193
+
18194
+ if (OpCodeA == OpCodeB) {
18195
+ DotOpcode =
18196
+ OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
18197
+ } else {
18198
+ // Check USDOT support support
18199
+ if (!ST->hasMatMulInt8())
18200
+ return SDValue();
18201
+ DotOpcode = AArch64ISD::USDOT;
18202
+ if (OpCodeA == ISD::SIGN_EXTEND)
18203
+ std::swap(A, B);
18204
+ }
18205
+ } else if (ExtOpcode == ISD::ZERO_EXTEND) {
18206
+ DotOpcode = AArch64ISD::UDOT;
18207
+ } else if (ExtOpcode == ISD::SIGN_EXTEND) {
18208
+ DotOpcode = AArch64ISD::SDOT;
18209
+ } else {
18189
18210
return SDValue();
18211
+ }
18190
18212
18191
18213
EVT Op0VT = A.getOperand(0).getValueType();
18192
18214
bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18212,8 +18234,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
18212
18234
NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
18213
18235
TargetType = MVT::v2i32;
18214
18236
}
18215
- auto DotOpcode =
18216
- (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
18217
18237
// Handle the case where we need to generate only one Dot operation.
18218
18238
if (NumOfVecReduce == 1) {
18219
18239
SDValue Zeros = DAG.getConstant(0, DL, TargetType);
0 commit comments