@@ -2049,28 +2049,6 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
2049
2049
return false;
2050
2050
}
2051
2051
2052
- bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
2053
- const IntrinsicInst *I) const {
2054
- if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
2055
- return true;
2056
-
2057
- EVT VT = EVT::getEVT(I->getType());
2058
- auto Input = I->getOperand(1);
2059
- EVT InputVT = EVT::getEVT(Input->getType());
2060
-
2061
- if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2062
- (InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2063
- (InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2064
- (InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2065
- (InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2066
- (InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2067
- (InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
2068
- (InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
2069
- (InputVT == MVT::v8i32 && VT == MVT::v2i32))
2070
- return false;
2071
- return true;
2072
- }
2073
-
2074
2052
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
2075
2053
if (!Subtarget->isSVEorStreamingSVEAvailable())
2076
2054
return true;
@@ -22037,9 +22015,9 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22037
22015
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22038
22016
bool Scalable = Op0->getValueType(0).isScalableVector();
22039
22017
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22040
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22018
+ return SDValue( );
22041
22019
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22042
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22020
+ return SDValue( );
22043
22021
22044
22022
unsigned Op1Opcode = Op1->getOpcode();
22045
22023
SDValue MulOpLHS, MulOpRHS;
@@ -22056,7 +22034,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22056
22034
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
22057
22035
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
22058
22036
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
22059
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22037
+ return SDValue( );
22060
22038
22061
22039
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
22062
22040
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -22066,7 +22044,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22066
22044
EVT MulOpLHSVT = MulOpLHS.getValueType();
22067
22045
22068
22046
if (MulOpLHSVT != MulOpRHS.getValueType())
22069
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22047
+ return SDValue( );
22070
22048
22071
22049
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
22072
22050
MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
@@ -22092,12 +22070,12 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22092
22070
unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22093
22071
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
22094
22072
if (!Subtarget->hasMatMulInt8())
22095
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22073
+ return SDValue( );
22096
22074
22097
22075
bool Scalable = ReducedVT.isScalableVT();
22098
22076
// There's no nxv2i64 version of usdot
22099
22077
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22100
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22078
+ return SDValue( );
22101
22079
22102
22080
if (!MulOpRHSIsSigned)
22103
22081
std::swap(MulOpLHS, MulOpRHS);
@@ -22134,10 +22112,10 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22134
22112
SelectionDAG &DAG,
22135
22113
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22136
22114
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22137
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22115
+ return SDValue( );
22138
22116
unsigned Op1Opcode = Op1->getOpcode();
22139
22117
if (!ISD::isExtOpcode(Op1Opcode))
22140
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22118
+ return SDValue( );
22141
22119
22142
22120
EVT AccVT = Op0->getValueType(0);
22143
22121
Op1 = Op1->getOperand(0);
@@ -22146,7 +22124,7 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
22146
22124
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
22147
22125
22148
22126
if (!AccVT.isScalableVector())
22149
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22127
+ return SDValue( );
22150
22128
22151
22129
if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22152
22130
!(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -22177,7 +22155,10 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22177
22155
return Dot;
22178
22156
if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
22179
22157
return WideAdd;
22180
- return SDValue();
22158
+ // N->getOperand needs calling again because the Op variables may have been
22159
+ // changed by the functions above
22160
+ return DAG.expandPartialReduceAdd(DL, N->getOperand(0), N->getOperand(1),
22161
+ N->getOperand(2));
22181
22162
}
22182
22163
22183
22164
static SDValue performIntrinsicCombine(SDNode *N,
0 commit comments