Skip to content

Commit 21babe4

Browse files
committed
[X86] Combine reduce(add (mul x, y)) to VNNI instruction.
For below C code, we can use VNNI to combine the mul and add operation. int usdot_prod_qi(unsigned char *restrict a, char *restrict b, int c, int n) { int i; for (i = 0; i < 32; i++) { c += ((int)a[i] * (int)b[i]); } return c; } We didn't support the combine acoss basic block in this patch. Differential Revision: https://reviews.llvm.org/D116039
1 parent 3aec4b3 commit 21babe4

File tree

4 files changed

+892
-7
lines changed

4 files changed

+892
-7
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41799,6 +41799,40 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
4179941799
return SDValue();
4180041800
}
4180141801

41802+
// (mul (zext a), (sext, b))
41803+
static bool detectExtMul(SelectionDAG &DAG, const SDValue &Mul, SDValue &Op0,
41804+
SDValue &Op1) {
41805+
Op0 = Mul.getOperand(0);
41806+
Op1 = Mul.getOperand(1);
41807+
41808+
// The operand1 should be signed extend
41809+
if (Op0.getOpcode() == ISD::SIGN_EXTEND)
41810+
std::swap(Op0, Op1);
41811+
41812+
if (Op0.getOpcode() != ISD::ZERO_EXTEND)
41813+
return false;
41814+
41815+
auto IsFreeTruncation = [](SDValue &Op) -> bool {
41816+
if ((Op.getOpcode() == ISD::ZERO_EXTEND ||
41817+
Op.getOpcode() == ISD::SIGN_EXTEND) &&
41818+
Op.getOperand(0).getScalarValueSizeInBits() <= 8)
41819+
return true;
41820+
41821+
// TODO: Support contant value.
41822+
return false;
41823+
};
41824+
41825+
// (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
41826+
// value, we need to check Op0 is zero extended value. Op1 should be signed
41827+
// value, so we just check the signed bits.
41828+
if ((IsFreeTruncation(Op0) &&
41829+
DAG.computeKnownBits(Op0).countMaxActiveBits() <= 8) &&
41830+
(IsFreeTruncation(Op1) && DAG.ComputeMaxSignificantBits(Op1) <= 8))
41831+
return true;
41832+
41833+
return false;
41834+
}
41835+
4180241836
// Given a ABS node, detect the following pattern:
4180341837
// (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))).
4180441838
// This is useful as it is the input into a SAD pattern.
@@ -41820,6 +41854,50 @@ static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) {
4182041854
return true;
4182141855
}
4182241856

41857+
static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
41858+
unsigned &LogBias, const SDLoc &DL,
41859+
const X86Subtarget &Subtarget) {
41860+
// Extend or truncate to MVT::i8 first.
41861+
MVT Vi8VT =
41862+
MVT::getVectorVT(MVT::i8, LHS.getValueType().getVectorElementCount());
41863+
LHS = DAG.getZExtOrTrunc(LHS, DL, Vi8VT);
41864+
RHS = DAG.getSExtOrTrunc(RHS, DL, Vi8VT);
41865+
41866+
// VPDPBUSD(<16 x i32>C, <16 x i8>A, <16 x i8>B). For each dst element
41867+
// C[0] = C[0] + A[0]B[0] + A[1]B[1] + A[2]B[2] + A[3]B[3].
41868+
// The src A, B element type is i8, but the dst C element type is i32.
41869+
// When we calculate the reduce stage, we use src vector type vXi8 for it
41870+
// so we need logbias 2 to avoid extra 2 stages.
41871+
LogBias = 2;
41872+
41873+
unsigned RegSize = std::max(128u, (unsigned)Vi8VT.getSizeInBits());
41874+
if (Subtarget.hasVNNI() && !Subtarget.hasVLX())
41875+
RegSize = std::max(512u, RegSize);
41876+
41877+
// "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
41878+
// fill in the missing vector elements with 0.
41879+
unsigned NumConcat = RegSize / Vi8VT.getSizeInBits();
41880+
SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, Vi8VT));
41881+
Ops[0] = LHS;
41882+
MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
41883+
SDValue DpOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
41884+
Ops[0] = RHS;
41885+
SDValue DpOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
41886+
41887+
// Actually build the DotProduct, split as 256/512 bits for
41888+
// AVXVNNI/AVX512VNNI.
41889+
auto DpBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
41890+
ArrayRef<SDValue> Ops) {
41891+
MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
41892+
return DAG.getNode(X86ISD::VPDPBUSD, DL, VT, Ops);
41893+
};
41894+
MVT DpVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
41895+
SDValue Zero = DAG.getConstant(0, DL, DpVT);
41896+
41897+
return SplitOpsAndApply(DAG, Subtarget, DL, DpVT, {Zero, DpOp0, DpOp1},
41898+
DpBuilder, false);
41899+
}
41900+
4182341901
// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
4182441902
// to these zexts.
4182541903
static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
@@ -42069,6 +42147,77 @@ static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG,
4206942147
return DAG.getNode(ISD::SUB, DL, ExtractVT, Zero, Zext);
4207042148
}
4207142149

