@@ -1283,7 +1283,8 @@ void SPIRVToLLVM::addMemAliasMetadata(Instruction *I, SPIRVId AliasListId,
1283
1283
I->setMetadata (AliasMDKind, MDAliasListMap[AliasListId]);
1284
1284
}
1285
1285
1286
- void transFunctionPointerCallArgumentAttributes (SPIRVValue *BV, CallInst *CI) {
1286
+ void SPIRVToLLVM::transFunctionPointerCallArgumentAttributes (
1287
+ SPIRVValue *BV, CallInst *CI, SPIRVTypeFunction *CalledFnTy) {
1287
1288
std::vector<SPIRVDecorate const *> ArgumentAttributes =
1288
1289
BV->getDecorations (internal::DecorationArgumentAttributeINTEL);
1289
1290
@@ -1296,8 +1297,8 @@ void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
1296
1297
auto LlvmAttr =
1297
1298
Attribute::isTypeAttrKind (LlvmAttrKind)
1298
1299
? Attribute::get (CI->getContext (), LlvmAttrKind,
1299
- cast<PointerType>(CI-> getOperand (ArgNo)-> getType () )
1300
- ->getPointerElementType ())
1300
+ transType (CalledFnTy-> getParameterType (ArgNo)
1301
+ ->getPointerElementType () ))
1301
1302
: Attribute::get (CI->getContext (), LlvmAttrKind);
1302
1303
CI->addParamAttr (ArgNo, LlvmAttr);
1303
1304
}
@@ -1733,7 +1734,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
1733
1734
// A ptr.annotation may have been generated for the source variable.
1734
1735
replaceOperandWithAnnotationIntrinsicCallResult (V);
1735
1736
1736
- Type *Ty = V ->getType ()-> getPointerElementType ( );
1737
+ Type *Ty = transType (BL ->getType ());
1737
1738
LoadInst *LI = nullptr ;
1738
1739
uint64_t AlignValue = BL->SPIRVMemoryAccess ::getAlignment ();
1739
1740
if (0 == AlignValue) {
@@ -2082,7 +2083,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2082
2083
case OpInBoundsPtrAccessChain: {
2083
2084
auto AC = static_cast <SPIRVAccessChainBase *>(BV);
2084
2085
auto Base = transValue (AC->getBase (), F, BB);
2085
- Type *BaseTy = cast<PointerType>(Base ->getType ()) ->getPointerElementType ();
2086
+ Type *BaseTy = transType (AC-> getBase () ->getType ()->getPointerElementType () );
2086
2087
auto Index = transValue (AC->getIndices (), F, BB);
2087
2088
if (!AC->hasPtrIndex ())
2088
2089
Index.insert (Index.begin (), getInt32 (M, 0 ));
@@ -2237,10 +2238,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2237
2238
SPIRVFunctionPointerCallINTEL *BC =
2238
2239
static_cast <SPIRVFunctionPointerCallINTEL *>(BV);
2239
2240
auto *V = transValue (BC->getCalledValue (), F, BB);
2240
- auto Call = CallInst::Create (
2241
- cast<FunctionType>(V->getType ()->getPointerElementType ()), V,
2242
- transValue (BC->getArgumentValues (), F, BB), BC->getName (), BB);
2243
- transFunctionPointerCallArgumentAttributes (BV, Call);
2241
+ auto *SpirvFnTy = BC->getCalledValue ()->getType ()->getPointerElementType ();
2242
+ auto *FnTy = cast<FunctionType>(transType (SpirvFnTy));
2243
+ auto *Call = CallInst::Create (
2244
+ FnTy, V, transValue (BC->getArgumentValues (), F, BB), BC->getName (), BB);
2245
+ transFunctionPointerCallArgumentAttributes (
2246
+ BV, Call, static_cast <SPIRVTypeFunction *>(SpirvFnTy));
2244
2247
// Assuming we are calling a regular device function
2245
2248
Call->setCallingConv (CallingConv::SPIR_FUNC);
2246
2249
// Don't set attributes, because at translation time we don't know which
@@ -2403,7 +2406,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2403
2406
else {
2404
2407
IID = Intrinsic::ptr_annotation;
2405
2408
auto *PtrTy = dyn_cast<PointerType>(Ty);
2406
- if (PtrTy && isa<IntegerType>(PtrTy->getPointerElementType ()))
2409
+ if (PtrTy &&
2410
+ (PtrTy->isOpaque () ||
2411
+ isa<IntegerType>(PtrTy->getNonOpaquePointerElementType ())))
2407
2412
RetTy = PtrTy;
2408
2413
// Whether a struct or a pointer to some other type,
2409
2414
// bitcast to i8*
@@ -2771,10 +2776,8 @@ void SPIRVToLLVM::transFunctionAttrs(SPIRVFunction *BF, Function *F) {
2771
2776
Type *AttrTy = nullptr ;
2772
2777
switch (LLVMKind) {
2773
2778
case Attribute::AttrKind::ByVal:
2774
- AttrTy = cast<PointerType>(I->getType ())->getPointerElementType ();
2775
- break ;
2776
2779
case Attribute::AttrKind::StructRet:
2777
- AttrTy = cast<PointerType>(I ->getType ()) ->getPointerElementType ();
2780
+ AttrTy = transType (BA ->getType ()->getPointerElementType () );
2778
2781
break ;
2779
2782
default :
2780
2783
break ; // do nothing
0 commit comments