@@ -41799,6 +41799,40 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
41799
41799
return SDValue();
41800
41800
}
41801
41801
41802
+ // (mul (zext a), (sext, b))
41803
+ static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0,
41804
+ SDValue &Op1) {
41805
+ Op0 = Mul.getOperand(0);
41806
+ Op1 = Mul.getOperand(1);
41807
+
41808
+ // The operand1 should be signed extend
41809
+ if (Op0.getOpcode() == ISD::SIGN_EXTEND)
41810
+ std::swap(Op0, Op1);
41811
+
41812
+ if (Op0.getOpcode() != ISD::ZERO_EXTEND)
41813
+ return false;
41814
+
41815
+ auto IsFreeTruncation = [](SDValue &Op) -> bool {
41816
+ if ((Op.getOpcode() == ISD::ZERO_EXTEND ||
41817
+ Op.getOpcode() == ISD::SIGN_EXTEND) &&
41818
+ Op.getOperand(0).getScalarValueSizeInBits() <= 8)
41819
+ return true;
41820
+
41821
+ // TODO: Support contant value.
41822
+ return false;
41823
+ };
41824
+
41825
+ // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
41826
+ // value, we need to check Op0 is zero extended value. Op1 should be signed
41827
+ // value, so we just check the signed bits.
41828
+ if ((IsFreeTruncation(Op0) &&
41829
+ DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8) &&
41830
+ (IsFreeTruncation(Op1) && DAG.ComputeMaxSignificantBits(Op1) <= 8))
41831
+ return true;
41832
+
41833
+ return false;
41834
+ }
41835
+
41802
41836
// Given a ABS node, detect the following pattern:
41803
41837
// (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))).
41804
41838
// This is useful as it is the input into a SAD pattern.
@@ -41820,6 +41854,50 @@ static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) {
41820
41854
return true;
41821
41855
}
41822
41856
41857
+ static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
41858
+ unsigned &LogBias, const SDLoc &DL,
41859
+ const X86Subtarget &Subtarget) {
41860
+ // Extend or truncate to MVT::i8 first.
41861
+ MVT Vi8VT =
41862
+ MVT::getVectorVT(MVT::i8, LHS.getValueType().getVectorElementCount());
41863
+ LHS = DAG.getZExtOrTrunc(LHS, DL, Vi8VT);
41864
+ RHS = DAG.getSExtOrTrunc(RHS, DL, Vi8VT);
41865
+
41866
+ // VPDPBUSD(<16 x i32>C, <16 x i8>A, <16 x i8>B). For each dst element
41867
+ // C[0] = C[0] + A[0]B[0] + A[1]B[1] + A[2]B[2] + A[3]B[3].
41868
+ // The src A, B element type is i8, but the dst C element type is i32.
41869
+ // When we calculate the reduce stage, we use src vector type vXi8 for it
41870
+ // so we need logbias 2 to avoid extra 2 stages.
41871
+ LogBias = 2;
41872
+
41873
+ unsigned RegSize = std::max(128u, (unsigned)Vi8VT.getSizeInBits());
41874
+ if (Subtarget.hasVNNI() && !Subtarget.hasVLX())
41875
+ RegSize = std::max(512u, RegSize);
41876
+
41877
+ // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
41878
+ // fill in the missing vector elements with 0.
41879
+ unsigned NumConcat = RegSize / Vi8VT.getSizeInBits();
41880
+ SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, Vi8VT));
41881
+ Ops[0] = LHS;
41882
+ MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
41883
+ SDValue DpOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
41884
+ Ops[0] = RHS;
41885
+ SDValue DpOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
41886
+
41887
+ // Actually build the DotProduct, split as 256/512 bits for
41888
+ // AVXVNNI/AVX512VNNI.
41889
+ auto DpBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
41890
+ ArrayRef<SDValue> Ops) {
41891
+ MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
41892
+ return DAG.getNode(X86ISD::VPDPBUSD, DL, VT, Ops);
41893
+ };
41894
+ MVT DpVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
41895
+ SDValue Zero = DAG.getConstant(0, DL, DpVT);
41896
+
41897
+ return SplitOpsAndApply(DAG, Subtarget, DL, DpVT, {Zero, DpOp0, DpOp1},
41898
+ DpBuilder, false);
41899
+ }
41900
+
41823
41901
// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
41824
41902
// to these zexts.
41825
41903
static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
@@ -42069,6 +42147,77 @@ static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG,
42069
42147
return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext);
42070
42148
}
42071
42149
42150
+ static SDValue combineVPDPBUSDPattern(SDNode *Extract, SelectionDAG &DAG,
42151
+ const X86Subtarget &Subtarget) {
42152
+ if (!Subtarget.hasVNNI() && !Subtarget.hasAVXVNNI())
42153
+ return SDValue();
42154
+
42155
+ EVT ExtractVT = Extract->getValueType(0);
42156
+ // Verify the type we're extracting is i32, as the output element type of
42157
+ // vpdpbusd is i32.
42158
+ if (ExtractVT != MVT::i32)
42159
+ return SDValue();
42160
+
42161
+ EVT VT = Extract->getOperand(0).getValueType();
42162
+ if (!isPowerOf2_32(VT.getVectorNumElements()))
42163
+ return SDValue();
42164
+
42165
+ // Match shuffle + add pyramid.
42166
+ ISD::NodeType BinOp;
42167
+ SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
42168
+
42169
+ // We can't combine to vpdpbusd for zext, because each of the 4 multiplies
42170
+ // done by vpdpbusd compute a signed 16-bit product that will be sign extended
42171
+ // before adding into the accumulator.
42172
+ // TODO:
42173
+ // We also need to verify that the multiply has at least 2x the number of bits
42174
+ // of the input. We shouldn't match
42175
+ // (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))).
42176
+ // if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND))
42177
+ // Root = Root.getOperand(0);
42178
+
42179
+ // If there was a match, we want Root to be a mul.
42180
+ if (!Root || Root.getOpcode() != ISD::MUL)
42181
+ return SDValue();
42182
+
42183
+ // Check whether we have an extend and mul pattern
42184
+ SDValue LHS, RHS;
42185
+ if (!detectExtMul(DAG, Root, LHS, RHS))
42186
+ return SDValue();
42187
+
42188
+ // Create the dot product instruction.
42189
+ SDLoc DL(Extract);
42190
+ unsigned StageBias;
42191
+ SDValue DP = createVPDPBUSD(DAG, LHS, RHS, StageBias, DL, Subtarget);
42192
+
42193
+ // If the original vector was wider than 4 elements, sum over the results
42194
+ // in the DP vector.
42195
+ unsigned Stages = Log2_32(VT.getVectorNumElements());
42196
+ EVT DpVT = DP.getValueType();
42197
+
42198
+ if (Stages > StageBias) {
42199
+ unsigned DpElems = DpVT.getVectorNumElements();
42200
+
42201
+ for (unsigned i = Stages - StageBias; i > 0; --i) {
42202
+ SmallVector<int, 16> Mask(DpElems, -1);
42203
+ for (unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j)
42204
+ Mask[j] = MaskEnd + j;
42205
+
42206
+ SDValue Shuffle =
42207
+ DAG.getVectorShuffle(DpVT, DL, DP, DAG.getUNDEF(DpVT), Mask);
42208
+ DP = DAG.getNode(ISD::ADD, DL, DpVT, DP, Shuffle);
42209
+ }
42210
+ }
42211
+
42212
+ // Return the lowest ExtractSizeInBits bits.
42213
+ EVT ResVT =
42214
+ EVT::getVectorVT(*DAG.getContext(), ExtractVT,
42215
+ DpVT.getSizeInBits() / ExtractVT.getSizeInBits());
42216
+ DP = DAG.getBitcast(ResVT, DP);
42217
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, DP,
42218
+ Extract->getOperand(1));
42219
+ }
42220
+
42072
42221
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
42073
42222
const X86Subtarget &Subtarget) {
42074
42223
// PSADBW is only supported on SSE2 and up.
@@ -42676,6 +42825,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
42676
42825
if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
42677
42826
return SAD;
42678
42827
42828
+ if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
42829
+ return VPDPBUSD;
42830
+
42679
42831
// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
42680
42832
if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
42681
42833
return Cmp;
0 commit comments