Skip to content

Commit b16e868

Browse files
committed
[CodeGenPrepare][X86] Teach optimizeGatherScatterInst to turn a splat pointer into GEP with scalar base and 0 index
This helps SelectionDAGBuilder recognize the splat can be used as a uniform base. Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D86371
1 parent ca13437 commit b16e868

File tree

5 files changed

+118
-96
lines changed

5 files changed

+118
-96
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ int getSplatIndex(ArrayRef<int> Mask);
358358
/// Get splat value if the input is a splat vector or return nullptr.
359359
/// The value may be extracted from a splat constants vector or from
360360
/// a sequence of instructions that broadcast a single value into a vector.
361-
const Value *getSplatValue(const Value *V);
361+
Value *getSplatValue(const Value *V);
362362

363363
/// Return true if each element of the vector value \p V is poisoned or equal to
364364
/// every other non-poisoned element. If an index element is specified, either

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ int llvm::getSplatIndex(ArrayRef<int> Mask) {
342342
/// This function is not fully general. It checks only 2 cases:
343343
/// the input value is (1) a splat constant vector or (2) a sequence
344344
/// of instructions that broadcasts a scalar at element 0.
345-
const llvm::Value *llvm::getSplatValue(const Value *V) {
345+
Value *llvm::getSplatValue(const Value *V) {
346346
if (isa<VectorType>(V->getType()))
347347
if (auto *C = dyn_cast<Constant>(V))
348348
return C->getSplatValue();

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 89 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5314,88 +5314,112 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
53145314
/// zero index.
53155315
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
53165316
Value *Ptr) {
5317-
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
5318-
if (!GEP || !GEP->hasIndices())
5317+
// FIXME: Support scalable vectors.
5318+
if (isa<ScalableVectorType>(Ptr->getType()))
53195319
return false;
53205320

5321-
// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
5322-
// FIXME: We should support this by sinking the GEP.
5323-
if (MemoryInst->getParent() != GEP->getParent())
5324-
return false;
5325-
5326-
SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
5321+
Value *NewAddr;
53275322

5328-
bool RewriteGEP = false;
5323+
if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
5324+
// Don't optimize GEPs that don't have indices.
5325+
if (!GEP->hasIndices())
5326+
return false;
53295327

5330-
if (Ops[0]->getType()->isVectorTy()) {
5331-
Ops[0] = const_cast<Value *>(getSplatValue(Ops[0]));
5332-
if (!Ops[0])
5328+
// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
5329+
// FIXME: We should support this by sinking the GEP.
5330+
if (MemoryInst->getParent() != GEP->getParent())
53335331
return false;
5334-
RewriteGEP = true;
5335-
}
53365332

5337-
unsigned FinalIndex = Ops.size() - 1;
5333+
SmallVector<Value *, 2> Ops(GEP->op_begin(), GEP->op_end());
53385334

5339-
// Ensure all but the last index is 0.
5340-
// FIXME: This isn't strictly required. All that's required is that they are
5341-
// all scalars or splats.
5342-
for (unsigned i = 1; i < FinalIndex; ++i) {
5343-
auto *C = dyn_cast<Constant>(Ops[i]);
5344-
if (!C)
5345-
return false;
5346-
if (isa<VectorType>(C->getType()))
5347-
C = C->getSplatValue();
5348-
auto *CI = dyn_cast_or_null<ConstantInt>(C);
5349-
if (!CI || !CI->isZero())
5350-
return false;
5351-
// Scalarize the index if needed.
5352-
Ops[i] = CI;
5353-
}
5354-
5355-
// Try to scalarize the final index.
5356-
if (Ops[FinalIndex]->getType()->isVectorTy()) {
5357-
if (Value *V = const_cast<Value *>(getSplatValue(Ops[FinalIndex]))) {
5358-
auto *C = dyn_cast<ConstantInt>(V);
5359-
// Don't scalarize all zeros vector.
5360-
if (!C || !C->isZero()) {
5361-
Ops[FinalIndex] = V;
5362-
RewriteGEP = true;
5363-
}
5335+
bool RewriteGEP = false;
5336+
5337+
if (Ops[0]->getType()->isVectorTy()) {
5338+
Ops[0] = getSplatValue(Ops[0]);
5339+
if (!Ops[0])
5340+
return false;
5341+
RewriteGEP = true;
53645342
}
5365-
}
53665343

5367-
// If we made any changes or the we have extra operands, we need to generate
5368-
// new instructions.
5369-
if (!RewriteGEP && Ops.size() == 2)
5370-
return false;
5344+
unsigned FinalIndex = Ops.size() - 1;
53715345

5372-
unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
5346+
// Ensure all but the last index is 0.
5347+
// FIXME: This isn't strictly required. All that's required is that they are
5348+
// all scalars or splats.
5349+
for (unsigned i = 1; i < FinalIndex; ++i) {
5350+
auto *C = dyn_cast<Constant>(Ops[i]);
5351+
if (!C)
5352+
return false;
5353+
if (isa<VectorType>(C->getType()))
5354+
C = C->getSplatValue();
5355+
auto *CI = dyn_cast_or_null<ConstantInt>(C);
5356+
if (!CI || !CI->isZero())
5357+
return false;
5358+
// Scalarize the index if needed.
5359+
Ops[i] = CI;
5360+
}
5361+
5362+
// Try to scalarize the final index.
5363+
if (Ops[FinalIndex]->getType()->isVectorTy()) {
5364+
if (Value *V = getSplatValue(Ops[FinalIndex])) {
5365+
auto *C = dyn_cast<ConstantInt>(V);
5366+
// Don't scalarize all zeros vector.
5367+
if (!C || !C->isZero()) {
5368+
Ops[FinalIndex] = V;
5369+
RewriteGEP = true;
5370+
}
5371+
}
5372+
}
53735373

5374-
IRBuilder<> Builder(MemoryInst);
5374+
// If we made any changes or the we have extra operands, we need to generate
5375+
// new instructions.
5376+
if (!RewriteGEP && Ops.size() == 2)
5377+
return false;
53755378

5376-
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
5379+
unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
53775380

5378-
Value *NewAddr;
5381+
IRBuilder<> Builder(MemoryInst);
53795382

5380-
// If the final index isn't a vector, emit a scalar GEP containing all ops
5381-
// and a vector GEP with all zeroes final index.
5382-
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
5383-
NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
5384-
auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
5385-
NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
5386-
} else {
5387-
Value *Base = Ops[0];
5388-
Value *Index = Ops[FinalIndex];
5383+
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
53895384

5390-
// Create a scalar GEP if there are more than 2 operands.
5391-
if (Ops.size() != 2) {
5392-
// Replace the last index with 0.
5393-
Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
5394-
Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
5385+
// If the final index isn't a vector, emit a scalar GEP containing all ops
5386+
// and a vector GEP with all zeroes final index.
5387+
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
5388+
NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front());
5389+
auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
5390+
NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy));
5391+
} else {
5392+
Value *Base = Ops[0];
5393+
Value *Index = Ops[FinalIndex];
5394+
5395+
// Create a scalar GEP if there are more than 2 operands.
5396+
if (Ops.size() != 2) {
5397+
// Replace the last index with 0.
5398+
Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy);
5399+
Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front());
5400+
}
5401+
5402+
// Now create the GEP with scalar pointer and vector index.
5403+
NewAddr = Builder.CreateGEP(Base, Index);
53955404
}
5405+
} else if (!isa<Constant>(Ptr)) {
5406+
// Not a GEP, maybe its a splat and we can create a GEP to enable
5407+
// SelectionDAGBuilder to use it as a uniform base.
5408+
Value *V = getSplatValue(Ptr);
5409+
if (!V)
5410+
return false;
5411+
5412+
unsigned NumElts = cast<FixedVectorType>(Ptr->getType())->getNumElements();
5413+
5414+
IRBuilder<> Builder(MemoryInst);
53965415

5397-
// Now create the GEP with scalar pointer and vector index.
5398-
NewAddr = Builder.CreateGEP(Base, Index);
5416+
// Emit a vector GEP with a scalar pointer and all 0s vector index.
5417+
Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
5418+
auto *IndexTy = FixedVectorType::get(ScalarIndexTy, NumElts);
5419+
NewAddr = Builder.CreateGEP(V, Constant::getNullValue(IndexTy));
5420+
} else {
5421+
// Constant, SelectionDAGBuilder knows to check if its a splat.
5422+
return false;
53995423
}
54005424

