Skip to content

Commit 9971a6e

Browse files
Make the no bin op changes work with adding Partial Reduction
SDNodes.
1 parent 6092426 commit 9971a6e

File tree

1 file changed

+96
-90
lines changed

1 file changed

+96
-90
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 96 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -22032,144 +22032,150 @@ static SDValue tryCombineWhileLo(SDNode *N,
2203222032
return SDValue(N, 0);
2203322033
}
2203422034

22035-
SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22035+
SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2203622036
SelectionDAG &DAG,
2203722037
const AArch64Subtarget *Subtarget, SDLoc &DL) {
22038-
bool Scalable = Acc.getValueType().isScalableVector();
22038+
bool Scalable = Op0->getValueType(0).isScalableVector();
2203922039
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
22040-
return SDValue();
22040+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
2204122041
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
22042+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22043+
22044+
unsigned Op1Opcode = Op1->getOpcode();
22045+
SDValue MulOpLHS, MulOpRHS;
22046+
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
22047+
if (ISD::isExtOpcode(Op1Opcode)) {
22048+
MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
22049+
MulOpLHS = Op1->getOperand(0);
22050+
MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
22051+
} else if (Op1Opcode == ISD::MUL) {
22052+
SDValue ExtMulOpLHS = Op1->getOperand(0);
22053+
SDValue ExtMulOpRHS = Op1->getOperand(1);
22054+
22055+
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
22056+
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
22057+
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
22058+
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
22059+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22060+
22061+
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
22062+
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
22063+
22064+
MulOpLHS = ExtMulOpLHS->getOperand(0);
22065+
MulOpRHS = ExtMulOpRHS->getOperand(0);
22066+
EVT MulOpLHSVT = MulOpLHS.getValueType();
22067+
22068+
if (MulOpLHSVT != MulOpRHS.getValueType())
22069+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22070+
22071+
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
22072+
MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
22073+
MulOpRHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpRHS, Op2);
22074+
} else
2204222075
return SDValue();
2204322076

22044-
unsigned Input1Opcode = Input1->getOpcode();
22045-
EVT AccVT = Acc->getValueType(0);
22046-
if (AccVT.getVectorElementCount() * 4 ==
22047-
Input1->getValueType(0).getVectorElementCount() &&
22048-
Input1Opcode != ISD::MUL)
22049-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22050-
if (Input1Opcode != ISD::MUL)
22051-
return SDValue();
22052-
22053-
auto A = Input1->getOperand(0);
22054-
auto B = Input1->getOperand(1);
22055-
unsigned AOpcode = A->getOpcode();
22056-
unsigned BOpcode = B->getOpcode();
22057-
22058-
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
22059-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22060-
22061-
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
22062-
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
22063-
22064-
A = A->getOperand(0);
22065-
B = B->getOperand(0);
22066-
EVT MulSrcVT = A.getValueType();
22067-
22068-
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
22069-
A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
22070-
B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
22077+
SDValue Acc = Op0;
22078+
EVT ReducedVT = Acc->getValueType(0);
22079+
EVT MulSrcVT = MulOpLHS.getValueType();
2207122080