42150+
static SDValue combineVPDPBUSDPattern(SDNode *Extract, SelectionDAG &DAG,
42151+
const X86Subtarget &Subtarget) {
42152+
if (!Subtarget.hasVNNI() && !Subtarget.hasAVXVNNI())
42153+
return SDValue();
42154+
42155+
EVT ExtractVT = Extract->getValueType(0);
42156+
// Verify the type we're extracting is i32, as the output element type of
42157+
// vpdpbusd is i32.
42158+
if (ExtractVT != MVT::i32)
42159+
return SDValue();
42160+
42161+
EVT VT = Extract->getOperand(0).getValueType();
42162+
if (!isPowerOf2_32(VT.getVectorNumElements()))
42163+
return SDValue();
42164+
42165+
// Match shuffle + add pyramid.
42166+
ISD::NodeType BinOp;
42167+
SDValue Root = DAG.matchBinOpReduction(Extract, BinOp, {ISD::ADD});
42168+
42169+
// We can't combine to vpdpbusd for zext, because each of the 4 multiplies
42170+
// done by vpdpbusd compute a signed 16-bit product that will be sign extended
42171+
// before adding into the accumulator.
42172+
// TODO:
42173+
// We also need to verify that the multiply has at least 2x the number of bits
42174+
// of the input. We shouldn't match
42175+
// (sign_extend (mul (vXi9 (zext (vXi8 X))), (vXi9 (zext (vXi8 Y)))).
42176+
// if (Root && (Root.getOpcode() == ISD::SIGN_EXTEND))
42177+
// Root = Root.getOperand(0);
42178+
42179+
// If there was a match, we want Root to be a mul.
42180+
if (!Root || Root.getOpcode() != ISD::MUL)
42181+
return SDValue();
42182+
42183+
// Check whether we have an extend and mul pattern
42184+
SDValue LHS, RHS;
42185+
if (!detectExtMul(DAG, Root, LHS, RHS))
42186+
return SDValue();
42187+
42188+
// Create the dot product instruction.
42189+
SDLoc DL(Extract);
42190+
unsigned StageBias;
42191+
SDValue DP = createVPDPBUSD(DAG, LHS, RHS, StageBias, DL, Subtarget);
42192+
42193+
// If the original vector was wider than 4 elements, sum over the results
42194+
// in the DP vector.
42195+
unsigned Stages = Log2_32(VT.getVectorNumElements());
42196+
EVT DpVT = DP.getValueType();
42197+
42198+
if (Stages > StageBias) {
42199+
unsigned DpElems = DpVT.getVectorNumElements();
42200+
42201+
for (unsigned i = Stages - StageBias; i > 0; --i) {
42202+
SmallVector<int, 16> Mask(DpElems, -1);
42203+
for (unsigned j = 0, MaskEnd = 1 << (i - 1); j < MaskEnd; ++j)
42204+
Mask[j] = MaskEnd + j;
42205+
42206+
SDValue Shuffle =
42207+
DAG.getVectorShuffle(DpVT, DL, DP, DAG.getUNDEF(DpVT), Mask);
42208+
DP = DAG.getNode(ISD::ADD, DL, DpVT, DP, Shuffle);
42209+
}
42210+
}
42211+
42212+
// Return the lowest ExtractSizeInBits bits.
42213+
EVT ResVT =
42214+
EVT::getVectorVT(*DAG.getContext(), ExtractVT,
42215+
DpVT.getSizeInBits() / ExtractVT.getSizeInBits());
42216+
DP = DAG.getBitcast(ResVT, DP);
42217+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT, DP,
42218+
Extract->getOperand(1));
42219+
}
42220+
4207242221
static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
4207342222
const X86Subtarget &Subtarget) {
4207442223
// PSADBW is only supported on SSE2 and up.
@@ -42676,6 +42825,9 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
4267642825
if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget))
4267742826
return SAD;
4267842827

42828+
if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget))
42829+
return VPDPBUSD;
42830+
4267942831
// Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK.
4268042832
if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget))
4268142833
return Cmp;

llvm/lib/Target/X86/X86PartialReduction.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "X86.h"
16+
#include "X86TargetMachine.h"
1617
#include "llvm/Analysis/ValueTracking.h"
1718
#include "llvm/CodeGen/TargetPassConfig.h"
1819
#include "llvm/IR/Constants.h"
20+
#include "llvm/IR/IRBuilder.h"
1921
#include "llvm/IR/Instructions.h"
2022
#include "llvm/IR/IntrinsicsX86.h"
21-
#include "llvm/IR/IRBuilder.h"
2223
#include "llvm/IR/Operator.h"
2324
#include "llvm/Pass.h"
24-
#include "X86TargetMachine.h"
25+
#include "llvm/Support/KnownBits.h"
2526

2627
using namespace llvm;
2728

