Skip to content
This repository was archived by the owner on Mar 28, 2020. It is now read-only.

Commit c700b40

Browse files
committed
[X86][AVX512] Detect repeated constant patterns in BUILD_VECTOR suitable for broadcasting.
Check if a build_vector node includes a repeated constant pattern and replace it with a broadcast of that pattern. For example: "build_vector <0, 1, 2, 3, 0, 1, 2, 3>" would be replaced by "broadcast <0, 1, 2, 3>" Differential Revision: https://reviews.llvm.org/D26802 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@288804 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 01544ba commit c700b40

File tree

5 files changed

+1329
-11
lines changed

5 files changed

+1329
-11
lines changed

lib/Target/X86/X86ISelLowering.cpp

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6316,8 +6316,47 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
63166316
return SDValue();
63176317
}
63186318

6319-
/// Attempt to use the vbroadcast instruction to generate a splat value for a
6320-
/// splat BUILD_VECTOR which uses a single scalar load, or a constant.
6319+
static Constant *getConstantVector(MVT VT, APInt SplatValue,
6320+
unsigned SplatBitSize, LLVMContext &C) {
6321+
unsigned ScalarSize = VT.getScalarSizeInBits();
6322+
unsigned NumElm = SplatBitSize / ScalarSize;
6323+
6324+
SmallVector<Constant *, 32> ConstantVec;
6325+
for (unsigned i = 0; i < NumElm; i++) {
6326+
APInt Val = SplatValue.lshr(ScalarSize * i).trunc(ScalarSize);
6327+
Constant *Const;
6328+
if (VT.isFloatingPoint()) {
6329+
assert((ScalarSize == 32 || ScalarSize == 64) &&
6330+
"Unsupported floating point scalar size");
6331+
if (ScalarSize == 32)
6332+
Const = ConstantFP::get(Type::getFloatTy(C), Val.bitsToFloat());
6333+
else
6334+
Const = ConstantFP::get(Type::getDoubleTy(C), Val.bitsToDouble());
6335+
} else
6336+
Const = Constant::getIntegerValue(Type::getIntNTy(C, ScalarSize), Val);
6337+
ConstantVec.push_back(Const);
6338+
}
6339+
return ConstantVector::get(ArrayRef<Constant *>(ConstantVec));
6340+
}
6341+
6342+
static bool isUseOfShuffle(SDNode *N) {
6343+
for (auto *U : N->uses()) {
6344+
if (isTargetShuffle(U->getOpcode()))
6345+
return true;
6346+
if (U->getOpcode() == ISD::BITCAST) // Ignore bitcasts
6347+
return isUseOfShuffle(U);
6348+
}
6349+
return false;
6350+
}
6351+
6352+
/// Attempt to use the vbroadcast instruction to generate a splat value for the
6353+
/// following cases:
6354+
/// 1. A splat BUILD_VECTOR which uses:
6355+
/// a. A single scalar load, or a constant.
6356+
/// b. Repeated pattern of constants (e.g. <0,1,0,1> or <0,1,2,3,0,1,2,3>).
6357+
/// 2. A splat shuffle which uses a scalar_to_vector node which comes from
6358+
/// a scalar load, or a constant.
6359+
///
63216360
/// The VBROADCAST node is returned when a pattern is found,
63226361
/// or SDValue() otherwise.
63236362
static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget &Subtarget,
@@ -6339,8 +6378,82 @@ static SDValue LowerVectorBroadcast(BuildVectorSDNode *BVOp, const X86Subtarget
63396378

63406379
// We need a splat of a single value to use broadcast, and it doesn't
63416380
// make any sense if the value is only in one element of the vector.
6342-
if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1)
6381+
if (!Ld || (VT.getVectorNumElements() - UndefElements.count()) <= 1) {
6382+
APInt SplatValue, Undef;
6383+
unsigned SplatBitSize;
6384+
bool HasUndef;
6385+
// Check if this is a repeated constant pattern suitable for broadcasting.
6386+
if (BVOp->isConstantSplat(SplatValue, Undef, SplatBitSize, HasUndef) &&
6387+
SplatBitSize > VT.getScalarSizeInBits() &&
6388+
SplatBitSize < VT.getSizeInBits()) {
6389+
// Avoid replacing with broadcast when it's a use of a shuffle
6390+
// instruction to preserve the present custom lowering of shuffles.
6391+
if (isUseOfShuffle(BVOp) || BVOp->hasOneUse())
6392+
return SDValue();
6393+
// replace BUILD_VECTOR with broadcast of the repeated constants.
6394+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6395+
LLVMContext *Ctx = DAG.getContext();
6396+
MVT PVT = TLI.getPointerTy(DAG.getDataLayout());
6397+
if (Subtarget.hasAVX()) {
6398+
if (SplatBitSize <= 64 && Subtarget.hasAVX2() &&
6399+
!(SplatBitSize == 64 && Subtarget.is32Bit())) {
6400+
// Splatted value can fit in one INTEGER constant in constant pool.
6401+
// Load the constant and broadcast it.
6402+
MVT CVT = MVT::getIntegerVT(SplatBitSize);
6403+
Type *ScalarTy = Type::getIntNTy(*Ctx, SplatBitSize);
6404+
Constant *C = Constant::getIntegerValue(ScalarTy, SplatValue);
6405+
SDValue CP = DAG.getConstantPool(C, PVT);
6406+
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
6407+
6408+
unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
6409+
Ld = DAG.getLoad(
6410+
CVT, dl, DAG.getEntryNode(), CP,
6411+
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
6412+
Alignment);
6413+
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
6414+
MVT::getVectorVT(CVT, Repeat), Ld);
6415+
return DAG.getBitcast(VT, Brdcst);
6416+
} else if (SplatBitSize == 32 || SplatBitSize == 64) {
6417+
// Splatted value can fit in one FLOAT constant in constant pool.
6418+
// Load the constant and broadcast it.
6419+
// AVX have support for 32 and 64 bit broadcast for floats only.
6420+
// No 64bit integer in 32bit subtarget.
6421+
MVT CVT = MVT::getFloatingPointVT(SplatBitSize);
6422+
Constant *C = SplatBitSize == 32
6423+
? ConstantFP::get(Type::getFloatTy(*Ctx),
6424+
SplatValue.bitsToFloat())
6425+
: ConstantFP::get(Type::getDoubleTy(*Ctx),
6426+
SplatValue.bitsToDouble());
6427+
SDValue CP = DAG.getConstantPool(C, PVT);
6428+
unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
6429+
6430+
unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
6431+
Ld = DAG.getLoad(
6432+
CVT, dl, DAG.getEntryNode(), CP,
6433+
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
6434+
Alignment);
6435+
SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
6436+
MVT::getVectorVT(CVT, Repeat), Ld);
6437+
return DAG.getBitcast(VT, Brdcst);
6438+
} else if (SplatBitSize > 64) {
6439+
// Load the vector of constants and broadcast it.
6440+
MVT CVT = VT.getScalarType();
6441+
Constant *VecC = getConstantVector(VT, SplatValue, SplatBitSize,
6442+
*Ctx);
6443+
SDValue VCP = DAG.getConstantPool(VecC, PVT);
6444+
unsigned NumElm = SplatBitSize / VT.getScalarSizeInBits();
6445+
unsigned Alignment = cast<ConstantPoolSDNode>(VCP)->getAlignment();
6446+
Ld = DAG.getLoad(
6447+
MVT::getVectorVT(CVT, NumElm), dl, DAG.getEntryNode(), VCP,
6448+
MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
6449+
Alignment);
6450+
SDValue Brdcst = DAG.getNode(X86ISD::SUBV_BROADCAST, dl, VT, Ld);
6451+
return DAG.getBitcast(VT, Brdcst);
6452+
}
6453+
}
6454+
}
63436455
return SDValue();
6456+
}
63446457

