Skip to content

Commit 07a7cca

Browse files
jcranmer-intelsvenvh
authored andcommitted
Replace uses of Type::getPointerElementType in SPIRVReader.
1 parent 33012ca commit 07a7cca

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,8 @@ void SPIRVToLLVM::addMemAliasMetadata(Instruction *I, SPIRVId AliasListId,
12831283
I->setMetadata(AliasMDKind, MDAliasListMap[AliasListId]);
12841284
}
12851285

1286-
void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
1286+
void SPIRVToLLVM::transFunctionPointerCallArgumentAttributes(
1287+
SPIRVValue *BV, CallInst *CI, SPIRVTypeFunction *CalledFnTy) {
12871288
std::vector<SPIRVDecorate const *> ArgumentAttributes =
12881289
BV->getDecorations(internal::DecorationArgumentAttributeINTEL);
12891290

@@ -1296,8 +1297,8 @@ void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
12961297
auto LlvmAttr =
12971298
Attribute::isTypeAttrKind(LlvmAttrKind)
12981299
? Attribute::get(CI->getContext(), LlvmAttrKind,
1299-
cast<PointerType>(CI->getOperand(ArgNo)->getType())
1300-
->getPointerElementType())
1300+
transType(CalledFnTy->getParameterType(ArgNo)
1301+
->getPointerElementType()))
13011302
: Attribute::get(CI->getContext(), LlvmAttrKind);
13021303
CI->addParamAttr(ArgNo, LlvmAttr);
13031304
}
@@ -1733,7 +1734,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
17331734
// A ptr.annotation may have been generated for the source variable.
17341735
replaceOperandWithAnnotationIntrinsicCallResult(V);
17351736

1736-
Type *Ty = V->getType()->getPointerElementType();
1737+
Type *Ty = transType(BL->getType());
17371738
LoadInst *LI = nullptr;
17381739
uint64_t AlignValue = BL->SPIRVMemoryAccess::getAlignment();
17391740
if (0 == AlignValue) {
@@ -2082,7 +2083,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
20822083
case OpInBoundsPtrAccessChain: {
20832084
auto AC = static_cast<SPIRVAccessChainBase *>(BV);
20842085
auto Base = transValue(AC->getBase(), F, BB);
2085-
Type *BaseTy = cast<PointerType>(Base->getType())->getPointerElementType();
2086+
Type *BaseTy = transType(AC->getBase()->getType()->getPointerElementType());
20862087
auto Index = transValue(AC->getIndices(), F, BB);
20872088
if (!AC->hasPtrIndex())
20882089
Index.insert(Index.begin(), getInt32(M, 0));
@@ -2237,10 +2238,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
22372238
SPIRVFunctionPointerCallINTEL *BC =
22382239
static_cast<SPIRVFunctionPointerCallINTEL *>(BV);
22392240
auto *V = transValue(BC->getCalledValue(), F, BB);
2240-
auto Call = CallInst::Create(
2241-
cast<FunctionType>(V->getType()->getPointerElementType()), V,
2242-
transValue(BC->getArgumentValues(), F, BB), BC->getName(), BB);
2243-
transFunctionPointerCallArgumentAttributes(BV, Call);
2241+
auto *SpirvFnTy = BC->getCalledValue()->getType()->getPointerElementType();
2242+
auto *FnTy = cast<FunctionType>(transType(SpirvFnTy));
2243+
auto *Call = CallInst::Create(
2244+
FnTy, V, transValue(BC->getArgumentValues(), F, BB), BC->getName(), BB);
2245+
transFunctionPointerCallArgumentAttributes(
2246+
BV, Call, static_cast<SPIRVTypeFunction *>(SpirvFnTy));
22442247
// Assuming we are calling a regular device function
22452248
Call->setCallingConv(CallingConv::SPIR_FUNC);
22462249
// Don't set attributes, because at translation time we don't know which
@@ -2403,7 +2406,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
24032406
else {
24042407
IID = Intrinsic::ptr_annotation;
24052408
auto *PtrTy = dyn_cast<PointerType>(Ty);
2406-
if (PtrTy && isa<IntegerType>(PtrTy->getPointerElementType()))
2409+
if (PtrTy &&
2410+
(PtrTy->isOpaque() ||
2411+
isa<IntegerType>(PtrTy->getNonOpaquePointerElementType())))
24072412
RetTy = PtrTy;
24082413
// Whether a struct or a pointer to some other type,
24092414
// bitcast to i8*
@@ -2771,10 +2776,8 @@ void SPIRVToLLVM::transFunctionAttrs(SPIRVFunction *BF, Function *F) {
27712776
Type *AttrTy = nullptr;
27722777
switch (LLVMKind) {
27732778
case Attribute::AttrKind::ByVal:
2774-
AttrTy = cast<PointerType>(I->getType())->getPointerElementType();
2775-
break;
27762779
case Attribute::AttrKind::StructRet:
2777-
AttrTy = cast<PointerType>(I->getType())->getPointerElementType();
2780+
AttrTy = transType(BA->getType()->getPointerElementType());
27782781
break;
27792782
default:
27802783
break; // do nothing

lib/SPIRV/SPIRVReader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ class SPIRVToLLVM {
244244
void transMemAliasingINTELDecorations(SPIRVValue *BV, Value *V);
245245
void transVarDecorationsToMetadata(SPIRVValue *BV, Value *V);
246246
void transFunctionDecorationsToMetadata(SPIRVFunction *BF, Function *F);
247+
void
248+
transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI,
249+
SPIRVTypeFunction *CalledFnTy);
247250
}; // class SPIRVToLLVM
248251

249252
} // namespace SPIRV

0 commit comments

Comments
 (0)