@@ -22032,30 +22032,31 @@ static SDValue tryCombineWhileLo(SDNode *N,
22032
22032
return SDValue(N, 0);
22033
22033
}
22034
22034
22035
- SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22035
+ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22036
+ SelectionDAG &DAG,
22036
22037
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22037
22038
bool Scalable = Acc.getValueType().isScalableVector();
22038
22039
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22039
22040
return SDValue();
22040
22041
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22041
22042
return SDValue();
22042
22043
22043
- unsigned InputOpcode = Input ->getOpcode();
22044
+ unsigned Input1Opcode = Input1 ->getOpcode();
22044
22045
EVT AccVT = Acc->getValueType(0);
22045
22046
if (AccVT.getVectorElementCount() * 4 ==
22046
- Input ->getValueType(0).getVectorElementCount() &&
22047
- InputOpcode != ISD::MUL)
22048
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22049
- if (InputOpcode != ISD::MUL)
22047
+ Input1 ->getValueType(0).getVectorElementCount() &&
22048
+ Input1Opcode != ISD::MUL)
22049
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22050
+ if (Input1Opcode != ISD::MUL)
22050
22051
return SDValue();
22051
22052
22052
- auto A = Input ->getOperand(0);
22053
- auto B = Input ->getOperand(1);
22053
+ auto A = Input1 ->getOperand(0);
22054
+ auto B = Input1 ->getOperand(1);
22054
22055
unsigned AOpcode = A->getOpcode();
22055
22056
unsigned BOpcode = B->getOpcode();
22056
22057
22057
22058
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22058
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22059
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22059
22060
22060
22061
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22061
22062
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
@@ -22064,6 +22065,10 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22064
22065
B = B->getOperand(0);
22065
22066
EVT MulSrcVT = A.getValueType();
22066
22067
22068
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22069
+ A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22070
+ B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22071
+
22067
22072
// Dot products operate on chunks of four elements so there must be four times
22068
22073
// as many elements in the wide type
22069
22074
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
@@ -22072,17 +22077,17 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22072
22077
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22073
22078
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22074
22079
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22075
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22080
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22076
22081
22077
22082
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22078
22083
if (AIsSigned != BIsSigned) {
22079
22084
if (!Subtarget->hasMatMulInt8())
22080
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22085
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22081
22086
22082
22087
bool Scalable = AccVT.isScalableVT();
22083
22088
// There's no nxv2i64 version of usdot
22084
22089
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22085
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22090
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22086
22091
22087
22092
if (!BIsSigned)
22088
22093
std::swap(A, B);
@@ -22109,32 +22114,37 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22109
22114
}
22110
22115
22111
22116
if (A.getValueType() != B.getValueType())
22112
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22117
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22113
22118
22114
22119
unsigned NewOpcode =
22115
22120
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22116
22121
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
22117
22122
}
22118
22123
22119
- SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22124
+ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22125
+ SelectionDAG &DAG,
22120
22126
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22121
22127
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22122
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
22123
- unsigned InputOpcode = Input->getOpcode();
22124
- if (!ISD::isExtOpcode(InputOpcode))
22125
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
22126
- Input = Input->getOperand(0);
22127
- EVT InputVT = Input.getValueType();
22128
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22129
+ unsigned Input1Opcode = Input1->getOpcode();
22130
+ if (!ISD::isExtOpcode(Input1Opcode))
22131
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22132
+
22128
22133
EVT AccVT = Acc->getValueType(0);
22134
+ Input1 = Input1->getOperand(0);
22135
+ EVT InputVT = Input1.getValueType();
22136
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
22137
+ SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
22138
+
22129
22139
if (!AccVT.isScalableVector())
22130
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22140
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22131
22141
22132
22142
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22133
22143
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22134
22144
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22135
22145
return SDValue();
22136
22146
22137
- unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22147
+ unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
22138
22148
? ISD::PARTIAL_REDUCE_SMLA
22139
22149
: ISD::PARTIAL_REDUCE_UMLA;
22140
22150
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
@@ -22145,18 +22155,21 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22145
22155
const AArch64Subtarget *Subtarget) {
22146
22156
SDLoc DL(N);
22147
22157
auto Acc = N->getOperand(0);
22148
- auto Input = N->getOperand(1);
22158
+ auto Input1 = N->getOperand(1);
22159
+ auto Input2 = N->getOperand(2);
22149
22160
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22150
- EVT InputElemVT = Input .getValueType().getVectorElementType();
22161
+ EVT InputElemVT = Input1 .getValueType().getVectorElementType();
22151
22162
22152
22163
// If the exts have already been removed or it has already been lowered to an
22153
22164
// usdot instruction, then the element types will not be equal
22154
- if (InputElemVT != AccElemVT || Input .getOpcode() == AArch64ISD::USDOT)
22165
+ if (InputElemVT != AccElemVT || Input1 .getOpcode() == AArch64ISD::USDOT)
22155
22166
return SDValue(N, 0);
22156
22167
22157
- if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
22168
+ if (auto Dot =
22169
+ tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
22158
22170
return Dot;
22159
- if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
22171
+ if (auto WideAdd =
22172
+ tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
22160
22173
return WideAdd;
22161
22174
return SDValue();
22162
22175
}
0 commit comments