63456458
bool ConstSplatVal =
63466459
(Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP);

test/CodeGen/X86/avg.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,7 +2132,7 @@ define void @avg_v64i8_const(<64 x i8>* %a) {
21322132
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
21332133
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
21342134
; AVX512F-NEXT: vpmovzxbd {{.*#+}} zmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero,mem[8],zero,zero,zero,mem[9],zero,zero,zero,mem[10],zero,zero,zero,mem[11],zero,zero,zero,mem[12],zero,zero,zero,mem[13],zero,zero,zero,mem[14],zero,zero,zero,mem[15],zero,zero,zero
2135-
; AVX512F-NEXT: vmovdqa32 {{.*#+}} zmm4 = [1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8]
2135+
; AVX512F-NEXT: vbroadcasti64x4 {{.*#+}} zmm4 = mem[0,1,2,3,0,1,2,3]
21362136
; AVX512F-NEXT: vpaddd %zmm4, %zmm3, %zmm3
21372137
; AVX512F-NEXT: vpaddd %zmm4, %zmm2, %zmm2
21382138
; AVX512F-NEXT: vpaddd %zmm4, %zmm1, %zmm1
@@ -2405,7 +2405,7 @@ define void @avg_v32i16_const(<32 x i16>* %a) {
24052405
; AVX512F: # BB#0:
24062406
; AVX512F-NEXT: vpmovzxwd {{.*#+}} zmm0 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
24072407
; AVX512F-NEXT: vpmovzxwd {{.*#+}} zmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero,mem[8],zero,mem[9],zero,mem[10],zero,mem[11],zero,mem[12],zero,mem[13],zero,mem[14],zero,mem[15],zero
2408-
; AVX512F-NEXT: vmovdqa32 {{.*#+}} zmm2 = [1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8]
2408+
; AVX512F-NEXT: vbroadcasti64x4 {{.*#+}} zmm2 = mem[0,1,2,3,0,1,2,3]
24092409
; AVX512F-NEXT: vpaddd %zmm2, %zmm1, %zmm1
24102410
; AVX512F-NEXT: vpaddd %zmm2, %zmm0, %zmm0
24112411
; AVX512F-NEXT: vpsrld $1, %zmm0, %zmm0

0 commit comments

Comments
 (0)