Skip to content

Commit 5bc3fcc

Browse files
MUL instructions now included in DAG combines.
1 parent b43db72 commit 5bc3fcc

File tree

4 files changed

+62
-37
lines changed

4 files changed

+62
-37
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,13 @@ class SelectionDAG {
16061606
/// \p Op1 Accumulator for where the result is stored for the partial
16071607
/// reduction operation.
16081608
/// \p Op2 Input for the partial reduction operation.
1609-
SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
1609+
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
1610+
/// \p Acc Accumulator for where the result is stored for the partial
1611+
/// reduction operation.
1612+
/// \p Input1 First input for the partial reduction operation.
1613+
/// \p Input2 Second input for the partial reduction operation.
1614+
SDValue expandPartialReduceAdd(SDLoc DL, SDValue Acc, SDValue Input1,
1615+
SDValue Input2);
16101616

16111617
/// Expands a node with multiple results to an FP or vector libcall. The
16121618
/// libcall is expected to take all the operands of the \p Node followed by

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,20 +2467,24 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
24672467
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
24682468
}
24692469

2470-
SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
2471-
SDValue Op2) {
2472-
EVT ReducedTy = Op1.getValueType();
2473-
EVT FullTy = Op2.getValueType();
2470+
SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Acc,
2471+
SDValue Input1, SDValue Input2) {
2472+
2473+
EVT FullTy = Input1.getValueType();
2474+
Input2 = getAnyExtOrTrunc(Input2, DL, FullTy);
2475+
SDValue Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
2476+
2477+
EVT ReducedTy = Acc.getValueType();
24742478

24752479
unsigned Stride = ReducedTy.getVectorMinNumElements();
24762480
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
24772481

24782482
// Collect all of the subvectors
2479-
std::deque<SDValue> Subvectors = {Op1};
2483+
std::deque<SDValue> Subvectors = {Acc};
24802484
for (unsigned I = 0; I < ScaleFactor; I++) {
24812485
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
24822486
Subvectors.push_back(
2483-
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
2487+
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Input, SourceIndex}));
24842488
}
24852489

24862490
// Flatten the subvector tree

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8141,11 +8141,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81418141
SDValue Input = getValue(I.getOperand(1));
81428142

81438143
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8144-
setValue(&I,
8145-
DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input));
8144+
setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
8145+
DAG.getConstant(1, dl, Input.getValueType())));
81468146
return;
81478147
}
8148-
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
8148+
setValue(&I,
8149+
DAG.expandPartialReduceAdd(
8150+
dl, Acc, Input, DAG.getConstant(1, dl, Input.getValueType())));
81498151
return;
81508152
}
81518153
case Intrinsic::experimental_cttz_elts: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21990,30 +21990,31 @@ static SDValue tryCombineWhileLo(SDNode *N,
2199021990
return SDValue(N, 0);
2199121991
}
2199221992

21993-
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
21993+
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
21994+
SelectionDAG &DAG,
2199421995
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2199521996
bool Scalable = Acc.getValueType().isScalableVector();
2199621997
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
2199721998
return SDValue();
2199821999
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
2199922000
return SDValue();
2200022001

22001-
unsigned InputOpcode = Input->getOpcode();
22002+
unsigned Input1Opcode = Input1->getOpcode();
2200222003
EVT AccVT = Acc->getValueType(0);
2200322004
if (AccVT.getVectorElementCount() * 4 ==
22004-
Input->getValueType(0).getVectorElementCount() &&
22005-
InputOpcode != ISD::MUL)
22006-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22007-
if (InputOpcode != ISD::MUL)
22005+
Input1->getValueType(0).getVectorElementCount() &&
22006+
Input1Opcode != ISD::MUL)
22007+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22008+
if (Input1Opcode != ISD::MUL)
2200822009
return SDValue();
2200922010

22010-
auto A = Input->getOperand(0);
22011-
auto B = Input->getOperand(1);
22011+
auto A = Input1->getOperand(0);
22012+
auto B = Input1->getOperand(1);
2201222013
unsigned AOpcode = A->getOpcode();
2201322014
unsigned BOpcode = B->getOpcode();
2201422015

2201522016
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22016-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22017+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2201722018

2201822019
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
2201922020
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
@@ -22022,6 +22023,10 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2202222023
B = B->getOperand(0);
2202322024
EVT MulSrcVT = A.getValueType();
2202422025

22026+
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22027+
A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22028+
B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22029+
2202522030
// Dot products operate on chunks of four elements so there must be four times
2202622031
// as many elements in the wide type
2202722032
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
@@ -22030,17 +22035,17 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2203022035
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
2203122036
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
2203222037
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22033-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22038+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2203422039

2203522040
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
2203622041
if (AIsSigned != BIsSigned) {
2203722042
if (!Subtarget->hasMatMulInt8())
22038-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22043+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2203922044

2204022045
bool Scalable = AccVT.isScalableVT();
2204122046
// There's no nxv2i64 version of usdot
2204222047
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22043-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22048+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2204422049

2204522050
if (!BIsSigned)
2204622051
std::swap(A, B);
@@ -22067,32 +22072,37 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
2206722072
}
2206822073

2206922074
if (A.getValueType() != B.getValueType())
22070-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22075+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2207122076

2207222077
unsigned NewOpcode =
2207322078
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
2207422079
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
2207522080
}
2207622081

22077-
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22082+
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22083+
SelectionDAG &DAG,
2207822084
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2207922085
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22080-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22081-
unsigned InputOpcode = Input->getOpcode();
22082-
if (!ISD::isExtOpcode(InputOpcode))
22083-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22084-
Input = Input->getOperand(0);
22085-
EVT InputVT = Input.getValueType();
22086+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22087+
unsigned Input1Opcode = Input1->getOpcode();
22088+
if (!ISD::isExtOpcode(Input1Opcode))
22089+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22090+
2208622091
EVT AccVT = Acc->getValueType(0);
22092+
Input1 = Input1->getOperand(0);
22093+
EVT InputVT = Input1.getValueType();
22094+
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
22095+
SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
22096+
2208722097
if (!AccVT.isScalableVector())
22088-
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22098+
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
2208922099

2209022100
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
2209122101
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
2209222102
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2209322103
return SDValue();
2209422104

22095-
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22105+
unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
2209622106
? ISD::PARTIAL_REDUCE_SMLA
2209722107
: ISD::PARTIAL_REDUCE_UMLA;
2209822108
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
@@ -22103,18 +22113,21 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
2210322113
const AArch64Subtarget *Subtarget) {
2210422114
SDLoc DL(N);
2210522115
auto Acc = N->getOperand(0);
22106-
auto Input = N->getOperand(1);
22116+
auto Input1 = N->getOperand(1);
22117+
auto Input2 = N->getOperand(2);
2210722118
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22108-
EVT InputElemVT = Input.getValueType().getVectorElementType();
22119+
EVT InputElemVT = Input1.getValueType().getVectorElementType();
2210922120

2211022121
// If the exts have already been removed or it has already been lowered to an
2211122122
// usdot instruction, then the element types will not be equal
22112-
if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
22123+
if (InputElemVT != AccElemVT || Input1.getOpcode() == AArch64ISD::USDOT)
2211322124
return SDValue(N, 0);
2211422125

22115-
if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
22126+
if (auto Dot =
22127+
tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
2211622128
return Dot;
22117-
if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
22129+
if (auto WideAdd =
22130+
tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
2211822131
return WideAdd;
2211922132
return SDValue();
2212022133
}

0 commit comments

Comments
 (0)