@@ -49,7 +50,7 @@ class X86PartialReduction : public FunctionPass {
4950
}
5051

5152
private:
52-
bool tryMAddReplacement(Instruction *Op);
53+
bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB);
5354
bool trySADReplacement(Instruction *Op);
5455
};
5556
}
@@ -63,7 +64,46 @@ char X86PartialReduction::ID = 0;
6364
INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
6465
"X86 Partial Reduction", false, false)
6566

66-
bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
67+
// This function should be aligned with detectExtMul() in X86ISelLowering.cpp.
68+
static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul,
69+
const DataLayout *DL) {
70+
if (!ST->hasVNNI() && !ST->hasAVXVNNI())
71+
return false;
72+
73+
Value *LHS = Mul->getOperand(0);
74+
Value *RHS = Mul->getOperand(1);
75+
76+
if (isa<SExtInst>(LHS))
77+
std::swap(LHS, RHS);
78+
79+
if (!isa<ZExtInst>(LHS))
80+
return false;
81+
82+
auto IsFreeTruncation = [&](Value *Op) {
83+
if (auto *Cast = dyn_cast<CastInst>(Op)) {
84+
if (Cast->getParent() == Mul->getParent() &&
85+
(Cast->getOpcode() == Instruction::SExt ||
86+
Cast->getOpcode() == Instruction::ZExt) &&
87+
Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8)
88+
return true;
89+
}
90+
// TODO: Support constant in ISel.
91+
return false;
92+
};
93+
94+
// (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned
95+
// value, we need to check LHS is zero extended value. RHS should be signed
96+
// value, so we just check the signed bits.
97+
if ((IsFreeTruncation(LHS) &&
98+
computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) &&
99+
(IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8))
100+
return true;
101+
102+
return false;
103+
}
104+
105+
bool X86PartialReduction::tryMAddReplacement(Instruction *Op,
106+
bool ReduceInOneBB) {
67107
if (!ST->hasSSE2())
68108
return false;
69109

@@ -82,6 +122,13 @@ bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
82122
Value *LHS = Mul->getOperand(0);
83123
Value *RHS = Mul->getOperand(1);
84124

125+
// If the target support VNNI, leave it to ISel to combine reduce operation
126+
// to VNNI instruction.
127+
// TODO: we can support transforming reduce to VNNI intrinsic for across block
128+
// in this pass.
129+
if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL))
130+
return false;
131+
85132
// LHS and RHS should be only used once or if they are the same then only
86133
// used twice. Only check this when SSE4.1 is enabled and we have zext/sext
87134
// instructions, otherwise we use punpck to emulate zero extend in stages. The
@@ -300,7 +347,9 @@ bool X86PartialReduction::trySADReplacement(Instruction *Op) {
300347

301348
// Walk backwards from the ExtractElementInst and determine if it is the end of
302349
// a horizontal reduction. Return the input to the reduction if we find one.
303-
static Value *matchAddReduction(const ExtractElementInst &EE) {
350+
static Value *matchAddReduction(const ExtractElementInst &EE,
351+
bool &ReduceInOneBB) {
352+
ReduceInOneBB = true;
304353
// Make sure we're extracting index 0.
305354
auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
306355
if (!Index || !Index->isNullValue())
@@ -309,6 +358,8 @@ static Value *matchAddReduction(const ExtractElementInst &EE) {
309358
const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
310359
if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
311360
return nullptr;
361+
if (EE.getParent() != BO->getParent())
362+
ReduceInOneBB = false;
312363

313364
unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements();
314365
// Ensure the reduction size is a power of 2.
@@ -321,6 +372,8 @@ static Value *matchAddReduction(const ExtractElementInst &EE) {
321372
const auto *BO = dyn_cast<BinaryOperator>(Op);
322373
if (!BO || BO->getOpcode() != Instruction::Add)
323374
return nullptr;
375+
if (EE.getParent() != BO->getParent())
376+
ReduceInOneBB = false;
324377

325378
// If this isn't the first add, then it should only have 2 users, the
326379
// shuffle and another add which we checked in the previous iteration.
@@ -460,17 +513,18 @@ bool X86PartialReduction::runOnFunction(Function &F) {
460513
if (!EE)
461514
continue;
462515

516+
bool ReduceInOneBB;
463517
// First find a reduction tree.
464518
// FIXME: Do we need to handle other opcodes than Add?
465-
Value *Root = matchAddReduction(*EE);
519+
Value *Root = matchAddReduction(*EE, ReduceInOneBB);
466520
if (!Root)
467521
continue;
468522

469523
SmallVector<Instruction *, 8> Leaves;
470524
collectLeaves(Root, Leaves);
471525

472526
for (Instruction *I : Leaves) {
473-
if (tryMAddReplacement(I)) {
527+
if (tryMAddReplacement(I, ReduceInOneBB)) {
474528
MadeChange = true;
475529
continue;
476530
}

0 commit comments

Comments
 (0)