Skip to content

Commit 4cd617a

Browse files
Address comments
Move the getPartialReduceAdd function around. Make the new codepath work for fixed length NEON vectors too.
1 parent c036082 commit 4cd617a

File tree

9 files changed

+365
-68
lines changed

9 files changed

+365
-68
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,11 +1607,6 @@ class SelectionDAG {
16071607
/// the target's desired shift amount type.
16081608
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
16091609

1610-
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
1611-
/// its operands and ReducedTY is the intrinsic's return type.
1612-
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
1613-
SDValue Op2);
1614-
16151610
/// Expands a node with multiple results to an FP or vector libcall. The
16161611
/// libcall is expected to take all the operands of the \p Node followed by
16171612
/// output pointers for each of the results. \p CallRetResNo can be optionally

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5636,7 +5636,14 @@ class TargetLowering : public TargetLoweringBase {
56365636

56375637
// Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
56385638
// consisting of zext/sext, extract_subvector, mul and add operations.
5639-
SDValue expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const;
5639+
SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
5640+
SDValue Input2, SelectionDAG &DAG) const;
5641+
5642+
// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
5643+
// its operands and ReducedTY is the return type.
5644+
static SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
5645+
SDValue Op1, SDValue Op2,
5646+
SelectionDAG &DAG);
56405647

56415648
private:
56425649
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,9 +1198,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
11981198
Results.push_back(TLI.expandVecReduce(Node, DAG));
11991199
return;
12001200
case ISD::PARTIAL_REDUCE_UMLA:
1201-
case ISD::PARTIAL_REDUCE_SMLA:
1202-
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
1201+
case ISD::PARTIAL_REDUCE_SMLA: {
1202+
SDLoc DL(Node);
1203+
Results.push_back(TLI.expandPartialReduceMLA(DL, Node->getOperand(0),
1204+
Node->getOperand(1),
1205+
Node->getOperand(2), DAG));
12031206
return;
1207+
}
12041208
case ISD::VECREDUCE_SEQ_FADD:
12051209
case ISD::VECREDUCE_SEQ_FMUL:
12061210
Results.push_back(TLI.expandVecReduceSeq(Node, DAG));

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3186,7 +3186,9 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
31863186
}
31873187

31883188
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
3189-
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
3189+
SDLoc DL(N);
3190+
SDValue Res = TLI.expandPartialReduceMLA(
3191+
DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
31903192
ReplaceValueWith(SDValue(N, 0), Res);
31913193
}
31923194

@@ -4447,7 +4449,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
44474449
}
44484450

44494451
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
4450-
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
4452+
SDLoc DL(N);
4453+
SDValue Res = TLI.expandPartialReduceMLA(
4454+
DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
44514455
ReplaceValueWith(SDValue(N, 0), Res);
44524456
return SDValue();
44534457
}

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,35 +2474,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
24742474
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
24752475
}
24762476

2477-
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
2478-
SDValue Op2) {
2479-
EVT FullTy = Op2.getValueType();
2480-
2481-
unsigned Stride = ReducedTy.getVectorMinNumElements();
2482-
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
2483-
2484-
// Collect all of the subvectors
2485-
std::deque<SDValue> Subvectors = {Op1};
2486-
for (unsigned I = 0; I < ScaleFactor; I++) {
2487-
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
2488-
Subvectors.push_back(
2489-
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
2490-
}
2491-
2492-
// Flatten the subvector tree
2493-
while (Subvectors.size() > 1) {
2494-
Subvectors.push_back(
2495-
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
2496-
Subvectors.pop_front();
2497-
Subvectors.pop_front();
2498-
}
2499-
2500-
assert(Subvectors.size() == 1 &&
2501-
"There should only be one subvector after tree flattening");
2502-
2503-
return Subvectors[0];
2504-
}
2505-
25062477
/// Given a store node \p StoreNode, return true if it is safe to fold that node
25072478
/// into \p FPNode, which expands to a library call with output pointers.
25082479
static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode,

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8125,21 +8125,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81258125
return;
81268126
}
81278127
case Intrinsic::experimental_vector_partial_reduce_add: {
8128+
SDValue Acc = getValue(I.getOperand(0));
8129+
EVT AccVT = Acc.getValueType();
8130+
SDValue Input = getValue(I.getOperand(1));
8131+
EVT InputVT = Input.getValueType();
8132+
8133+
assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
8134+
"Expected operands to have the same vector element type!");
8135+
assert(InputVT.getVectorElementCount().getKnownMinValue() %
8136+
AccVT.getVectorElementCount().getKnownMinValue() ==
8137+
0 &&
8138+
"Expected the element count of the Input operand to be a positive "
8139+
"integer multiple of the element count of the Accumulator operand!");
81288140
if (NewPartialReduceLowering) {
8129-
SDValue Acc = getValue(I.getOperand(0));
8130-
EVT AccVT = Acc.getValueType();
8131-
SDValue Input = getValue(I.getOperand(1));
8132-
EVT InputVT = Input.getValueType();
8133-
8134-
assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
8135-
"Expected operands to have the same vector element type!");
8136-
assert(
8137-
InputVT.getVectorElementCount().getKnownMinValue() %
8138-
AccVT.getVectorElementCount().getKnownMinValue() ==
8139-
0 &&
8140-
"Expected the element count of the Input operand to be a positive "
8141-
"integer multiple of the element count of the Accumulator operand!");
8142-
81438141
// ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
81448142
// same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
81458143
// changed to its correct signedness when combining or expanding,
@@ -8154,9 +8152,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81548152
return;
81558153
}
81568154

