Skip to content

Commit 9231804

Browse files
Change the way the dot product pattern is checked for lowering.
Add condition in wide add combine to not allow fixed length vectors.
1 parent bd02348 commit 9231804

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21999,13 +21999,18 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2199921999
return SDValue();
2200022000

2200122001
unsigned InputOpcode = Input->getOpcode();
22002+
EVT AccVT = Acc->getValueType(0);
22003+
if (AccVT.getVectorElementCount() * 4 ==
22004+
Input->getValueType(0).getVectorElementCount() &&
22005+
InputOpcode != ISD::MUL)
22006+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2200222007
if (InputOpcode != ISD::MUL)
2200322008
return SDValue();
22009+
2200422010
auto A = Input->getOperand(0);
2200522011
auto B = Input->getOperand(1);
2200622012
unsigned AOpcode = A->getOpcode();
2200722013
unsigned BOpcode = B->getOpcode();
22008-
EVT AccVT = Acc->getValueType(0);
2200922014

2201022015
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
2201122016
return DAG.expandPartialReduceAdd(DL, Acc, Input);
@@ -22080,6 +22085,8 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2208022085
Input = Input->getOperand(0);
2208122086
EVT InputVT = Input.getValueType();
2208222087
EVT AccVT = Acc->getValueType(0);
22088+
if (!AccVT.isScalableVector())
22089+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2208322090

2208422091
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
2208522092
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -29180,6 +29187,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
2918029187

2918129188
unsigned Opcode = Op.getOpcode();
2918229189

29190+
// If the following condition is true and the input opcode was not ISD::MUL
29191+
// during the DAG-combine, it is already expanded. So this condition means the
29192+
// input opcode must have been ISD::MUL.
2918329193
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
2918429194
unsigned IndexAdd = 0;
2918529195
// ISD::MUL may have already been lowered, meaning the operands would be in

0 commit comments

Comments
 (0)