Skip to content

Commit a5f0525

Browse files
authored
[AArch64][SelectionDAG] Enable new partial reduction lowering by default (llvm#143565)
1 parent c7d8581 commit a5f0525

File tree

5 files changed

+1804
-1817
lines changed

5 files changed

+1804
-1817
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,6 @@ cl::opt<bool> EnableSVEGISel(
153153
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
154154
cl::init(false));
155155

156-
// FIXME : This is a temporary flag, and is used to help transition to
157-
// performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
158-
// nodes.
159-
static cl::opt<bool> EnablePartialReduceNodes(
160-
"aarch64-enable-partial-reduce-nodes", cl::init(false), cl::ReallyHidden,
161-
cl::desc("Use the new method of lowering partial reductions."));
162-
163156
/// Value type used for condition codes.
164157
static const MVT MVT_CC = MVT::i32;
165158

@@ -1457,7 +1450,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14571450
for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
14581451
setOperationAction(ISD::FADD, VT, Custom);
14591452

1460-
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1453+
if (Subtarget->hasDotProd()) {
14611454
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
14621455
ISD::PARTIAL_REDUCE_UMLA};
14631456

@@ -1895,7 +1888,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18951888
}
18961889

18971890
// Handle partial reduction operations
1898-
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
1891+
if (Subtarget->isSVEorStreamingSVEAvailable()) {
18991892
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
19001893
// Other pairs will default to 'Expand'.
19011894
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
@@ -1957,17 +1950,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19571950
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19581951
Custom);
19591952

1960-
if (EnablePartialReduceNodes) {
1961-
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1962-
ISD::PARTIAL_REDUCE_UMLA};
1963-
// Must be lowered to SVE instructions.
1964-
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1965-
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1966-
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1967-
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1968-
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1969-
setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1970-
}
1953+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1954+
ISD::PARTIAL_REDUCE_UMLA};
1955+
// Must be lowered to SVE instructions.
1956+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1957+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1958+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1959+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1960+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1961+
setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
19711962
}
19721963
}
19731964

@@ -2165,16 +2156,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
21652156
assert(I->getIntrinsicID() ==
21662157
Intrinsic::experimental_vector_partial_reduce_add &&
21672158
"Unexpected intrinsic!");
2168-
if (EnablePartialReduceNodes)
2169-
return true;
2170-
2171-
EVT VT = EVT::getEVT(I->getType());
2172-
auto Op1 = I->getOperand(1);
2173-
EVT Op1VT = EVT::getEVT(Op1->getType());
2174-
if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2175-
(VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2176-
VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2177-
return false;
21782159
return true;
21792160
}
21802161

@@ -2252,37 +2233,32 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22522233
bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22532234
bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22542235

2255-
if (EnablePartialReduceNodes) {
2256-
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2257-
ISD::PARTIAL_REDUCE_UMLA};
2258-
unsigned NumElts = VT.getVectorNumElements();
2259-
if (VT.getVectorElementType() == MVT::i64) {
2260-
setPartialReduceMLAAction(MLAOps, VT,
2261-
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2262-
setPartialReduceMLAAction(
2263-
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2264-
setPartialReduceMLAAction(
2265-
MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2266-
} else if (VT.getVectorElementType() == MVT::i32) {
2267-
setPartialReduceMLAAction(MLAOps, VT,
2236+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2237+
ISD::PARTIAL_REDUCE_UMLA};
2238+
unsigned NumElts = VT.getVectorNumElements();
2239+
if (VT.getVectorElementType() == MVT::i64) {
2240+
setPartialReduceMLAAction(MLAOps, VT,
2241+
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2242+
setPartialReduceMLAAction(MLAOps, VT,
2243+
MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2244+
setPartialReduceMLAAction(MLAOps, VT,
2245+
MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2246+
} else if (VT.getVectorElementType() == MVT::i32) {
2247+
setPartialReduceMLAAction(MLAOps, VT,
2248+
MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2249+
setPartialReduceMLAAction(MLAOps, VT,
2250+
MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2251+
} else if (VT.getVectorElementType() == MVT::i16) {
2252+
setPartialReduceMLAAction(MLAOps, VT,
2253+
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2254+
}
2255+
if (Subtarget->hasMatMulInt8()) {
2256+
if (VT.getVectorElementType() == MVT::i32)
2257+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
22682258
MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2269-
setPartialReduceMLAAction(
2270-
MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2271-
} else if (VT.getVectorElementType() == MVT::i16) {
2272-
setPartialReduceMLAAction(MLAOps, VT,
2273-
MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2274-
}
2275-
2276-
if (Subtarget->hasMatMulInt8()) {
2277-
if (VT.getVectorElementType() == MVT::i32)
2278-
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2279-
MVT::getVectorVT(MVT::i8, NumElts * 4),
2280-
Custom);
2281-
else if (VT.getVectorElementType() == MVT::i64)
2282-
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2283-
MVT::getVectorVT(MVT::i8, NumElts * 8),
2284-
Custom);
2285-
}
2259+
else if (VT.getVectorElementType() == MVT::i64)
2260+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2261+
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
22862262
}
22872263

22882264
// Lower fixed length vector operations to scalable equivalents.

0 commit comments

Comments
 (0)