Skip to content

Commit 3dbff90

Browse files
committed
[X86] matchPMADDWD/matchPMADDWD_2 - update to use SDPatternMatch matching. NFCI.
Prep work for #118433
1 parent 82c93b6 commit 3dbff90

File tree

1 file changed

+41
-62
lines changed

1 file changed

+41
-62
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -56447,9 +56447,11 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
5644756447
return SDValue();
5644856448
}
5644956449

56450-
static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
56450+
static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
5645156451
const SDLoc &DL, EVT VT,
5645256452
const X86Subtarget &Subtarget) {
56453+
using namespace SDPatternMatch;
56454+
5645356455
// Example of pattern we try to detect:
5645456456
// t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
5645556457
//(add (build_vector (extract_elt t, 0),
@@ -56464,15 +56466,16 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5646456466
if (!Subtarget.hasSSE2())
5646556467
return SDValue();
5646656468

56467-
if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
56468-
Op1.getOpcode() != ISD::BUILD_VECTOR)
56469-
return SDValue();
56470-
5647156469
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
5647256470
VT.getVectorNumElements() < 4 ||
5647356471
!isPowerOf2_32(VT.getVectorNumElements()))
5647456472
return SDValue();
5647556473

56474+
SDValue Op0, Op1;
56475+
if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
56476+
m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
56477+
return SDValue();
56478+
5647656479
// Check if one of Op0,Op1 is of the form:
5647756480
// (build_vector (extract_elt Mul, 0),
5647856481
// (extract_elt Mul, 2),
@@ -56489,26 +56492,23 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5648956492
SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
5649056493
Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
5649156494
// TODO: Be more tolerant to undefs.
56492-
if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56493-
Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56494-
Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56495-
Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56496-
return SDValue();
56497-
auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
56498-
auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
56499-
auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
56500-
auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
56501-
if (!Const0L || !Const1L || !Const0H || !Const1H)
56495+
APInt Idx0L, Idx0H, Idx1L, Idx1H;
56496+
if (!sd_match(Op0L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56497+
m_ConstInt(Idx0L))) ||
56498+
!sd_match(Op0H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56499+
m_ConstInt(Idx0H))) ||
56500+
!sd_match(Op1L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56501+
m_ConstInt(Idx1L))) ||
56502+
!sd_match(Op1H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56503+
m_ConstInt(Idx1H))))
5650256504
return SDValue();
56503-
unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
56504-
Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
5650556505
// Commutativity of mul allows factors of a product to reorder.
56506-
if (Idx0L > Idx1L)
56506+
if (Idx0L.getZExtValue() > Idx1L.getZExtValue())
5650756507
std::swap(Idx0L, Idx1L);
56508-
if (Idx0H > Idx1H)
56508+
if (Idx0H.getZExtValue() > Idx1H.getZExtValue())
5650956509
std::swap(Idx0H, Idx1H);
5651056510
// Commutativity of add allows pairs of factors to reorder.
56511-
if (Idx0L > Idx0H) {
56511+
if (Idx0L.getZExtValue() > Idx0H.getZExtValue()) {
5651256512
std::swap(Idx0L, Idx0H);
5651356513
std::swap(Idx1L, Idx1H);
5651456514
}
@@ -56555,39 +56555,26 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5655556555
// Attempt to turn this pattern into PMADDWD.
5655656556
// (add (mul (sext (build_vector)), (sext (build_vector))),
5655756557
// (mul (sext (build_vector)), (sext (build_vector)))
56558-
static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
56558+
static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDNode *N,
5655956559
const SDLoc &DL, EVT VT,
5656056560
const X86Subtarget &Subtarget) {
56561-
if (!Subtarget.hasSSE2())
56562-
return SDValue();
56561+
using namespace SDPatternMatch;
5656356562

56564-
if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
56563+
if (!Subtarget.hasSSE2())
5656556564
return SDValue();
5656656565

5656756566
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
5656856567
VT.getVectorNumElements() < 4 ||
5656956568
!isPowerOf2_32(VT.getVectorNumElements()))
5657056569
return SDValue();
5657156570