54015425
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);

llvm/test/CodeGen/X86/masked_gather_scatter.ll

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,14 +3323,13 @@ define void @scatter_16i64_constant_indices(i32* %ptr, <16 x i1> %mask, <16 x i3
33233323
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
33243324
; KNL_64-LABEL: splat_ptr_gather:
33253325
; KNL_64: # %bb.0:
3326-
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
3326+
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
33273327
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
33283328
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
33293329
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
33303330
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
3331-
; KNL_64-NEXT: vmovq %rdi, %xmm0
3332-
; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
3333-
; KNL_64-NEXT: vpgatherqd (,%zmm0), %ymm1 {%k1}
3331+
; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
3332+
; KNL_64-NEXT: vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1}
33343333
; KNL_64-NEXT: vmovdqa %xmm1, %xmm0
33353334
; KNL_64-NEXT: vzeroupper
33363335
; KNL_64-NEXT: retq
@@ -3342,8 +3341,9 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
33423341
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
33433342
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
33443343
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
3345-
; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
3346-
; KNL_32-NEXT: vpgatherdd (,%zmm0), %zmm1 {%k1}
3344+
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
3345+
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
3346+
; KNL_32-NEXT: vpgatherdd (%eax,%zmm0,4), %zmm1 {%k1}
33473347
; KNL_32-NEXT: vmovdqa %xmm1, %xmm0
33483348
; KNL_32-NEXT: vzeroupper
33493349
; KNL_32-NEXT: retl
@@ -3352,18 +3352,18 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
33523352
; SKX: # %bb.0:
33533353
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
33543354
; SKX-NEXT: vpmovd2m %xmm0, %k1
3355-
; SKX-NEXT: vpbroadcastq %rdi, %ymm0
3356-
; SKX-NEXT: vpgatherqd (,%ymm0), %xmm1 {%k1}
3355+
; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
3356+
; SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
33573357
; SKX-NEXT: vmovdqa %xmm1, %xmm0
3358-
; SKX-NEXT: vzeroupper
33593358
; SKX-NEXT: retq
33603359
;
33613360
; SKX_32-LABEL: splat_ptr_gather:
33623361
; SKX_32: # %bb.0:
33633362
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
33643363
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
3365-
; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
3366-
; SKX_32-NEXT: vpgatherdd (,%xmm0), %xmm1 {%k1}
3364+
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
3365+
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
3366+
; SKX_32-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
33673367
; SKX_32-NEXT: vmovdqa %xmm1, %xmm0
33683368
; SKX_32-NEXT: retl
33693369
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
@@ -3376,14 +3376,13 @@ declare <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*>, i32, <4 x i1>,
33763376
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
33773377
; KNL_64-LABEL: splat_ptr_scatter:
33783378
; KNL_64: # %bb.0:
3379-
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $ymm1
3379+
; KNL_64-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
33803380
; KNL_64-NEXT: vpslld $31, %xmm0, %xmm0
33813381
; KNL_64-NEXT: vptestmd %zmm0, %zmm0, %k0
33823382
; KNL_64-NEXT: kshiftlw $12, %k0, %k0
33833383
; KNL_64-NEXT: kshiftrw $12, %k0, %k1
3384-
; KNL_64-NEXT: vmovq %rdi, %xmm0
3385-
; KNL_64-NEXT: vpbroadcastq %xmm0, %ymm0
3386-
; KNL_64-NEXT: vpscatterqd %ymm1, (,%zmm0) {%k1}
3384+
; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0
3385+
; KNL_64-NEXT: vpscatterdd %zmm1, (%rdi,%zmm0,4) {%k1}
33873386
; KNL_64-NEXT: vzeroupper
33883387
; KNL_64-NEXT: retq
33893388
;
@@ -3394,26 +3393,27 @@ define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
33943393
; KNL_32-NEXT: vptestmd %zmm0, %zmm0, %k0
33953394
; KNL_32-NEXT: kshiftlw $12, %k0, %k0
33963395
; KNL_32-NEXT: kshiftrw $12, %k0, %k1
3397-
; KNL_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
3398-
; KNL_32-NEXT: vpscatterdd %zmm1, (,%zmm0) {%k1}
3396+
; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax
3397+
; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
3398+
; KNL_32-NEXT: vpscatterdd %zmm1, (%eax,%zmm0,4) {%k1}
33993399
; KNL_32-NEXT: vzeroupper
34003400
; KNL_32-NEXT: retl
34013401
;
34023402
; SKX-LABEL: splat_ptr_scatter:
34033403
; SKX: # %bb.0:
34043404
; SKX-NEXT: vpslld $31, %xmm0, %xmm0
34053405
; SKX-NEXT: vpmovd2m %xmm0, %k1
3406-
; SKX-NEXT: vpbroadcastq %rdi, %ymm0
3407-
; SKX-NEXT: vpscatterqd %xmm1, (,%ymm0) {%k1}
3408-
; SKX-NEXT: vzeroupper
3406+
; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
3407+
; SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
34093408
; SKX-NEXT: retq
34103409
;
34113410
; SKX_32-LABEL: splat_ptr_scatter:
34123411
; SKX_32: # %bb.0:
34133412
; SKX_32-NEXT: vpslld $31, %xmm0, %xmm0
34143413
; SKX_32-NEXT: vpmovd2m %xmm0, %k1
3415-
; SKX_32-NEXT: vpbroadcastd {{[0-9]+}}(%esp), %xmm0
3416-
; SKX_32-NEXT: vpscatterdd %xmm1, (,%xmm0) {%k1}
3414+
; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax
3415+
; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0
3416+
; SKX_32-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
34173417
; SKX_32-NEXT: retl
34183418
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
34193419
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer

llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,9 @@ define <4 x i32> @global_struct_splat() {
8787

8888
define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
8989
; CHECK-LABEL: @splat_ptr_gather(
90-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
91-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
92-
; CHECK-NEXT: [[TMP3:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
93-
; CHECK-NEXT: ret <4 x i32> [[TMP3]]
90+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
91+
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0i32(<4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
92+
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
9493
;
9594
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0
9695
%2 = shufflevector <4 x i32*> %1, <4 x i32*> undef, <4 x i32> zeroinitializer
@@ -100,9 +99,8 @@ define <4 x i32> @splat_ptr_gather(i32* %ptr, <4 x i1> %mask, <4 x i32> %passthr
10099

101100
define void @splat_ptr_scatter(i32* %ptr, <4 x i1> %mask, <4 x i32> %val) {
102101
; CHECK-LABEL: @splat_ptr_scatter(
103-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i32*> undef, i32* [[PTR:%.*]], i32 0
104-
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32*> [[TMP1]], <4 x i32*> undef, <4 x i32> zeroinitializer
105-
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP2]], i32 4, <4 x i1> [[MASK:%.*]])
102+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, i32* [[PTR:%.*]], <4 x i64> zeroinitializer
103+
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0i32(<4 x i32> [[VAL:%.*]], <4 x i32*> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
106104
; CHECK-NEXT: ret void
107105
;
108106
%1 = insertelement <4 x i32*> undef, i32* %ptr, i32 0

0 commit comments

Comments
 (0)