@@ -21990,30 +21990,31 @@ static SDValue tryCombineWhileLo(SDNode *N,
21990
21990
return SDValue(N, 0);
21991
21991
}
21992
21992
21993
- SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
21993
+ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
21994
+ SelectionDAG &DAG,
21994
21995
const AArch64Subtarget *Subtarget, SDLoc &DL) {
21995
21996
bool Scalable = Acc.getValueType().isScalableVector();
21996
21997
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21997
21998
return SDValue();
21998
21999
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
21999
22000
return SDValue();
22000
22001
22001
- unsigned InputOpcode = Input ->getOpcode();
22002
+ unsigned Input1Opcode = Input1 ->getOpcode();
22002
22003
EVT AccVT = Acc->getValueType(0);
22003
22004
if (AccVT.getVectorElementCount() * 4 ==
22004
- Input ->getValueType(0).getVectorElementCount() &&
22005
- InputOpcode != ISD::MUL)
22006
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22007
- if (InputOpcode != ISD::MUL)
22005
+ Input1 ->getValueType(0).getVectorElementCount() &&
22006
+ Input1Opcode != ISD::MUL)
22007
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22008
+ if (Input1Opcode != ISD::MUL)
22008
22009
return SDValue();
22009
22010
22010
- auto A = Input ->getOperand(0);
22011
- auto B = Input ->getOperand(1);
22011
+ auto A = Input1 ->getOperand(0);
22012
+ auto B = Input1 ->getOperand(1);
22012
22013
unsigned AOpcode = A->getOpcode();
22013
22014
unsigned BOpcode = B->getOpcode();
22014
22015
22015
22016
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22016
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22017
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22017
22018
22018
22019
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22019
22020
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
@@ -22022,6 +22023,10 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22022
22023
B = B->getOperand(0);
22023
22024
EVT MulSrcVT = A.getValueType();
22024
22025
22026
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22027
+ A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22028
+ B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22029
+
22025
22030
// Dot products operate on chunks of four elements so there must be four times
22026
22031
// as many elements in the wide type
22027
22032
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
@@ -22030,17 +22035,17 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22030
22035
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22031
22036
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22032
22037
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22033
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22038
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22034
22039
22035
22040
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22036
22041
if (AIsSigned != BIsSigned) {
22037
22042
if (!Subtarget->hasMatMulInt8())
22038
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22043
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22039
22044
22040
22045
bool Scalable = AccVT.isScalableVT();
22041
22046
// There's no nxv2i64 version of usdot
22042
22047
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22043
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22048
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22044
22049
22045
22050
if (!BIsSigned)
22046
22051
std::swap(A, B);
@@ -22067,32 +22072,37 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22067
22072
}
22068
22073
22069
22074
if (A.getValueType() != B.getValueType())
22070
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22075
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22071
22076
22072
22077
unsigned NewOpcode =
22073
22078
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22074
22079
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
22075
22080
}
22076
22081
22077
- SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22082
+ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22083
+ SelectionDAG &DAG,
22078
22084
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22079
22085
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22080
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
22081
- unsigned InputOpcode = Input->getOpcode();
22082
- if (!ISD::isExtOpcode(InputOpcode))
22083
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
22084
- Input = Input->getOperand(0);
22085
- EVT InputVT = Input.getValueType();
22086
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22087
+ unsigned Input1Opcode = Input1->getOpcode();
22088
+ if (!ISD::isExtOpcode(Input1Opcode))
22089
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22090
+
22086
22091
EVT AccVT = Acc->getValueType(0);
22092
+ Input1 = Input1->getOperand(0);
22093
+ EVT InputVT = Input1.getValueType();
22094
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
22095
+ SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
22096
+
22087
22097
if (!AccVT.isScalableVector())
22088
- return DAG.expandPartialReduceAdd(DL, Acc, Input );
22098
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22089
22099
22090
22100
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22091
22101
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22092
22102
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22093
22103
return SDValue();
22094
22104
22095
- unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22105
+ unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
22096
22106
? ISD::PARTIAL_REDUCE_SMLA
22097
22107
: ISD::PARTIAL_REDUCE_UMLA;
22098
22108
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
@@ -22103,18 +22113,21 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22103
22113
const AArch64Subtarget *Subtarget) {
22104
22114
SDLoc DL(N);
22105
22115
auto Acc = N->getOperand(0);
22106
- auto Input = N->getOperand(1);
22116
+ auto Input1 = N->getOperand(1);
22117
+ auto Input2 = N->getOperand(2);
22107
22118
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22108
- EVT InputElemVT = Input .getValueType().getVectorElementType();
22119
+ EVT InputElemVT = Input1 .getValueType().getVectorElementType();
22109
22120
22110
22121
// If the exts have already been removed or it has already been lowered to an
22111
22122
// usdot instruction, then the element types will not be equal
22112
- if (InputElemVT != AccElemVT || Input .getOpcode() == AArch64ISD::USDOT)
22123
+ if (InputElemVT != AccElemVT || Input1 .getOpcode() == AArch64ISD::USDOT)
22113
22124
return SDValue(N, 0);
22114
22125
22115
- if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
22126
+ if (auto Dot =
22127
+ tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
22116
22128
return Dot;
22117
- if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
22129
+ if (auto WideAdd =
22130
+ tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
22118
22131
return WideAdd;
22119
22132
return SDValue();
22120
22133
}
0 commit comments