8157-
setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
8158-
getValue(I.getOperand(0)),
8159-
getValue(I.getOperand(1))));
8155+
setValue(&I, TLI.expandPartialReduceMLA(
8156+
sdl, Acc, Input, DAG.getConstant(1, sdl, InputVT), DAG));
81608157
return;
81618158
}
81628159
case Intrinsic::experimental_cttz_elts: {

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/Support/MathExtras.h"
3535
#include "llvm/Target/TargetMachine.h"
3636
#include <cctype>
37+
#include <deque>
3738
using namespace llvm;
3839

3940
/// NOTE: The TargetMachine owns TLOF.
@@ -12189,20 +12190,15 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
1218912190
return Load;
1219012191
}
1219112192

12192-
SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
12193+
SDValue TargetLowering::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
12194+
SDValue Input1, SDValue Input2,
1219312195
SelectionDAG &DAG) const {
12194-
SDLoc DL(N);
12195-
SDValue Acc = N->getOperand(0);
12196-
SDValue Input1 = N->getOperand(1);
12197-
SDValue Input2 = N->getOperand(2);
12198-
1219912196
EVT ReducedTy = Acc.getValueType();
1220012197
EVT FullTy = Input1.getValueType();
1220112198

1220212199
auto ExtendToAccEltVT = [&](SDValue V) {
12203-
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
12204-
? ISD::ZERO_EXTEND
12205-
: ISD::SIGN_EXTEND;
12200+
unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
12201+
: ISD::ZERO_EXTEND;
1220612202
EVT ExtVT = V.getValueType().changeVectorElementType(
1220712203
Acc.getValueType().getVectorElementType());
1220812204
if (ExtVT != FullTy)
@@ -12224,5 +12220,34 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1222412220
Input = ExtendToAccEltVT(Input1);
1222512221
}
1222612222

12227-
return DAG.getPartialReduceAdd(DL, ReducedTy, Acc, Input);
12223+
return TargetLowering::getPartialReduceAdd(DL, ReducedTy, FullTy, Acc, Input,
12224+
DAG);
12225+
}
12226+
12227+
SDValue TargetLowering::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
12228+
SDValue Op1, SDValue Op2,
12229+
SelectionDAG &DAG) {
12230+
unsigned Stride = ReducedTy.getVectorMinNumElements();
12231+
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
12232+
12233+
// Collect all of the subvectors
12234+
std::deque<SDValue> Subvectors = {Op1};
12235+
for (unsigned I = 0; I < ScaleFactor; I++) {
12236+
auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
12237+
Subvectors.push_back(
12238+
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
12239+
}
12240+
12241+
// Flatten the subvector tree
12242+
while (Subvectors.size() > 1) {
12243+
Subvectors.push_back(
12244+
DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
12245+
Subvectors.pop_front();
12246+
Subvectors.pop_front();
12247+
}
12248+
12249+
assert(Subvectors.size() == 1 &&
12250+
"There should only be one subvector after tree flattening");
12251+
12252+
return Subvectors[0];
1222812253
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
13581358
setOperationAction(ISD::BSWAP, VT, Expand);
13591359
setOperationAction(ISD::CTTZ, VT, Expand);
13601360

1361+
setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
1362+
setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
1363+
13611364
for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
13621365
setTruncStoreAction(VT, InnerVT, Expand);
13631366
setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
@@ -22015,8 +22018,10 @@ static SDValue performIntrinsicCombine(SDNode *N,
2201522018
return Dot;
2201622019
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
2201722020
return WideAdd;
22018-
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
22019-
N->getOperand(1), N->getOperand(2));
22021+
SDValue Input = N->getOperand(2);
22022+
return TargetLowering::getPartialReduceAdd(SDLoc(N), N->getValueType(0),
22023+
Input.getValueType(),
22024+
N->getOperand(1), Input, DAG);
2202022025
}
2202122026
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2202222027
case Intrinsic::aarch64_neon_vcvtfxu2fp:

0 commit comments

Comments
 (0)