Skip to content

Commit 5e31db5

Browse files
Change the way the dot product pattern is checked for lowering.
Add condition in wide add combine to not allow fixed length vectors.
1 parent 43f73c2 commit 5e31db5

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22041,13 +22041,18 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2204122041
return SDValue();
2204222042

2204322043
unsigned InputOpcode = Input->getOpcode();
22044+
EVT AccVT = Acc->getValueType(0);
22045+
if (AccVT.getVectorElementCount() * 4 ==
22046+
Input->getValueType(0).getVectorElementCount() &&
22047+
InputOpcode != ISD::MUL)
22048+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2204422049
if (InputOpcode != ISD::MUL)
2204522050
return SDValue();
22051+
2204622052
auto A = Input->getOperand(0);
2204722053
auto B = Input->getOperand(1);
2204822054
unsigned AOpcode = A->getOpcode();
2204922055
unsigned BOpcode = B->getOpcode();
22050-
EVT AccVT = Acc->getValueType(0);
2205122056

2205222057
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
2205322058
return DAG.expandPartialReduceAdd(DL, Acc, Input);
@@ -22122,6 +22127,8 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2212222127
Input = Input->getOperand(0);
2212322128
EVT InputVT = Input.getValueType();
2212422129
EVT AccVT = Acc->getValueType(0);
22130+
if (!AccVT.isScalableVector())
22131+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2212522132

2212622133
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
2212722134
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -29376,6 +29383,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
2937629383

2937729384
unsigned Opcode = Op.getOpcode();
2937829385

29386+
// If the following condition is true and the input opcode was not ISD::MUL
29387+
// during the DAG-combine, it is already expanded. So this condition means the
29388+
// input opcode must have been ISD::MUL.
2937929389
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
2938029390
unsigned IndexAdd = 0;
2938129391
// ISD::MUL may have already been lowered, meaning the operands would be in

0 commit comments

Comments
 (0)