@@ -2041,9 +2041,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
2041
2041
return true;
2042
2042
2043
2043
EVT VT = EVT::getEVT(I->getType());
2044
- return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2045
- VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
2046
- VT != MVT::v2i32 && VT != MVT::v8i16;
2044
+ auto Op1 = I->getOperand(1);
2045
+ EVT Op1VT = EVT::getEVT(Op1->getType());
2046
+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2047
+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2048
+ VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2049
+ return false;
2050
+ return true;
2047
2051
}
2048
2052
2049
2053
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21793,36 +21797,34 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
21793
21797
Intrinsic::experimental_vector_partial_reduce_add &&
21794
21798
"Expected a partial reduction node");
21795
21799
21796
- bool Scalable = N->getValueType(0).isScalableVector();
21797
- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21800
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
21798
21801
return SDValue();
21799
21802
21800
21803
SDLoc DL(N);
21801
21804
21802
- auto Accumulator = N->getOperand(1);
21805
+ auto Acc = N->getOperand(1);
21803
21806
auto ExtInput = N->getOperand(2);
21804
21807
21805
- EVT AccumulatorType = Accumulator .getValueType();
21806
- EVT AccumulatorElementType = AccumulatorType .getVectorElementType();
21808
+ EVT AccVT = Acc .getValueType();
21809
+ EVT AccElemVT = AccVT .getVectorElementType();
21807
21810
21808
- if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType )
21811
+ if (ExtInput.getValueType().getVectorElementType() != AccElemVT )
21809
21812
return SDValue();
21810
21813
21811
21814
unsigned ExtInputOpcode = ExtInput->getOpcode();
21812
21815
if (!ISD::isExtOpcode(ExtInputOpcode))
21813
21816
return SDValue();
21814
21817
21815
21818
auto Input = ExtInput->getOperand(0);
21816
- EVT InputType = Input.getValueType();
21819
+ EVT InputVT = Input.getValueType();
21817
21820
21818
21821
// To do this transformation, output element size needs to be double input
21819
21822
// element size, and output number of elements needs to be half the input
21820
21823
// number of elements
21821
- if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
21822
- AccumulatorElementType.getSizeInBits()) ||
21823
- !(AccumulatorType.getVectorElementCount() * 2 ==
21824
- InputType.getVectorElementCount()) ||
21825
- !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
21824
+ if (InputVT.getVectorElementType().getSizeInBits() * 2 !=
21825
+ AccElemVT.getSizeInBits() ||
21826
+ AccVT.getVectorElementCount() * 2 != InputVT.getVectorElementCount() ||
21827
+ AccVT.isScalableVector() != InputVT.isScalableVector())
21826
21828
return SDValue();
21827
21829
21828
21830
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
@@ -21831,13 +21833,12 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
21831
21833
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
21832
21834
: Intrinsic::aarch64_sve_uaddwt;
21833
21835
21834
- auto BottomID =
21835
- DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
21836
- auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
21837
- BottomID, Accumulator, Input);
21838
- auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
21839
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
21840
- BottomNode, Input);
21836
+ auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
21837
+ auto BottomNode =
21838
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
21839
+ auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
21840
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
21841
+ Input);
21841
21842
}
21842
21843
21843
21844
static SDValue performIntrinsicCombine(SDNode *N,
0 commit comments