Skip to content

Commit 0781d7b

Browse files
author
Elena Demikhovsky
committed
Fixed a failure in cost calculation for vector GEP
Cost calculation for vector GEP failed with due to invalid cast to GEP index operand. The bug is fixed, added a test. http://reviews.llvm.org/D14976 llvm-svn: 254408
1 parent a00de63 commit 0781d7b

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/IR/GetElementPtrTypeIterator.h"
2323
#include "llvm/IR/Operator.h"
2424
#include "llvm/IR/Type.h"
25+
#include "llvm/Analysis/VectorUtils.h"
2526

2627
namespace llvm {
2728

@@ -415,21 +416,28 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase {
415416
(Ptr == nullptr ? 0 : Ptr->getType()->getPointerAddressSpace());
416417
auto GTI = gep_type_begin(PointerType::get(PointeeType, AS), Operands);
417418
for (auto I = Operands.begin(); I != Operands.end(); ++I, ++GTI) {
419+
// We assume that the cost of Scalar GEP with constant index and the
420+
// cost of Vector GEP with splat constant index are the same.
421+
const ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I);
422+
if (!ConstIdx)
423+
if (auto Splat = getSplatValue(*I))
424+
ConstIdx = dyn_cast<ConstantInt>(Splat);
418425
if (isa<SequentialType>(*GTI)) {
419426
int64_t ElementSize = DL.getTypeAllocSize(GTI.getIndexedType());
420-
if (const ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I)) {
427+
if (ConstIdx)
421428
BaseOffset += ConstIdx->getSExtValue() * ElementSize;
422-
} else {
429+
else {
423430
// Needs scale register.
424-
if (Scale != 0) {
431+
if (Scale != 0)
425432
// No addressing mode takes two scale registers.
426433
return TTI::TCC_Basic;
427-
}
428434
Scale = ElementSize;
429435
}
430436
} else {
431437
StructType *STy = cast<StructType>(*GTI);
432-
uint64_t Field = cast<ConstantInt>(*I)->getZExtValue();
438+
// For structures the index is always splat or scalar constant
439+
assert(ConstIdx && "Unexpected GEP index");
440+
uint64_t Field = ConstIdx->getZExtValue();
433441
BaseOffset += DL.getStructLayout(STy)->getElementOffset(Field);
434442
}
435443
}

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Value *findScalarElement(Value *V, unsigned EltNo);
8686
/// \brief Get splat value if the input is a splat vector or return nullptr.
8787
/// The value may be extracted from a splat constants vector or from
8888
/// a sequence of instructions that broadcast a single value into a vector.
89-
Value *getSplatValue(Value *V);
89+
const Value *getSplatValue(const Value *V);
9090

9191
/// \brief Compute a map of integer instructions to their minimum legal type
9292
/// size.

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,10 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
417417
/// the input value is (1) a splat constants vector or (2) a sequence
418418
/// of instructions that broadcast a single value into a vector.
419419
///
420-
llvm::Value *llvm::getSplatValue(Value *V) {
421-
if (auto *CV = dyn_cast<ConstantDataVector>(V))
422-
return CV->getSplatValue();
420+
const llvm::Value *llvm::getSplatValue(const Value *V) {
421+
422+
if (auto *C = dyn_cast<Constant>(V))
423+
return C->getSplatValue();
423424

424425
auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V);
425426
if (!ShuffleInst)

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,18 +3301,18 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
33013301
// extract the spalt value and use it as a uniform base.
33023302
// In all other cases the function returns 'false'.
33033303
//
3304-
static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
3304+
static bool getUniformBase(const Value *& Ptr, SDValue& Base, SDValue& Index,
33053305
SelectionDAGBuilder* SDB) {
33063306

33073307
SelectionDAG& DAG = SDB->DAG;
33083308
LLVMContext &Context = *DAG.getContext();
33093309

33103310
assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type");
3311-
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
3311+
const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
33123312
if (!GEP || GEP->getNumOperands() > 2)
33133313
return false;
33143314

3315-
Value *GEPPtr = GEP->getPointerOperand();
3315+
const Value *GEPPtr = GEP->getPointerOperand();
33163316
if (!GEPPtr->getType()->isVectorTy())
33173317
Ptr = GEPPtr;
33183318
else if (!(Ptr = getSplatValue(GEPPtr)))
@@ -3348,7 +3348,7 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
33483348
SDLoc sdl = getCurSDLoc();
33493349

33503350
// llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
3351-
Value *Ptr = I.getArgOperand(1);
3351+
const Value *Ptr = I.getArgOperand(1);
33523352
SDValue Src0 = getValue(I.getArgOperand(0));
33533353
SDValue Mask = getValue(I.getArgOperand(3));
33543354
EVT VT = Src0.getValueType();
@@ -3362,10 +3362,10 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
33623362

33633363
SDValue Base;
33643364
SDValue Index;
3365-
Value *BasePtr = Ptr;
3365+
const Value *BasePtr = Ptr;
33663366
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
33673367

3368-
Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
3368+
const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr;
33693369
MachineMemOperand *MMO = DAG.getMachineFunction().
33703370
getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
33713371
MachineMemOperand::MOStore, VT.getStoreSize(),
@@ -3425,7 +3425,7 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
34253425
SDLoc sdl = getCurSDLoc();
34263426

34273427
// @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
3428-
Value *Ptr = I.getArgOperand(0);
3428+
const Value *Ptr = I.getArgOperand(0);
34293429
SDValue Src0 = getValue(I.getArgOperand(3));
34303430
SDValue Mask = getValue(I.getArgOperand(2));
34313431

@@ -3442,7 +3442,7 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
34423442
SDValue Root = DAG.getRoot();
34433443
SDValue Base;
34443444
SDValue Index;
3445-
Value *BasePtr = Ptr;
3445+
const Value *BasePtr = Ptr;
34463446
bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
34473447
bool ConstantMemory = false;
34483448
if (UniformBase &&
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
; RUN: opt < %s -cost-model -analyze -mtriple=x86_64-linux-unknown-unknown -mattr=+avx512f | FileCheck %s
2+
3+
%struct.S = type { [1000 x i32] }
4+
5+
6+
declare <4 x i32> @llvm.masked.gather.v4i32(<4 x i32*>, i32, <4 x i1>, <4 x i32>)
7+
8+
define <4 x i32> @foov(<4 x %struct.S*> %s, i64 %base){
9+
%temp = insertelement <4 x i64> undef, i64 %base, i32 0
10+
%vector = shufflevector <4 x i64> %temp, <4 x i64> undef, <4 x i32> zeroinitializer
11+
;CHECK: cost of 0 for instruction: {{.*}} getelementptr inbounds %struct.S
12+
%B = getelementptr inbounds %struct.S, <4 x %struct.S*> %s, <4 x i32> zeroinitializer, <4 x i32> zeroinitializer
13+
;CHECK: cost of 0 for instruction: {{.*}} getelementptr inbounds [1000 x i32]
14+
%arrayidx = getelementptr inbounds [1000 x i32], <4 x [1000 x i32]*> %B, <4 x i64> zeroinitializer, <4 x i64> %vector
15+
%res = call <4 x i32> @llvm.masked.gather.v4i32(<4 x i32*> %arrayidx, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i32> undef)
16+
ret <4 x i32> %res
17+
}

0 commit comments

Comments
 (0)