Skip to content

Commit 6092426

Browse files
MUL instructions now included in DAG combines.
1 parent 0a06b2a commit 6092426

File tree

4 files changed

+62
-37
lines changed

4 files changed

+62
-37
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1608,7 +1608,13 @@ class SelectionDAG {
16081608
/// \p Op1 Accumulator for where the result is stored for the partial
16091609
/// reduction operation.
16101610
/// \p Op2 Input for the partial reduction operation.
1611-
SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
1611+
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
1612+
/// \p Acc Accumulator for where the result is stored for the partial
1613+
/// reduction operation.
1614+
/// \p Input1 First input for the partial reduction operation.
1615+
/// \p Input2 Second input for the partial reduction operation.
1616+
SDValue expandPartialReduceAdd(SDLoc DL, SDValue Acc, SDValue Input1,
1617+
SDValue Input2);
16121618

16131619
/// Expands a node with multiple results to an FP or vector libcall. The
16141620
/// libcall is expected to take all the operands of the \p Node followed by

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,20 +2467,24 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
24672467
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
24682468
}
24692469

2470-
SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
2471-
SDValue Op2) {
2472-
EVT ReducedTy = Op1.getValueType();
2473-
EVT FullTy = Op2.getValueType();
2470+
SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Acc,
2471+
SDValue Input1, SDValue Input2) {
2472+
2473+
EVT FullTy = Input1.getValueType();
2474+
Input2 = getAnyExtOrTrunc(Input2, DL, FullTy);
2475+
SDValue Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
2476+
2477+
EVT ReducedTy = Acc.getValueType();
24742478

24752479
unsigned Stride = ReducedTy.getVectorMinNumElements();
24762480
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
24772481

24782482
// Collect all of the subvectors
2479-
std::deque<SDValue> Subvectors = {Op1};
2483+
std::deque<SDValue> Subvectors = {Acc};
24802484
for (unsigned I = 0; I < ScaleFactor; I++) {
24812485
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
24822486
Subvectors.push_back(
2483-
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
2487+
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Input, SourceIndex}));
24842488
}
24852489

24862490
// Flatten the subvector tree

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8124,11 +8124,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81248124
SDValue Input = getValue(I.getOperand(1));
81258125

81268126
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8127-
setValue(&I,
8128-
DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input));
8127+
setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
8128+
DAG.getConstant(1, dl, Input.getValueType())));
81298129
return;
81308130
}
8131-
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
8131+
setValue(&I,
8132+
DAG.expandPartialReduceAdd(
8133+
dl, Acc, Input, DAG.getConstant(1, dl, Input.getValueType())));
81328134
return;
81338135
}
81348136
case Intrinsic::experimental_cttz_elts: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22032,30 +22032,31 @@ static SDValue tryCombineWhileLo(SDNode *N,
2203222032
return SDValue(N, 0);
2203322033
}
2203422034

22035-
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22035+
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22036+
SelectionDAG &DAG,
2203622037
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2203722038
bool Scalable = Acc.getValueType().isScalableVector();
2203822039
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
2203922040
return SDValue();
2204022041
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
2204122042
return SDValue();
2204222043

22043-
unsigned InputOpcode = Input->getOpcode();
22044+
unsigned Input1Opcode = Input1->getOpcode();
2204422045
EVT AccVT = Acc->getValueType(0);
2204522046
if (AccVT.getVectorElementCount() * 4 ==
22046-
Input->getValueType(0).getVectorElementCount() &&
22047-
InputOpcode != ISD::MUL)
22048-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22049-
if (InputOpcode != ISD::MUL)
22047+
Input1->getValueType(0).getVectorElementCount() &&
22048+
Input1Opcode != ISD::MUL)
22049+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22050+
if (Input1Opcode != ISD::MUL)
2205022051
return SDValue();
2205122052

22052-
auto A = Input->getOperand(0);
22053-
auto B = Input->getOperand(1);
22053+
auto A = Input1->getOperand(0);
22054+
auto B = Input1->getOperand(1);
2205422055
unsigned AOpcode = A->getOpcode();
2205522056
unsigned BOpcode = B->getOpcode();
2205622057

2205722058
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22058-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22059+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2205922060

2206022061
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
2206122062
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
@@ -22064,6 +22065,10 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2206422065
B = B->getOperand(0);
2206522066
EVT MulSrcVT = A.getValueType();
2206622067

22068+
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22069+
A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22070+
B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22071+
2206722072
// Dot products operate on chunks of four elements so there must be four times
2206822073
// as many elements in the wide type
2206922074
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
@@ -22072,17 +22077,17 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2207222077
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
2207322078
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
2207422079
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22075-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22080+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2207622081

2207722082
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
2207822083
if (AIsSigned != BIsSigned) {
2207922084
if (!Subtarget->hasMatMulInt8())
22080-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22085+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2208122086

2208222087
bool Scalable = AccVT.isScalableVT();
2208322088
// There's no nxv2i64 version of usdot
2208422089
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22085-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22090+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2208622091

2208722092
if (!BIsSigned)
2208822093
std::swap(A, B);
@@ -22109,32 +22114,37 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2210922114
}
2211022115

2211122116
if (A.getValueType() != B.getValueType())
22112-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22117+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2211322118

2211422119
unsigned NewOpcode =
2211522120
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
2211622121
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
2211722122
}
2211822123

22119-
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22124+
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22125+
SelectionDAG &DAG,
2212022126
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2212122127
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22122-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22123-
unsigned InputOpcode = Input->getOpcode();
22124-
if (!ISD::isExtOpcode(InputOpcode))
22125-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22126-
Input = Input->getOperand(0);
22127-
EVT InputVT = Input.getValueType();
22128+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22129+
unsigned Input1Opcode = Input1->getOpcode();
22130+
if (!ISD::isExtOpcode(Input1Opcode))
22131+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22132+
2212822133
EVT AccVT = Acc->getValueType(0);
22134+
Input1 = Input1->getOperand(0);
22135+
EVT InputVT = Input1.getValueType();
22136+
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
22137+
SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
22138+
2212922139
if (!AccVT.isScalableVector())
22130-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22140+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2213122141

2213222142
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
2213322143
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
2213422144
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2213522145
return SDValue();
2213622146

22137-
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22147+
unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
2213822148
? ISD::PARTIAL_REDUCE_SMLA
2213922149
: ISD::PARTIAL_REDUCE_UMLA;
2214022150
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
@@ -22145,18 +22155,21 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
2214522155
const AArch64Subtarget *Subtarget) {
2214622156
SDLoc DL(N);
2214722157
auto Acc = N->getOperand(0);
22148-
auto Input = N->getOperand(1);
22158+
auto Input1 = N->getOperand(1);
22159+
auto Input2 = N->getOperand(2);
2214922160
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22150-
EVT InputElemVT = Input.getValueType().getVectorElementType();
22161+
EVT InputElemVT = Input1.getValueType().getVectorElementType();
2215122162

2215222163
// If the exts have already been removed or it has already been lowered to an
2215322164
// usdot instruction, then the element types will not be equal
22154-
if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
22165+
if (InputElemVT != AccElemVT || Input1.getOpcode() == AArch64ISD::USDOT)
2215522166
return SDValue(N, 0);
2215622167

22157-
if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
22168+
if (auto Dot =
22169+
tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
2215822170
return Dot;
22159-
if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
22171+
if (auto WideAdd =
22172+
tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
2216022173
return WideAdd;
2216122174
return SDValue();
2216222175
}

0 commit comments

Comments
 (0)