@@ -46058,22 +46058,18 @@ static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0,
46058
46058
// Given a ABS node, detect the following pattern:
46059
46059
// (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))).
46060
46060
// This is useful as it is the input into a SAD pattern.
46061
- static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) {
46062
- SDValue AbsOp1 = Abs->getOperand(0);
46063
- if (AbsOp1.getOpcode() != ISD::SUB)
46064
- return false;
46065
-
46066
- Op0 = AbsOp1.getOperand(0);
46067
- Op1 = AbsOp1.getOperand(1);
46061
+ static bool detectZextAbsDiff(SDValue Abs, SDValue &Op0, SDValue &Op1) {
46062
+ using namespace SDPatternMatch;
46068
46063
46069
46064
// Check if the operands of the sub are zero-extended from vectors of i8.
46070
- if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
46071
- Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
46072
- Op1.getOpcode() != ISD::ZERO_EXTEND ||
46073
- Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
46074
- return false;
46075
-
46076
- return true;
46065
+ EVT SrcVT0, SrcVT1;
46066
+ return sd_match(
46067
+ Abs,
46068
+ m_UnaryOp(ISD::ABS,
46069
+ m_Sub(m_AllOf(m_Value(Op0), m_ZExt(m_VT(SrcVT0))),
46070
+ m_AllOf(m_Value(Op1), m_ZExt(m_VT(SrcVT1)))))) &&
46071
+ SrcVT0.getVectorElementType() == MVT::i8 &&
46072
+ SrcVT1.getVectorElementType() == MVT::i8;
46077
46073
}
46078
46074
46079
46075
static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
@@ -46455,6 +46451,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
46455
46451
// Match shuffle + add pyramid.
46456
46452
ISD::NodeType BinOp;
46457
46453
SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
46454
+ if (!Root)
46455
+ return SDValue();
46458
46456
46459
46457
// The operand is expected to be zero extended from i8
46460
46458
// (verified in detectZextAbsDiff).
@@ -46464,16 +46462,11 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
46464
46462
// Also the sign extend is basically zero extend
46465
46463
// (extends the sign bit which is zero).
46466
46464
// So it is correct to skip the sign/zero extend instruction.
46467
- if (Root && (Root .getOpcode() == ISD::SIGN_EXTEND ||
46468
- Root.getOpcode() == ISD::ZERO_EXTEND ||
46469
- Root.getOpcode() == ISD::ANY_EXTEND) )
46465
+ if (Root.getOpcode() == ISD::SIGN_EXTEND ||
46466
+ Root.getOpcode() == ISD::ZERO_EXTEND ||
46467
+ Root.getOpcode() == ISD::ANY_EXTEND)
46470
46468
Root = Root.getOperand(0);
46471
46469
46472
- // If there was a match, we want Root to be a select that is the root of an
46473
- // abs-diff pattern.
46474
- if (!Root || Root.getOpcode() != ISD::ABS)
46475
- return SDValue();
46476
-
46477
46470
// Check whether we have an abs-diff pattern feeding into the select.
46478
46471
SDValue Zext0, Zext1;
46479
46472
if (!detectZextAbsDiff(Root, Zext0, Zext1))
0 commit comments