@@ -22032,144 +22032,150 @@ static SDValue tryCombineWhileLo(SDNode *N,
22032
22032
return SDValue(N, 0);
22033
22033
}
22034
22034
22035
- SDValue tryCombineToDotProduct(SDValue &Acc , SDValue &Input1 , SDValue &Input2 ,
22035
+ SDValue tryCombineToDotProduct(SDValue &Op0 , SDValue &Op1 , SDValue &Op2 ,
22036
22036
SelectionDAG &DAG,
22037
22037
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22038
- bool Scalable = Acc. getValueType().isScalableVector();
22038
+ bool Scalable = Op0-> getValueType(0 ).isScalableVector();
22039
22039
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22040
- return SDValue( );
22040
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22041
22041
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22042
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22043
+
22044
+ unsigned Op1Opcode = Op1->getOpcode();
22045
+ SDValue MulOpLHS, MulOpRHS;
22046
+ bool MulOpLHSIsSigned, MulOpRHSIsSigned;
22047
+ if (ISD::isExtOpcode(Op1Opcode)) {
22048
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
22049
+ MulOpLHS = Op1->getOperand(0);
22050
+ MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
22051
+ } else if (Op1Opcode == ISD::MUL) {
22052
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
22053
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
22054
+
22055
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
22056
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
22057
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
22058
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
22059
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22060
+
22061
+ MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
22062
+ MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
22063
+
22064
+ MulOpLHS = ExtMulOpLHS->getOperand(0);
22065
+ MulOpRHS = ExtMulOpRHS->getOperand(0);
22066
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
22067
+
22068
+ if (MulOpLHSVT != MulOpRHS.getValueType())
22069
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22070
+
22071
+ Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
22072
+ MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
22073
+ MulOpRHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpRHS, Op2);
22074
+ } else
22042
22075
return SDValue();
22043
22076
22044
- unsigned Input1Opcode = Input1->getOpcode();
22045
- EVT AccVT = Acc->getValueType(0);
22046
- if (AccVT.getVectorElementCount() * 4 ==
22047
- Input1->getValueType(0).getVectorElementCount() &&
22048
- Input1Opcode != ISD::MUL)
22049
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22050
- if (Input1Opcode != ISD::MUL)
22051
- return SDValue();
22052
-
22053
- auto A = Input1->getOperand(0);
22054
- auto B = Input1->getOperand(1);
22055
- unsigned AOpcode = A->getOpcode();
22056
- unsigned BOpcode = B->getOpcode();
22057
-
22058
- if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22059
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22060
-
22061
- bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22062
- bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
22063
-
22064
- A = A->getOperand(0);
22065
- B = B->getOperand(0);
22066
- EVT MulSrcVT = A.getValueType();
22067
-
22068
- Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22069
- A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22070
- B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22077
+ SDValue Acc = Op0;
22078
+ EVT ReducedVT = Acc->getValueType(0);
22079
+ EVT MulSrcVT = MulOpLHS.getValueType();
22071
22080
22072
22081
// Dot products operate on chunks of four elements so there must be four times
22073
22082
// as many elements in the wide type
22074
- if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22075
- !(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22076
- !(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22077
- !(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22078
- !(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22079
- !(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22080
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22081
-
22082
- unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22083
- if (AIsSigned != BIsSigned) {
22083
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22084
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22085
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22086
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22087
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22088
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22089
+ return SDValue();
22090
+
22091
+ // If the extensions are mixed, we should lower it to a usdot instead
22092
+ unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22093
+ if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
22084
22094
if (!Subtarget->hasMatMulInt8())
22085
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22095
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22086
22096
22087
- bool Scalable = AccVT .isScalableVT();
22097
+ bool Scalable = ReducedVT .isScalableVT();
22088
22098
// There's no nxv2i64 version of usdot
22089
- if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22090
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22099
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22100
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22091
22101
22092
- if (!BIsSigned )
22093
- std::swap(A, B );
22102
+ if (!MulOpRHSIsSigned )
22103
+ std::swap(MulOpLHS, MulOpRHS );
22094
22104
DotOpcode = AArch64ISD::USDOT;
22095
22105
// Lower usdot patterns here because legalisation would attempt to split it
22096
22106
// unless exts are removed. But, removing the exts would lose the
22097
22107
// information about whether each operand is signed.
22098
- if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22099
- (AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22100
- return DAG.getNode(DotOpcode, DL, AccVT , Acc, A, B );
22108
+ if ((ReducedVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22109
+ (ReducedVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22110
+ return DAG.getNode(DotOpcode, DL, ReducedVT , Acc, MulOpLHS, MulOpRHS );
22101
22111
}
22102
22112
22103
22113
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22104
22114
// product followed by a zero / sign extension. Need to lower this here
22105
22115
// because legalisation would attempt to split it.
22106
- if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22107
- (AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22108
- EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22116
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22117
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22118
+ EVT ReducedVTI32 =
22119
+ (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22109
22120
22110
- auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
22111
- DAG.getConstant(0, DL, AccVTI32), A, B);
22112
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
22113
- return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
22121
+ SDValue DotI32 =
22122
+ DAG.getNode(DotOpcode, DL, ReducedVTI32,
22123
+ DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22124
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22125
+ return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22114
22126
}
22115
22127
22116
- if (A.getValueType() != B.getValueType())
22117
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22118
-
22119
22128
unsigned NewOpcode =
22120
- AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22121
- return DAG.getNode(NewOpcode, DL, AccVT , Acc, A, B );
22129
+ MulOpLHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22130
+ return DAG.getNode(NewOpcode, DL, ReducedVT , Acc, MulOpLHS, MulOpRHS );
22122
22131
}
22123
22132
22124
- SDValue tryCombineToWideAdd(SDValue &Acc , SDValue &Input1 , SDValue &Input2 ,
22133
+ SDValue tryCombineToWideAdd(SDValue &Op0 , SDValue &Op1 , SDValue &Op2 ,
22125
22134
SelectionDAG &DAG,
22126
22135
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22127
22136
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22128
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22129
- unsigned Input1Opcode = Input1 ->getOpcode();
22130
- if (!ISD::isExtOpcode(Input1Opcode ))
22131
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22137
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22138
+ unsigned Op1Opcode = Op1 ->getOpcode();
22139
+ if (!ISD::isExtOpcode(Op1Opcode ))
22140
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22132
22141
22133
- EVT AccVT = Acc ->getValueType(0);
22134
- Input1 = Input1 ->getOperand(0);
22135
- EVT InputVT = Input1 .getValueType();
22136
- Input2 = DAG.getAnyExtOrTrunc(Input2 , DL, InputVT );
22137
- SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2 );
22142
+ EVT AccVT = Op0 ->getValueType(0);
22143
+ Op1 = Op1 ->getOperand(0);
22144
+ EVT Op1VT = Op1 .getValueType();
22145
+ Op2 = DAG.getAnyExtOrTrunc(Op2 , DL, Op1VT );
22146
+ SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2 );
22138
22147
22139
22148
if (!AccVT.isScalableVector())
22140
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2 );
22149
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2 );
22141
22150
22142
- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22143
- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22144
- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22151
+ if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22152
+ !(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22153
+ !(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22145
22154
return SDValue();
22146
22155
22147
- unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
22148
- ? ISD::PARTIAL_REDUCE_SMLA
22149
- : ISD::PARTIAL_REDUCE_UMLA;
22150
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
22151
- DAG.getConstant(1, DL, InputVT));
22156
+ unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
22157
+ : ISD::PARTIAL_REDUCE_UMLA;
22158
+ return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
22159
+ DAG.getConstant(1, DL, Op1VT));
22152
22160
}
22153
22161
22154
22162
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22155
22163
const AArch64Subtarget *Subtarget) {
22156
22164
SDLoc DL(N);
22157
- auto Acc = N->getOperand(0);
22158
- auto Input1 = N->getOperand(1);
22159
- auto Input2 = N->getOperand(2);
22160
- EVT AccElemVT = Acc .getValueType().getVectorElementType();
22161
- EVT InputElemVT = Input1 .getValueType().getVectorElementType();
22165
+ SDValue Op0 = N->getOperand(0);
22166
+ SDValue Op1 = N->getOperand(1);
22167
+ SDValue Op2 = N->getOperand(2);
22168
+ EVT Op0ElemVT = Op0 .getValueType().getVectorElementType();
22169
+ EVT Op1ElemVT = Op1 .getValueType().getVectorElementType();
22162
22170
22163
22171
// If the exts have already been removed or it has already been lowered to an
22164
22172
// usdot instruction, then the element types will not be equal
22165
- if (InputElemVT != AccElemVT || Input1 .getOpcode() == AArch64ISD::USDOT)
22173
+ if (Op0ElemVT != Op1ElemVT || Op1 .getOpcode() == AArch64ISD::USDOT)
22166
22174
return SDValue(N, 0);
22167
22175
22168
- if (auto Dot =
22169
- tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
22176
+ if (auto Dot = tryCombineToDotProduct(Op0, Op1, Op2, DAG, Subtarget, DL))
22170
22177
return Dot;
22171
- if (auto WideAdd =
22172
- tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
22178
+ if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
22173
22179
return WideAdd;
22174
22180
return SDValue();
22175
22181
}
0 commit comments