@@ -56447,9 +56447,11 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
56447
56447
return SDValue();
56448
56448
}
56449
56449
56450
- static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1 ,
56450
+ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N ,
56451
56451
const SDLoc &DL, EVT VT,
56452
56452
const X86Subtarget &Subtarget) {
56453
+ using namespace SDPatternMatch;
56454
+
56453
56455
// Example of pattern we try to detect:
56454
56456
// t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
56455
56457
//(add (build_vector (extract_elt t, 0),
@@ -56464,15 +56466,16 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
56464
56466
if (!Subtarget.hasSSE2())
56465
56467
return SDValue();
56466
56468
56467
- if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
56468
- Op1.getOpcode() != ISD::BUILD_VECTOR)
56469
- return SDValue();
56470
-
56471
56469
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
56472
56470
VT.getVectorNumElements() < 4 ||
56473
56471
!isPowerOf2_32(VT.getVectorNumElements()))
56474
56472
return SDValue();
56475
56473
56474
+ SDValue Op0, Op1;
56475
+ if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
56476
+ m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
56477
+ return SDValue();
56478
+
56476
56479
// Check if one of Op0,Op1 is of the form:
56477
56480
// (build_vector (extract_elt Mul, 0),
56478
56481
// (extract_elt Mul, 2),
@@ -56489,26 +56492,23 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
56489
56492
SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
56490
56493
Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
56491
56494
// TODO: Be more tolerant to undefs.
56492
- if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56493
- Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56494
- Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56495
- Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56496
- return SDValue();
56497
- auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
56498
- auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
56499
- auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
56500
- auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
56501
- if (!Const0L || !Const1L || !Const0H || !Const1H)
56495
+ APInt Idx0L, Idx0H, Idx1L, Idx1H;
56496
+ if (!sd_match(Op0L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56497
+ m_ConstInt(Idx0L))) ||
56498
+ !sd_match(Op0H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56499
+ m_ConstInt(Idx0H))) ||
56500
+ !sd_match(Op1L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56501
+ m_ConstInt(Idx1L))) ||
56502
+ !sd_match(Op1H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56503
+ m_ConstInt(Idx1H))))
56502
56504
return SDValue();
56503
- unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
56504
- Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
56505
56505
// Commutativity of mul allows factors of a product to reorder.
56506
- if (Idx0L > Idx1L)
56506
+ if (Idx0L.getZExtValue() > Idx1L.getZExtValue() )
56507
56507
std::swap(Idx0L, Idx1L);
56508
- if (Idx0H > Idx1H)
56508
+ if (Idx0H.getZExtValue() > Idx1H.getZExtValue() )
56509
56509
std::swap(Idx0H, Idx1H);
56510
56510
// Commutativity of add allows pairs of factors to reorder.
56511
- if (Idx0L > Idx0H) {
56511
+ if (Idx0L.getZExtValue() > Idx0H.getZExtValue() ) {
56512
56512
std::swap(Idx0L, Idx0H);
56513
56513
std::swap(Idx1L, Idx1H);
56514
56514
}
@@ -56555,39 +56555,26 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
56555
56555
// Attempt to turn this pattern into PMADDWD.
56556
56556
// (add (mul (sext (build_vector)), (sext (build_vector))),
56557
56557
// (mul (sext (build_vector)), (sext (build_vector)))
56558
- static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1 ,
56558
+ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDNode *N ,
56559
56559
const SDLoc &DL, EVT VT,
56560
56560
const X86Subtarget &Subtarget) {
56561
- if (!Subtarget.hasSSE2())
56562
- return SDValue();
56561
+ using namespace SDPatternMatch;
56563
56562
56564
- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL )
56563
+ if (!Subtarget.hasSSE2() )
56565
56564
return SDValue();
56566
56565
56567
56566
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
56568
56567
VT.getVectorNumElements() < 4 ||
56569
56568
!isPowerOf2_32(VT.getVectorNumElements()))
56570
56569
return SDValue();
56571
56570
56572
- SDValue N00 = N0.getOperand(0);
56573
- SDValue N01 = N0.getOperand(1);
56574
- SDValue N10 = N1.getOperand(0);
56575
- SDValue N11 = N1.getOperand(1);
56576
-
56577
56571
// All inputs need to be sign extends.
56578
56572
// TODO: Support ZERO_EXTEND from known positive?
56579
- if (N00.getOpcode() != ISD::SIGN_EXTEND ||
56580
- N01.getOpcode() != ISD::SIGN_EXTEND ||
56581
- N10.getOpcode() != ISD::SIGN_EXTEND ||
56582
- N11.getOpcode() != ISD::SIGN_EXTEND)
56573
+ SDValue N00, N01, N10, N11;
56574
+ if (!sd_match(N, m_Add(m_Mul(m_SExt(m_Value(N00)), m_SExt(m_Value(N01))),
56575
+ m_Mul(m_SExt(m_Value(N10)), m_SExt(m_Value(N11))))))
56583
56576
return SDValue();
56584
56577
56585
- // Peek through the extends.
56586
- N00 = N00.getOperand(0);
56587
- N01 = N01.getOperand(0);
56588
- N10 = N10.getOperand(0);
56589
- N11 = N11.getOperand(0);
56590
-
56591
56578
// Must be extending from vXi16.
56592
56579
EVT InVT = N00.getValueType();
56593
56580
if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
@@ -56614,34 +56601,26 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
56614
56601
SDValue N10Elt = N10.getOperand(i);
56615
56602
SDValue N11Elt = N11.getOperand(i);
56616
56603
// TODO: Be more tolerant to undefs.
56617
- if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56618
- N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56619
- N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56620
- N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56621
- return SDValue();
56622
- auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
56623
- auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
56624
- auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
56625
- auto *ConstN11Elt = dyn_cast<ConstantSDNode> (N11Elt.getOperand(1));
56626
- if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt )
56604
+ SDValue N00In, N01In, N10In, N11In;
56605
+ APInt IdxN00, IdxN01, IdxN10, IdxN11;
56606
+ if (!sd_match(N00Elt, m_BinOp( ISD::EXTRACT_VECTOR_ELT, m_Value(N00In),
56607
+ m_ConstInt(IdxN00))) ||
56608
+ !sd_match(N01Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N01In),
56609
+ m_ConstInt(IdxN01))) ||
56610
+ !sd_match(N10Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N10In),
56611
+ m_ConstInt(IdxN10))) ||
56612
+ !sd_match (N11Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N11In),
56613
+ m_ConstInt(IdxN11))) )
56627
56614
return SDValue();
56628
- unsigned IdxN00 = ConstN00Elt->getZExtValue();
56629
- unsigned IdxN01 = ConstN01Elt->getZExtValue();
56630
- unsigned IdxN10 = ConstN10Elt->getZExtValue();
56631
- unsigned IdxN11 = ConstN11Elt->getZExtValue();
56632
56615
// Add is commutative so indices can be reordered.
56633
- if (IdxN00 > IdxN10) {
56616
+ if (IdxN00.getZExtValue() > IdxN10.getZExtValue() ) {
56634
56617
std::swap(IdxN00, IdxN10);
56635
56618
std::swap(IdxN01, IdxN11);
56636
56619
}
56637
56620
// N0 indices be the even element. N1 indices must be the next odd element.
56638
- if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
56639
- IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
56621
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || IdxN01 != 2 * i ||
56622
+ IdxN11 != 2 * i + 1)
56640
56623
return SDValue();
56641
- SDValue N00In = N00Elt.getOperand(0);
56642
- SDValue N01In = N01Elt.getOperand(0);
56643
- SDValue N10In = N10Elt.getOperand(0);
56644
- SDValue N11In = N11Elt.getOperand(0);
56645
56624
56646
56625
// First time we find an input capture it.
56647
56626
if (!In0) {
@@ -56815,9 +56794,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
56815
56794
if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
56816
56795
return Select;
56817
56796
56818
- if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1 , DL, VT, Subtarget))
56797
+ if (SDValue MAdd = matchPMADDWD(DAG, N , DL, VT, Subtarget))
56819
56798
return MAdd;
56820
- if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1 , DL, VT, Subtarget))
56799
+ if (SDValue MAdd = matchPMADDWD_2(DAG, N , DL, VT, Subtarget))
56821
56800
return MAdd;
56822
56801
if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
56823
56802
return MAdd;
0 commit comments