@@ -22041,13 +22041,18 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22041
22041
return SDValue();
22042
22042
22043
22043
unsigned InputOpcode = Input->getOpcode();
22044
+ EVT AccVT = Acc->getValueType(0);
22045
+ if (AccVT.getVectorElementCount() * 4 ==
22046
+ Input->getValueType(0).getVectorElementCount() &&
22047
+ InputOpcode != ISD::MUL)
22048
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22044
22049
if (InputOpcode != ISD::MUL)
22045
22050
return SDValue();
22051
+
22046
22052
auto A = Input->getOperand(0);
22047
22053
auto B = Input->getOperand(1);
22048
22054
unsigned AOpcode = A->getOpcode();
22049
22055
unsigned BOpcode = B->getOpcode();
22050
- EVT AccVT = Acc->getValueType(0);
22051
22056
22052
22057
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22053
22058
return DAG.expandPartialReduceAdd(DL, Acc, Input);
@@ -22122,6 +22127,8 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22122
22127
Input = Input->getOperand(0);
22123
22128
EVT InputVT = Input.getValueType();
22124
22129
EVT AccVT = Acc->getValueType(0);
22130
+ if (!AccVT.isScalableVector())
22131
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22125
22132
22126
22133
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22127
22134
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -29376,6 +29383,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
29376
29383
29377
29384
unsigned Opcode = Op.getOpcode();
29378
29385
29386
+ // If the following condition is true and the input opcode was not ISD::MUL
29387
+ // during the DAG-combine, it is already expanded. So this condition means the
29388
+ // input opcode must have been ISD::MUL.
29379
29389
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29380
29390
unsigned IndexAdd = 0;
29381
29391
// ISD::MUL may have already been lowered, meaning the operands would be in
0 commit comments