@@ -1846,8 +1846,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1846
1846
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
1847
1847
Custom);
1848
1848
}
1849
+
1850
+ for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
1851
+ setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1852
+ setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1853
+ }
1849
1854
}
1850
1855
1856
+ for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
1857
+ setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1858
+ setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1859
+ }
1851
1860
1852
1861
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
1853
1862
// Only required for llvm.aarch64.mops.memset.tag
@@ -2046,17 +2055,18 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
2046
2055
return true;
2047
2056
2048
2057
EVT VT = EVT::getEVT(I->getType());
2049
- auto Op1 = I->getOperand(1);
2050
- EVT Op1VT = EVT::getEVT(Op1->getType());
2051
- if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2052
- (Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2053
- (Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2054
- (Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2055
- (Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2056
- (Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2057
- (Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
2058
- (Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
2059
- (Op1VT == MVT::v8i32 && VT == MVT::v2i32))
2058
+ auto Input = I->getOperand(1);
2059
+ EVT InputVT = EVT::getEVT(Input->getType());
2060
+
2061
+ if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2062
+ (InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2063
+ (InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2064
+ (InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2065
+ (InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2066
+ (InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2067
+ (InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
2068
+ (InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
2069
+ (InputVT == MVT::v8i32 && VT == MVT::v2i32))
2060
2070
return false;
2061
2071
return true;
2062
2072
}
@@ -7659,6 +7669,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
7659
7669
return LowerFLDEXP(Op, DAG);
7660
7670
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
7661
7671
return LowerVECTOR_HISTOGRAM(Op, DAG);
7672
+ case ISD::PARTIAL_REDUCE_UADD:
7673
+ case ISD::PARTIAL_REDUCE_SADD:
7674
+ return LowerPARTIAL_REDUCE_ADD(Op, DAG);
7662
7675
}
7663
7676
}
7664
7677
@@ -22019,147 +22032,126 @@ static SDValue tryCombineWhileLo(SDNode *N,
22019
22032
return SDValue(N, 0);
22020
22033
}
22021
22034
22022
- SDValue tryLowerPartialReductionToDot(SDNode *N,
22023
- const AArch64Subtarget *Subtarget,
22024
- SelectionDAG &DAG) {
22025
-
22026
- bool Scalable = N->getValueType(0).isScalableVector();
22035
+ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22036
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
22037
+ bool Scalable = Acc.getValueType().isScalableVector();
22027
22038
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22028
22039
return SDValue();
22029
22040
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22030
22041
return SDValue();
22031
22042
22032
- SDLoc DL(N);
22033
-
22034
- // The narrower of the two operands. Used as the accumulator
22035
- auto NarrowOp = N->getOperand(0);
22036
- auto MulOp = N->getOperand(1);
22037
- if (MulOp->getOpcode() != ISD::MUL)
22043
+ unsigned InputOpcode = Input->getOpcode();
22044
+ if (InputOpcode != ISD::MUL)
22038
22045
return SDValue();
22039
-
22040
- auto A = MulOp->getOperand(0);
22041
- auto B = MulOp->getOperand(1);
22042
-
22046
+ auto A = Input->getOperand(0);
22047
+ auto B = Input->getOperand(1);
22043
22048
unsigned AOpcode = A->getOpcode();
22044
22049
unsigned BOpcode = B->getOpcode();
22045
- unsigned Opcode;
22046
- EVT ReducedType = N->getValueType(0);
22047
- EVT MulSrcType;
22048
- if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
22049
- bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22050
- bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
22051
-
22052
- A = A->getOperand(0);
22053
- B = B->getOperand(0);
22054
- if (A.getValueType() != B.getValueType())
22055
- return SDValue();
22050
+ EVT AccVT = Acc->getValueType(0);
22056
22051
22057
- if (AIsSigned != BIsSigned) {
22058
- if (!Subtarget->hasMatMulInt8())
22059
- return SDValue();
22052
+ if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22053
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22060
22054
22061
- bool Scalable = N->getValueType(0).isScalableVT();
22062
- // There's no nxv2i64 version of usdot
22063
- if (Scalable && ReducedType != MVT::nxv4i32 &&
22064
- ReducedType != MVT::nxv4i64)
22065
- return SDValue();
22055
+ bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22056
+ bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
22066
22057
22067
- Opcode = AArch64ISD::USDOT;
22068
- // USDOT expects the signed operand to be last
22069
- if (!BIsSigned)
22070
- std::swap(A, B);
22071
- } else if (AIsSigned)
22072
- Opcode = AArch64ISD::SDOT;
22073
- else
22074
- Opcode = AArch64ISD::UDOT;
22075
- MulSrcType = A.getValueType();
22076
- }
22058
+ A = A->getOperand(0);
22059
+ B = B->getOperand(0);
22060
+ EVT MulSrcVT = A.getValueType();
22077
22061
22078
22062
// Dot products operate on chunks of four elements so there must be four times
22079
22063
// as many elements in the wide type
22080
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
22081
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
22082
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
22083
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
22084
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
22085
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
22086
- return SDValue();
22064
+ if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22065
+ !(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22066
+ !(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22067
+ !(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22068
+ !(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22069
+ !(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22070
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22071
+
22072
+ unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22073
+ if (AIsSigned != BIsSigned) {
22074
+ if (!Subtarget->hasMatMulInt8())
22075
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22076
+
22077
+ bool Scalable = AccVT.isScalableVT();
22078
+ // There's no nxv2i64 version of usdot
22079
+ if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22080
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22081
+
22082
+ if (!BIsSigned)
22083
+ std::swap(A, B);
22084
+ DotOpcode = AArch64ISD::USDOT;
22085
+ // Lower usdot patterns here because legalisation would attempt to split it
22086
+ // unless exts are removed. But, removing the exts would lose the
22087
+ // information about whether each operand is signed.
22088
+ if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22089
+ (AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22090
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
22091
+ }
22087
22092
22088
22093
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22089
- // product followed by a zero / sign extension
22090
- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22091
- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22092
- EVT ReducedVTI32 =
22093
- (ReducedVT .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22094
+ // product followed by a zero / sign extension. Need to lower this here
22095
+ // because legalisation would attempt to split it.
22096
+ if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22097
+ (AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22098
+ EVT AccVTI32 = (AccVT .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22094
22099
22095
- SDValue DotI32 =
22096
- DAG.getNode(Opcode, DL, ReducedVTI32,
22097
- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22098
- SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22099
- return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22100
+ auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
22101
+ DAG.getConstant(0, DL, AccVTI32), A, B);
22102
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
22103
+ return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
22100
22104
}
22101
22105
22102
- return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
22103
- }
22106
+ if (A.getValueType() != B.getValueType())
22107
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22104
22108
22105
- SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
22106
- const AArch64Subtarget *Subtarget,
22107
- SelectionDAG &DAG) {
22109
+ unsigned NewOpcode =
22110
+ AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
22111
+ auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
22112
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
22113
+ }
22108
22114
22115
+ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22116
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
22109
22117
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22110
- return SDValue();
22111
-
22112
- SDLoc DL(N);
22113
-
22114
- auto Acc = N->getOperand(0);
22115
- auto Input = N->getOperand(1);
22116
-
22117
- unsigned Opcode = N->getOpcode();
22118
- unsigned InputOpcode = Input.getOpcode();
22119
- if (ISD::isExtOpcode(InputOpcode)) {
22120
- Input = Input.getOperand(0);
22121
- if (InputOpcode == ISD::SIGN_EXTEND)
22122
- Opcode = ISD::PARTIAL_REDUCE_SADD;
22123
- }
22124
-
22118
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22119
+ unsigned InputOpcode = Input->getOpcode();
22120
+ if (!ISD::isExtOpcode(InputOpcode))
22121
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
22122
+ Input = Input->getOperand(0);
22125
22123
EVT InputVT = Input.getValueType();
22126
- EVT AccVT = Acc. getValueType();
22124
+ EVT AccVT = Acc-> getValueType(0 );
22127
22125
22128
- if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22129
- !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22130
- !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22126
+ if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22127
+ !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22128
+ !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22131
22129
return SDValue();
22132
22130
22133
- bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
22134
- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22135
- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22136
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
22137
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
22138
- }
22139
-
22140
- static SDValue
22141
- performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22142
- const AArch64Subtarget *Subtarget) {
22143
- if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22144
- return Dot;
22145
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22146
- return WideAdd;
22147
- return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
22148
- N->getOperand(1));
22131
+ unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22132
+ ? ISD::PARTIAL_REDUCE_SADD
22133
+ : ISD::PARTIAL_REDUCE_UADD;
22134
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
22149
22135
}
22150
22136
22137
+ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22138
+ const AArch64Subtarget *Subtarget) {
22139
+ SDLoc DL(N);
22140
+ auto Acc = N->getOperand(0);
22141
+ auto Input = N->getOperand(1);
22142
+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
22143
+ EVT InputElemVT = Input.getValueType().getVectorElementType();
22151
22144
22145
+ // If the exts have already been removed or it has already been lowered to an
22146
+ // usdot instruction, then the element types will not be equal
22147
+ if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
22148
+ return SDValue(N, 0);
22152
22149
22153
- static SDValue
22154
- performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22155
- const AArch64Subtarget *Subtarget) {
22156
- auto *PR = cast<PartialReduceAddSDNode>(N);
22157
- if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
22150
+ if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
22158
22151
return Dot;
22159
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget , DAG))
22152
+ if (auto WideAdd = tryCombineToWideAdd(Acc, Input , DAG, Subtarget, DL ))
22160
22153
return WideAdd;
22161
- return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
22162
- PR->getInput());
22154
+ return SDValue();
22163
22155
}
22164
22156
22165
22157
static SDValue performIntrinsicCombine(SDNode *N,
@@ -29372,6 +29364,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
29372
29364
return Scatter;
29373
29365
}
29374
29366
29367
+ SDValue
29368
+ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
29369
+ SelectionDAG &DAG) const {
29370
+ SDLoc DL(Op);
29371
+ SDValue Acc = Op.getOperand(0);
29372
+ SDValue Input = Op.getOperand(1);
29373
+
29374
+ EVT AccVT = Acc.getValueType();
29375
+ EVT InputVT = Input.getValueType();
29376
+
29377
+ unsigned Opcode = Op.getOpcode();
29378
+
29379
+ if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29380
+ unsigned IndexAdd = 0;
29381
+ // ISD::MUL may have already been lowered, meaning the operands would be in
29382
+ // different positions.
29383
+ if (Input.getOpcode() != ISD::MUL)
29384
+ IndexAdd = 1;
29385
+ auto A = Input.getOperand(IndexAdd);
29386
+ auto B = Input.getOperand(IndexAdd + 1);
29387
+
29388
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
29389
+ : AArch64ISD::UDOT;
29390
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
29391
+ }
29392
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
29393
+ unsigned BottomOpcode =
29394
+ InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
29395
+ unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29396
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
29397
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
29398
+ }
29399
+
29375
29400
SDValue
29376
29401
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
29377
29402
SelectionDAG &DAG) const {
0 commit comments