56572-
SDValue N00 = N0.getOperand(0);
56573-
SDValue N01 = N0.getOperand(1);
56574-
SDValue N10 = N1.getOperand(0);
56575-
SDValue N11 = N1.getOperand(1);
56576-
5657756571
// All inputs need to be sign extends.
5657856572
// TODO: Support ZERO_EXTEND from known positive?
56579-
if (N00.getOpcode() != ISD::SIGN_EXTEND ||
56580-
N01.getOpcode() != ISD::SIGN_EXTEND ||
56581-
N10.getOpcode() != ISD::SIGN_EXTEND ||
56582-
N11.getOpcode() != ISD::SIGN_EXTEND)
56573+
SDValue N00, N01, N10, N11;
56574+
if (!sd_match(N, m_Add(m_Mul(m_SExt(m_Value(N00)), m_SExt(m_Value(N01))),
56575+
m_Mul(m_SExt(m_Value(N10)), m_SExt(m_Value(N11))))))
5658356576
return SDValue();
5658456577

56585-
// Peek through the extends.
56586-
N00 = N00.getOperand(0);
56587-
N01 = N01.getOperand(0);
56588-
N10 = N10.getOperand(0);
56589-
N11 = N11.getOperand(0);
56590-
5659156578
// Must be extending from vXi16.
5659256579
EVT InVT = N00.getValueType();
5659356580
if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
@@ -56614,34 +56601,26 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
5661456601
SDValue N10Elt = N10.getOperand(i);
5661556602
SDValue N11Elt = N11.getOperand(i);
5661656603
// TODO: Be more tolerant to undefs.
56617-
if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56618-
N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56619-
N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56620-
N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56621-
return SDValue();
56622-
auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
56623-
auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
56624-
auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
56625-
auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
56626-
if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
56604+
SDValue N00In, N01In, N10In, N11In;
56605+
APInt IdxN00, IdxN01, IdxN10, IdxN11;
56606+
if (!sd_match(N00Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N00In),
56607+
m_ConstInt(IdxN00))) ||
56608+
!sd_match(N01Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N01In),
56609+
m_ConstInt(IdxN01))) ||
56610+
!sd_match(N10Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N10In),
56611+
m_ConstInt(IdxN10))) ||
56612+
!sd_match(N11Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N11In),
56613+
m_ConstInt(IdxN11))))
5662756614
return SDValue();
56628-
unsigned IdxN00 = ConstN00Elt->getZExtValue();
56629-
unsigned IdxN01 = ConstN01Elt->getZExtValue();
56630-
unsigned IdxN10 = ConstN10Elt->getZExtValue();
56631-
unsigned IdxN11 = ConstN11Elt->getZExtValue();
5663256615
// Add is commutative so indices can be reordered.
56633-
if (IdxN00 > IdxN10) {
56616+
if (IdxN00.getZExtValue() > IdxN10.getZExtValue()) {
5663456617
std::swap(IdxN00, IdxN10);
5663556618
std::swap(IdxN01, IdxN11);
5663656619
}
5663756620
// N0 indices be the even element. N1 indices must be the next odd element.
56638-
if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
56639-
IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
56621+
if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || IdxN01 != 2 * i ||
56622+
IdxN11 != 2 * i + 1)
5664056623
return SDValue();
56641-
SDValue N00In = N00Elt.getOperand(0);
56642-
SDValue N01In = N01Elt.getOperand(0);
56643-
SDValue N10In = N10Elt.getOperand(0);
56644-
SDValue N11In = N11Elt.getOperand(0);
5664556624

5664656625
// First time we find an input capture it.
5664756626
if (!In0) {
@@ -56815,9 +56794,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5681556794
if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
5681656795
return Select;
5681756796

56818-
if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget))
56797+
if (SDValue MAdd = matchPMADDWD(DAG, N, DL, VT, Subtarget))
5681956798
return MAdd;
56820-
if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget))
56799+
if (SDValue MAdd = matchPMADDWD_2(DAG, N, DL, VT, Subtarget))
5682156800
return MAdd;
5682256801
if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
5682356802
return MAdd;

0 commit comments

Comments
 (0)