Skip to content

Commit 982d6e0

Browse files
Add the MUL in LowerPARTIAL_REDUCE_MLA()
Only do it if Input2 is a splat vector of constant 1s. Still create the MUL in the DAG combine for the wide add pattern. This is because it is pruned if an operand is constant 1s, or changed to a shift instruction if an operand is a power of 2. This would not happen if the MUL was made in LowerPARTIAL_REDUCE_MLA.
1 parent 3cdf92f commit 982d6e0

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22120,7 +22120,10 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2212022120
EVT AccVT = Op0->getValueType(0);
2212122121
Op1 = Op1->getOperand(0);
2212222122
EVT Op1VT = Op1.getValueType();
22123+
// Makes Op2's value type match the value type of Op1 without its extend.
2212322124
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
22125+
// Make a MUL between Op1 and Op2 here so the MUL can be changed if possible
22126+
// (can be pruned or changed to a shift instruction for example).
2212422127
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
2212522128

2212622129
if (!AccVT.isScalableVector())
@@ -22133,6 +22136,7 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2213322136

2213422137
unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
2213522138
: ISD::PARTIAL_REDUCE_UMLA;
22139+
// Return a constant of 1s for Op2 so the MUL is not performed again.
2213622140
return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
2213722141
DAG.getConstant(1, DL, Op1VT));
2213822142
}
@@ -29389,11 +29393,19 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
2938929393
: AArch64ISD::UDOT;
2939029394
return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
2939129395
}
29396+
29397+
SDValue MulInput = Input1;
29398+
// If Input2 is a splat vector of constant 1 then the MUL instruction is not
29399+
// needed. If it was created here it would not be automatically pruned.
29400+
if (Input2.getOpcode() != ISD::SPLAT_VECTOR || Input2.getNumOperands() == 0 ||
29401+
!isOneConstant(Input2.getOperand(0)))
29402+
MulInput = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
29403+
2939229404
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
2939329405
unsigned BottomOpcode =
2939429406
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
2939529407
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29396-
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
29408+
SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
2939729409
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
2939829410
}
2939929411

0 commit comments

Comments
 (0)