Skip to content

Commit 43f73c2

Browse files
Separate lowering code for PARTIAL_REDUCE_U/SADD
Separate lowering code from all being in the DAG-combine function. Now the DAG-combine decides whether the node should be the signed or unsigned version of partial reduce add. Then there is a function in LowerOperation that does the actual lowering to wide adds or dot products if it is able to.
1 parent ba98a37 commit 43f73c2

File tree

2 files changed

+146
-120
lines changed

2 files changed

+146
-120
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 145 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,8 +1846,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18461846
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
18471847
Custom);
18481848
}
1849+
1850+
for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
1851+
setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1852+
setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1853+
}
18491854
}
18501855

1856+
for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
1857+
setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
1858+
setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
1859+
}
18511860

18521861
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
18531862
// Only required for llvm.aarch64.mops.memset.tag
@@ -2046,17 +2055,18 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
20462055
return true;
20472056

20482057
EVT VT = EVT::getEVT(I->getType());
2049-
auto Op1 = I->getOperand(1);
2050-
EVT Op1VT = EVT::getEVT(Op1->getType());
2051-
if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2052-
(Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2053-
(Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2054-
(Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2055-
(Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2056-
(Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2057-
(Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
2058-
(Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
2059-
(Op1VT == MVT::v8i32 && VT == MVT::v2i32))
2058+
auto Input = I->getOperand(1);
2059+
EVT InputVT = EVT::getEVT(Input->getType());
2060+
2061+
if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
2062+
(InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
2063+
(InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
2064+
(InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
2065+
(InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
2066+
(InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
2067+
(InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
2068+
(InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
2069+
(InputVT == MVT::v8i32 && VT == MVT::v2i32))
20602070
return false;
20612071
return true;
20622072
}
@@ -7659,6 +7669,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
76597669
return LowerFLDEXP(Op, DAG);
76607670
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
76617671
return LowerVECTOR_HISTOGRAM(Op, DAG);
7672+
case ISD::PARTIAL_REDUCE_UADD:
7673+
case ISD::PARTIAL_REDUCE_SADD:
7674+
return LowerPARTIAL_REDUCE_ADD(Op, DAG);
76627675
}
76637676
}
76647677

@@ -22019,147 +22032,126 @@ static SDValue tryCombineWhileLo(SDNode *N,
2201922032
return SDValue(N, 0);
2202022033
}
2202122034

22022-
SDValue tryLowerPartialReductionToDot(SDNode *N,
22023-
const AArch64Subtarget *Subtarget,
22024-
SelectionDAG &DAG) {
22025-
22026-
bool Scalable = N->getValueType(0).isScalableVector();
22035+
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22036+
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22037+
bool Scalable = Acc.getValueType().isScalableVector();
2202722038
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
2202822039
return SDValue();
2202922040
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
2203022041
return SDValue();
2203122042

22032-
SDLoc DL(N);
22033-
22034-
// The narrower of the two operands. Used as the accumulator
22035-
auto NarrowOp = N->getOperand(0);
22036-
auto MulOp = N->getOperand(1);
22037-
if (MulOp->getOpcode() != ISD::MUL)
22043+
unsigned InputOpcode = Input->getOpcode();
22044+
if (InputOpcode != ISD::MUL)
2203822045
return SDValue();
22039-
22040-
auto A = MulOp->getOperand(0);
22041-
auto B = MulOp->getOperand(1);
22042-
22046+
auto A = Input->getOperand(0);
22047+
auto B = Input->getOperand(1);
2204322048
unsigned AOpcode = A->getOpcode();
2204422049
unsigned BOpcode = B->getOpcode();
22045-
unsigned Opcode;
22046-
EVT ReducedType = N->getValueType(0);
22047-
EVT MulSrcType;
22048-
if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
22049-
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22050-
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
22051-
22052-
A = A->getOperand(0);
22053-
B = B->getOperand(0);
22054-
if (A.getValueType() != B.getValueType())
22055-
return SDValue();
22050+
EVT AccVT = Acc->getValueType(0);
2205622051

22057-
if (AIsSigned != BIsSigned) {
22058-
if (!Subtarget->hasMatMulInt8())
22059-
return SDValue();
22052+
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22053+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2206022054

22061-
bool Scalable = N->getValueType(0).isScalableVT();
22062-
// There's no nxv2i64 version of usdot
22063-
if (Scalable && ReducedType != MVT::nxv4i32 &&
22064-
ReducedType != MVT::nxv4i64)
22065-
return SDValue();
22055+
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22056+
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
2206622057

22067-
Opcode = AArch64ISD::USDOT;
22068-
// USDOT expects the signed operand to be last
22069-
if (!BIsSigned)
22070-
std::swap(A, B);
22071-
} else if (AIsSigned)
22072-
Opcode = AArch64ISD::SDOT;
22073-
else
22074-
Opcode = AArch64ISD::UDOT;
22075-
MulSrcType = A.getValueType();
22076-
}
22058+
A = A->getOperand(0);
22059+
B = B->getOperand(0);
22060+
EVT MulSrcVT = A.getValueType();
2207722061

2207822062
// Dot products operate on chunks of four elements so there must be four times
2207922063
// as many elements in the wide type
22080-
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
22081-
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
22082-
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
22083-
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
22084-
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
22085-
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
22086-
return SDValue();
22064+
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22065+
!(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22066+
!(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22067+
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22068+
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22069+
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22070+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22071+
22072+
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22073+
if (AIsSigned != BIsSigned) {
22074+
if (!Subtarget->hasMatMulInt8())
22075+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22076+
22077+
bool Scalable = AccVT.isScalableVT();
22078+
// There's no nxv2i64 version of usdot
22079+
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22080+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22081+
22082+
if (!BIsSigned)
22083+
std::swap(A, B);
22084+
DotOpcode = AArch64ISD::USDOT;
22085+
// Lower usdot patterns here because legalisation would attempt to split it
22086+
// unless exts are removed. But, removing the exts would lose the
22087+
// information about whether each operand is signed.
22088+
if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22089+
(AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22090+
return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
22091+
}
2208722092

2208822093
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
22089-
// product followed by a zero / sign extension
22090-
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22091-
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22092-
EVT ReducedVTI32 =
22093-
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22094+
// product followed by a zero / sign extension. Need to lower this here
22095+
// because legalisation would attempt to split it.
22096+
if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22097+
(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22098+
EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
2209422099

22095-
SDValue DotI32 =
22096-
DAG.getNode(Opcode, DL, ReducedVTI32,
22097-
DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22098-
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22099-
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
22100+
auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
22101+
DAG.getConstant(0, DL, AccVTI32), A, B);
22102+
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
22103+
return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
2210022104
}
2210122105

22102-
return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
22103-
}
22106+
if (A.getValueType() != B.getValueType())
22107+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
2210422108

22105-
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
22106-
const AArch64Subtarget *Subtarget,
22107-
SelectionDAG &DAG) {
22109+
unsigned NewOpcode =
22110+
AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
22111+
auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
22112+
return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
22113+
}
2210822114

22115+
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
22116+
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2210922117
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22110-
return SDValue();
22111-
22112-
SDLoc DL(N);
22113-
22114-
auto Acc = N->getOperand(0);
22115-
auto Input = N->getOperand(1);
22116-
22117-
unsigned Opcode = N->getOpcode();
22118-
unsigned InputOpcode = Input.getOpcode();
22119-
if (ISD::isExtOpcode(InputOpcode)) {
22120-
Input = Input.getOperand(0);
22121-
if (InputOpcode == ISD::SIGN_EXTEND)
22122-
Opcode = ISD::PARTIAL_REDUCE_SADD;
22123-
}
22124-
22118+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22119+
unsigned InputOpcode = Input->getOpcode();
22120+
if (!ISD::isExtOpcode(InputOpcode))
22121+
return DAG.expandPartialReduceAdd(DL, Acc, Input);
22122+
Input = Input->getOperand(0);
2212522123
EVT InputVT = Input.getValueType();
22126-
EVT AccVT = Acc.getValueType();
22124+
EVT AccVT = Acc->getValueType(0);
2212722125

22128-
if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22129-
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22130-
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22126+
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22127+
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22128+
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2213122129
return SDValue();
2213222130

22133-
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
22134-
auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22135-
auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22136-
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
22137-
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
22138-
}
22139-
22140-
static SDValue
22141-
performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22142-
const AArch64Subtarget *Subtarget) {
22143-
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22144-
return Dot;
22145-
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22146-
return WideAdd;
22147-
return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
22148-
N->getOperand(1));
22131+
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
22132+
? ISD::PARTIAL_REDUCE_SADD
22133+
: ISD::PARTIAL_REDUCE_UADD;
22134+
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
2214922135
}
2215022136

22137+
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22138+
const AArch64Subtarget *Subtarget) {
22139+
SDLoc DL(N);
22140+
auto Acc = N->getOperand(0);
22141+
auto Input = N->getOperand(1);
22142+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22143+
EVT InputElemVT = Input.getValueType().getVectorElementType();
2215122144

22145+
// If the exts have already been removed or it has already been lowered to an
22146+
// usdot instruction, then the element types will not be equal
22147+
if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
22148+
return SDValue(N, 0);
2215222149

22153-
static SDValue
22154-
performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
22155-
const AArch64Subtarget *Subtarget) {
22156-
auto *PR = cast<PartialReduceAddSDNode>(N);
22157-
if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
22150+
if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
2215822151
return Dot;
22159-
if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
22152+
if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
2216022153
return WideAdd;
22161-
return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
22162-
PR->getInput());
22154+
return SDValue();
2216322155
}
2216422156

2216522157
static SDValue performIntrinsicCombine(SDNode *N,
@@ -29372,6 +29364,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2937229364
return Scatter;
2937329365
}
2937429366

29367+
SDValue
29368+
AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
29369+
SelectionDAG &DAG) const {
29370+
SDLoc DL(Op);
29371+
SDValue Acc = Op.getOperand(0);
29372+
SDValue Input = Op.getOperand(1);
29373+
29374+
EVT AccVT = Acc.getValueType();
29375+
EVT InputVT = Input.getValueType();
29376+
29377+
unsigned Opcode = Op.getOpcode();
29378+
29379+
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
29380+
unsigned IndexAdd = 0;
29381+
// ISD::MUL may have already been lowered, meaning the operands would be in
29382+
// different positions.
29383+
if (Input.getOpcode() != ISD::MUL)
29384+
IndexAdd = 1;
29385+
auto A = Input.getOperand(IndexAdd);
29386+
auto B = Input.getOperand(IndexAdd + 1);
29387+
29388+
unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
29389+
: AArch64ISD::UDOT;
29390+
return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
29391+
}
29392+
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
29393+
unsigned BottomOpcode =
29394+
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
29395+
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
29396+
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
29397+
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
29398+
}
29399+
2937529400
SDValue
2937629401
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
2937729402
SelectionDAG &DAG) const {

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ class AArch64TargetLowering : public TargetLowering {
11841184
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11851185
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11861186
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
1187+
SDValue LowerPARTIAL_REDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
11871188
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
11881189
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
11891190
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)