@@ -21999,13 +21999,18 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
21999
21999
return SDValue();
22000
22000
22001
22001
unsigned InputOpcode = Input->getOpcode();
22002
+ EVT AccVT = Acc->getValueType(0);
22003
+ if (AccVT.getVectorElementCount() * 4 ==
22004
+ Input->getValueType(0).getVectorElementCount() &&
22005
+ InputOpcode != ISD::MUL)
22006
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22002
22007
if (InputOpcode != ISD::MUL)
22003
22008
return SDValue();
22009
+
22004
22010
auto A = Input->getOperand(0);
22005
22011
auto B = Input->getOperand(1);
22006
22012
unsigned AOpcode = A->getOpcode();
22007
22013
unsigned BOpcode = B->getOpcode();
22008
- EVT AccVT = Acc->getValueType(0);
22009
22014
22010
22015
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22011
22016
return DAG.expandPartialReduceAdd(DL, Acc, Input);
@@ -22080,6 +22085,8 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22080
22085
Input = Input->getOperand(0);
22081
22086
EVT InputVT = Input.getValueType();
22082
22087
EVT AccVT = Acc->getValueType(0);
22088
+ if (!AccVT.isScalableVector())
22089
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22083
22090
22084
22091
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22085
22092
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -29180,6 +29187,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
29180
29187
29181
29188
unsigned Opcode = Op.getOpcode();
29182
29189
29190
+ // If the following condition is true and the input opcode was not ISD::MUL
29191
+ // during the DAG-combine, it is already expanded. So this condition means the
29192
+ // input opcode must have been ISD::MUL.
29183
29193
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29184
29194
unsigned IndexAdd = 0;
29185
29195
// ISD::MUL may have already been lowered, meaning the operands would be in
0 commit comments