2207222081
// Dot products operate on chunks of four elements so there must be four times
2207322082
// as many elements in the wide type
22074-
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22075-
!(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22076-
!(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22077-
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22078-
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22079-
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22080-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22081-
22082-
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22083-
if (AIsSigned != BIsSigned) {
22083+
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
22084+
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
22085+
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
22086+
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
22087+
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
22088+
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
22089+
return SDValue();
22090+
22091+
// If the extensions are mixed, we should lower it to a usdot instead
22092+
unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
22093+
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
2208422094
if (!Subtarget->hasMatMulInt8())
22085-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22095+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
2208622096

22087-
bool Scalable = AccVT.isScalableVT();
22097+
bool Scalable = ReducedVT.isScalableVT();
2208822098
// There's no nxv2i64 version of usdot
22089-
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
22090-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22099+
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
22100+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
2209122101

22092-
if (!BIsSigned)
22093-
std::swap(A, B);
22102+
if (!MulOpRHSIsSigned)
22103+
std::swap(MulOpLHS, MulOpRHS);
2209422104
DotOpcode = AArch64ISD::USDOT;
2209522105
// Lower usdot patterns here because legalisation would attempt to split it
2209622106
// unless exts are removed. But, removing the exts would lose the
2209722107
// information about whether each operand is signed.
22098-
if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22099-
(AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22100-
return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
22108+
if ((ReducedVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
22109+
(ReducedVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
22110+
return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
2210122111
}
2210222112

2210322113
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
2210422114
// product followed by a zero / sign extension. Need to lower this here
2210522115
// because legalisation would attempt to split it.
22106-
if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22107-
(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22108-
EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22116+
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22117+
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22118+
EVT ReducedVTI32 =
22119+
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
2210922120

22110-
auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
22111-
DAG.getConstant(0, DL, AccVTI32), A, B);
22112-
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
22113-
return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
22121+
SDValue DotI32 =
22122+
DAG.getNode(DotOpcode, DL, ReducedVTI32,
22123+
DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
22124+
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
22125+
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
2211422126
}
2211522127

22116-
if (A.getValueType() != B.getValueType())
22117-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22118-
2211922128
unsigned NewOpcode =
22120-
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22121-
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
22129+
MulOpLHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
22130+
return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
2212222131
}
2212322132

22124-
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
22133+
SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
2212522134
SelectionDAG &DAG,
2212622135
const AArch64Subtarget *Subtarget, SDLoc &DL) {
2212722136
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
22128-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22129-
unsigned Input1Opcode = Input1->getOpcode();
22130-
if (!ISD::isExtOpcode(Input1Opcode))
22131-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22137+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
22138+
unsigned Op1Opcode = Op1->getOpcode();
22139+
if (!ISD::isExtOpcode(Op1Opcode))
22140+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
2213222141

22133-
EVT AccVT = Acc->getValueType(0);
22134-
Input1 = Input1->getOperand(0);
22135-
EVT InputVT = Input1.getValueType();
22136-
Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
22137-
SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
22142+
EVT AccVT = Op0->getValueType(0);
22143+
Op1 = Op1->getOperand(0);
22144+
EVT Op1VT = Op1.getValueType();
22145+
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
22146+
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
2213822147

2213922148
if (!AccVT.isScalableVector())
22140-
return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
22149+
return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
2214122150

22142-
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22143-
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22144-
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22151+
if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22152+
!(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22153+
!(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2214522154
return SDValue();
2214622155

22147-
unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
22148-
? ISD::PARTIAL_REDUCE_SMLA
22149-
: ISD::PARTIAL_REDUCE_UMLA;
22150-
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
22151-
DAG.getConstant(1, DL, InputVT));
22156+
unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
22157+
: ISD::PARTIAL_REDUCE_UMLA;
22158+
return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
22159+
DAG.getConstant(1, DL, Op1VT));
2215222160
}
2215322161

2215422162
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
2215522163
const AArch64Subtarget *Subtarget) {
2215622164
SDLoc DL(N);
22157-
auto Acc = N->getOperand(0);
22158-
auto Input1 = N->getOperand(1);
22159-
auto Input2 = N->getOperand(2);
22160-
EVT AccElemVT = Acc.getValueType().getVectorElementType();
22161-
EVT InputElemVT = Input1.getValueType().getVectorElementType();
22165+
SDValue Op0 = N->getOperand(0);
22166+
SDValue Op1 = N->getOperand(1);
22167+
SDValue Op2 = N->getOperand(2);
22168+
EVT Op0ElemVT = Op0.getValueType().getVectorElementType();
22169+
EVT Op1ElemVT = Op1.getValueType().getVectorElementType();
2216222170

2216322171
// If the exts have already been removed or it has already been lowered to an
2216422172
// usdot instruction, then the element types will not be equal
22165-
if (InputElemVT != AccElemVT || Input1.getOpcode() == AArch64ISD::USDOT)
22173+
if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
2216622174
return SDValue(N, 0);
2216722175

22168-
if (auto Dot =
22169-
tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
22176+
if (auto Dot = tryCombineToDotProduct(Op0, Op1, Op2, DAG, Subtarget, DL))
2217022177
return Dot;
22171-
if (auto WideAdd =
22172-
tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
22178+
if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
2217322179
return WideAdd;
2217422180
return SDValue();
2217522181
}

0 commit comments

Comments
 (0)