Skip to content

Commit 3cdf92f

Browse files
Address comments on patch. Remove shouldExpandPartialReductionIntrinsic().
1 parent 9971a6e commit 3cdf92f

File tree

6 files changed

+41
-58
lines changed

6 files changed

+41
-58
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,9 +1451,18 @@ enum NodeType {
14511451
VECREDUCE_UMAX,
14521452
VECREDUCE_UMIN,
14531453

1454-
// Nodes used to represent a partial reduction addition operation (signed and
1455-
// unsigned).
1456-
// Operands: Accumulator, Input
1454+
// Partial Reduction nodes. These represent multiply-add instructions because
1455+
// Input1 and Input2 are multiplied together first. This result is then
1456+
// reduced, by addition, to the number of elements that the Accumulator's type
1457+
// has.
1458+
// Input1 and Input2 must be the same type. Accumulator's element type must
1459+
// match that of Input1 and Input2. The number of elements in Input1 and
1460+
// Input2 must be a positive integer multiple of the number of elements in the
1461+
// Accumulator.
1462+
// The signedness of this node will dictate the signedness of nodes expanded
1463+
// from it. The signedness of the node is dictated by the signedness of
1464+
// Input1.
1465+
// Operands: Accumulator, Input1, Input2
14571466
// Outputs: Output
14581467
PARTIAL_REDUCE_SMLA,
14591468
PARTIAL_REDUCE_UMLA,

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,10 +1604,6 @@ class SelectionDAG {
16041604
/// the target's desired shift amount type.
16051605
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
16061606

1607-
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
1608-
/// \p Op1 Accumulator for where the result is stored for the partial
1609-
/// reduction operation.
1610-
/// \p Op2 Input for the partial reduction operation.
16111607
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
16121608
/// \p Acc Accumulator for where the result is stored for the partial
16131609
/// reduction operation.

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,6 @@ class TargetLoweringBase {
455455
return true;
456456
}
457457

458-
/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
459-
/// should be expanded using generic code in SelectionDAGBuilder.
460-
virtual bool
461-
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
462-
return true;
463-
}
464-
465458
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
466459
/// using generic code in SelectionDAGBuilder.
467460
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8122,15 +8122,22 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81228122
SDValue Acc = getValue(I.getOperand(0));
81238123
EVT AccVT = Acc.getValueType();
81248124
SDValue Input = getValue(I.getOperand(1));
8125-
8126-
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8127-
setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
8128-
DAG.getConstant(1, dl, Input.getValueType())));
8129-
return;
8130-
}
8131-
setValue(&I,
8132-
DAG.expandPartialReduceAdd(
8133-
dl, Acc, Input, DAG.getConstant(1, dl, Input.getValueType())));
8125+
EVT InputVT = Input.getValueType();
8126+
8127+
assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
8128+
"Expected operands to have the same vector element type!");
8129+
assert(InputVT.getVectorElementCount().getKnownMinValue() %
8130+
AccVT.getVectorElementCount().getKnownMinValue() ==
8131+
0 &&
8132+
"Expected the element count of the Input operand to be a positive "
8133+
"integer multiple of the element count of the Accumulator operand!");
8134+
8135+
// ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
8136+
// same if ISD::PARTIAL_REDUCE_SMLA was used instead. It should be changed
8137+
// to its correct signedness when combining or expanding, according to
8138+
// extends being performed on Input.
8139+
setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
8140+
DAG.getConstant(1, dl, InputVT)));
81348141
return;
81358142
}
81368143
case Intrinsic::experimental_cttz_elts: {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,28 +2049,6 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
20492049
return false;
20502050
}
20512051

2052-
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
2053-
const IntrinsicInst *I) const {
2054-
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
2055-
return true;
2056-
2057-
EVT VT = EVT::getEVT(I->getType());
2058-
auto Input = I->getOperand(1);
2059-
EVT InputVT = EVT::getEVT(Input->getType());
2060-
2061-
if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2062-
(InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2063-
(InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2064-
(InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2065-
(InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2066-
(InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2067-
(InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
2068-
(InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
2069-
(InputVT == MVT::v8i32 && VT == MVT::v2i32))
2070-
return false;
2071-
return true;
2072-
}
2073-
20742052
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
20752053
if (!Subtarget->isSVEorStreamingSVEAvailable())
20762054
return true;
@@ -22037,9 +22015,9 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2203722015
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2203822016
bool Scalable = Op0->getValueType(0).isScalableVector();
2203922017
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22040-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22018+
return SDValue();
2204122019
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22042-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22020+
return SDValue();
2204322021

2204422022
unsigned Op1Opcode = Op1->getOpcode();
2204522023
SDValue MulOpLHS, MulOpRHS;
@@ -22056,7 +22034,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2205622034
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
2205722035
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
2205822036
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
22059-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22037+
return SDValue();
2206022038

2206122039
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
2206222040
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -22066,7 +22044,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2206622044
EVT MulOpLHSVT = MulOpLHS.getValueType();
2206722045

2206822046
if (MulOpLHSVT != MulOpRHS.getValueType())
22069-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22047+
return SDValue();
2207022048

2207122049
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
2207222050
MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
@@ -22092,12 +22070,12 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2209222070
unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
2209322071
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
2209422072
if (!Subtarget->hasMatMulInt8())
22095-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22073+
return SDValue();
2209622074

2209722075
bool Scalable = ReducedVT.isScalableVT();
2209822076
// There's no nxv2i64 version of usdot
2209922077
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22100-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22078+
return SDValue();
2210122079

2210222080
if (!MulOpRHSIsSigned)
2210322081
std::swap(MulOpLHS, MulOpRHS);
@@ -22134,10 +22112,10 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2213422112
SelectionDAG &DAG,
2213522113
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2213622114
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22137-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22115+
return SDValue();
2213822116
unsigned Op1Opcode = Op1->getOpcode();
2213922117
if (!ISD::isExtOpcode(Op1Opcode))
22140-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22118+
return SDValue();
2214122119

2214222120
EVT AccVT = Op0->getValueType(0);
2214322121
Op1 = Op1->getOperand(0);
@@ -22146,7 +22124,7 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2214622124
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
2214722125

2214822126
if (!AccVT.isScalableVector())
22149-
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22127+
return SDValue();
2215022128

2215122129
if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
2215222130
!(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -22177,7 +22155,10 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
2217722155
return Dot;
2217822156
if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
2217922157
return WideAdd;
22180-
return SDValue();
22158+
// N->getOperand needs calling again because the Op variables may have been
22159+
// changed by the functions above
22160+
return DAG.expandPartialReduceAdd(DL, N->getOperand(0), N->getOperand(1),
22161+
N->getOperand(2));
2218122162
}
2218222163

2218322164
static SDValue performIntrinsicCombine(SDNode *N,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -993,9 +993,6 @@ class AArch64TargetLowering : public TargetLowering {
993993

994994
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
995995

996-
bool
997-
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
998-
999996
bool shouldExpandCttzElements(EVT VT) const override;
1000997

1001998
bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;

0 commit comments

Comments
 (0)