Skip to content

Commit 9cb5482

Browse files
Minor changes to previous patch
Rename variables, eliminate a redundant condition in an if statement and refactor code checking types
1 parent daa1cdb commit 9cb5482

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,9 +2041,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20412041
return true;
20422042

20432043
EVT VT = EVT::getEVT(I->getType());
2044-
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2045-
VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
2046-
VT != MVT::v2i32 && VT != MVT::v8i16;
2044+
auto Op1 = I->getOperand(1);
2045+
EVT Op1VT = EVT::getEVT(Op1->getType());
2046+
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2047+
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2048+
VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2049+
return false;
2050+
return true;
20472051
}
20482052

20492053
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21793,36 +21797,34 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2179321797
Intrinsic::experimental_vector_partial_reduce_add &&
2179421798
"Expected a partial reduction node");
2179521799

21796-
bool Scalable = N->getValueType(0).isScalableVector();
21797-
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21800+
if (!Subtarget->isSVEorStreamingSVEAvailable())
2179821801
return SDValue();
2179921802

2180021803
SDLoc DL(N);
2180121804

21802-
auto Accumulator = N->getOperand(1);
21805+
auto Acc = N->getOperand(1);
2180321806
auto ExtInput = N->getOperand(2);
2180421807

21805-
EVT AccumulatorType = Accumulator.getValueType();
21806-
EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
21808+
EVT AccVT = Acc.getValueType();
21809+
EVT AccElemVT = AccVT.getVectorElementType();
2180721810

21808-
if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
21811+
if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
2180921812
return SDValue();
2181021813

2181121814
unsigned ExtInputOpcode = ExtInput->getOpcode();
2181221815
if (!ISD::isExtOpcode(ExtInputOpcode))
2181321816
return SDValue();
2181421817

2181521818
auto Input = ExtInput->getOperand(0);
21816-
EVT InputType = Input.getValueType();
21819+
EVT InputVT = Input.getValueType();
2181721820

2181821821
// To do this transformation, output element size needs to be double input
2181921822
// element size, and output number of elements needs to be half the input
2182021823
// number of elements
21821-
if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
21822-
AccumulatorElementType.getSizeInBits()) ||
21823-
!(AccumulatorType.getVectorElementCount() * 2 ==
21824-
InputType.getVectorElementCount()) ||
21825-
!(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
21824+
if (InputVT.getVectorElementType().getSizeInBits() * 2 !=
21825+
AccElemVT.getSizeInBits() ||
21826+
AccVT.getVectorElementCount() * 2 != InputVT.getVectorElementCount() ||
21827+
AccVT.isScalableVector() != InputVT.isScalableVector())
2182621828
return SDValue();
2182721829

2182821830
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
@@ -21831,13 +21833,12 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2183121833
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
2183221834
: Intrinsic::aarch64_sve_uaddwt;
2183321835

21834-
auto BottomID =
21835-
DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
21836-
auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
21837-
BottomID, Accumulator, Input);
21838-
auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
21839-
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
21840-
BottomNode, Input);
21836+
auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
21837+
auto BottomNode =
21838+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
21839+
auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
21840+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
21841+
Input);
2184121842
}
2184221843

2184321844
static SDValue performIntrinsicCombine(SDNode *N,

0 commit comments

Comments
 (0)