Skip to content

Commit bd9bcea

Browse files
update type inference for function pointers and update test cases
1 parent 7f79653 commit bd9bcea

File tree

4 files changed

+149
-52
lines changed

4 files changed

+149
-52
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class SPIRVEmitIntrinsics
6969
SPIRVGlobalRegistry *GR = nullptr;
7070
Function *F = nullptr;
7171
bool TrackConstants = true;
72+
bool HaveFunPtrs = false;
7273
DenseMap<Instruction *, Constant *> AggrConsts;
7374
DenseMap<Instruction *, Type *> AggrConstTypes;
7475
DenseSet<Instruction *> AggrStores;
@@ -714,6 +715,37 @@ static bool deduceOperandElementTypeCalledFunction(
714715
return true;
715716
}
716717

718+
// Try to deduce element type for a function pointer.
719+
static void deduceOperandElementTypeFunctionPointer(
720+
SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
721+
SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
722+
Value *Op = CI->getCalledOperand();
723+
if (!Op || !isPointerTy(Op->getType()))
724+
return;
725+
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
726+
FunctionType *FTy = CI->getFunctionType();
727+
bool IsNewFTy = false;
728+
SmallVector<Type *, 4> ArgTys;
729+
for (Value *Arg : CI->args()) {
730+
Type *ArgTy = Arg->getType();
731+
if (ArgTy->isPointerTy())
732+
if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
733+
IsNewFTy = true;
734+
ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
735+
}
736+
ArgTys.push_back(ArgTy);
737+
}
738+
Type *RetTy = FTy->getReturnType();
739+
if (I->getType()->isPointerTy())
740+
if (Type *ElemTy = GR->findDeducedElementType(I)) {
741+
IsNewFTy = true;
742+
RetTy =
743+
TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
744+
}
745+
KnownElemTy =
746+
IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
747+
}
748+
717749
// If the Instruction has Pointer operands with unresolved types, this function
718750
// tries to deduce them. If the Instruction has Pointer operands with known
719751
// types which differ from expected, this function tries to insert a bitcast to
@@ -820,17 +852,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
820852
Ops.push_back(std::make_pair(Op0, 0));
821853
}
822854
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
823-
if (!CI->isIndirectCall()) {
855+
if (!CI->isIndirectCall())
824856
deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
825857
KnownElemTy);
826-
} else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
827-
SPIRV::Extension::SPV_INTEL_function_pointers)) {
828-
Value *Op = CI->getCalledOperand();
829-
if (!Op || !isPointerTy(Op->getType()))
830-
return;
831-
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
832-
KnownElemTy = CI->getFunctionType();
833-
}
858+
else if (HaveFunPtrs)
859+
deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
834860
}
835861

836862
// There is no enough info to deduce types or all is valid.
@@ -1710,23 +1736,53 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
17101736
}
17111737
}
17121738

1739+
static FunctionType *getFunctionPointerElemType(Function *F,
1740+
SPIRVGlobalRegistry *GR) {
1741+
FunctionType *FTy = F->getFunctionType();
1742+
bool IsNewFTy = false;
1743+
SmallVector<Type *, 4> ArgTys;
1744+
for (Argument &Arg : F->args()) {
1745+
Type *ArgTy = Arg.getType();
1746+
if (ArgTy->isPointerTy())
1747+
if (Type *ElemTy = GR->findDeducedElementType(&Arg)) {
1748+
IsNewFTy = true;
1749+
ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
1750+
}
1751+
ArgTys.push_back(ArgTy);
1752+
}
1753+
return IsNewFTy
1754+
? FunctionType::get(FTy->getReturnType(), ArgTys, FTy->isVarArg())
1755+
: FTy;
1756+
}
1757+
17131758
bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
1714-
bool IsExt = false;
17151759
SmallVector<Function *> Worklist;
17161760
for (auto &F : M) {
1717-
if (!IsExt) {
1718-
if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
1719-
SPIRV::Extension::SPV_INTEL_function_pointers))
1720-
return false;
1721-
IsExt = true;
1722-
}
1723-
if (!F.isDeclaration() || F.isIntrinsic())
1761+
if (F.isIntrinsic())
17241762
continue;
1725-
for (User *U : F.users()) {
1726-
CallInst *CI = dyn_cast<CallInst>(U);
1727-
if (!CI || CI->getCalledFunction() != &F) {
1728-
Worklist.push_back(&F);
1729-
break;
1763+
if (F.isDeclaration()) {
1764+
for (User *U : F.users()) {
1765+
CallInst *CI = dyn_cast<CallInst>(U);
1766+
if (!CI || CI->getCalledFunction() != &F) {
1767+
Worklist.push_back(&F);
1768+
break;
1769+
}
1770+
}
1771+
} else {
1772+
if (F.user_empty())
1773+
continue;
1774+
Type *FPElemTy = GR->findDeducedElementType(&F);
1775+
if (!FPElemTy)
1776+
FPElemTy = getFunctionPointerElemType(&F, GR);
1777+
for (User *U : F.users()) {
1778+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1779+
if (!II || II->arg_size() != 3 || II->getOperand(0) != &F)
1780+
continue;
1781+
if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
1782+
II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
1783+
updateAssignType(II, &F, PoisonValue::get(FPElemTy));
1784+
break;
1785+
}
17301786
}
17311787
}
17321788
}
@@ -1765,6 +1821,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
17651821
InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
17661822
: SPIRV::InstructionSet::GLSL_std_450;
17671823

1824+
if (!F)
1825+
HaveFunPtrs =
1826+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
1827+
17681828
F = &Func;
17691829
IRBuilder<> B(Func.getContext());
17701830
AggrConsts.clear();
@@ -1910,7 +1970,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
19101970
}
19111971

19121972
Changed |= postprocessTypes();
1913-
Changed |= processFunctionPointers(M);
1973+
if (HaveFunPtrs)
1974+
Changed |= processFunctionPointers(M);
19141975

19151976
return Changed;
19161977
}

llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,47 @@
1-
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
22
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
33

4-
; CHECK: OpFunction
4+
; CHECK-DAG: OpName %[[I9:.*]] "_ZN13BaseIncrement9incrementEPi"
5+
; CHECK-DAG: OpName %[[I29:.*]] "_ZN12IncrementBy29incrementEPi"
6+
; CHECK-DAG: OpName %[[I49:.*]] "_ZN12IncrementBy49incrementEPi"
7+
; CHECK-DAG: OpName %[[I89:.*]] "_ZN12IncrementBy89incrementEPi"
58

6-
%classid = type { %arrayid }
7-
%arrayid = type { [1 x i64] }
8-
%struct.obj_storage_t = type { %storage }
9-
%storage = type { [8 x i8] }
9+
; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
10+
; CHECK-DAG: %[[TyArr:.*]] = OpTypeArray
11+
; CHECK-DAG: %[[TyStruct1:.*]] = OpTypeStruct %[[TyArr]]
12+
; CHECK-DAG: %[[TyStruct2:.*]] = OpTypeStruct %[[TyStruct1]]
13+
; CHECK-DAG: %[[TyPtrStruct2:.*]] = OpTypePointer Generic %[[TyStruct2]]
14+
; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrStruct2]] %[[#]]
15+
; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Generic %[[TyFun]]
16+
; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Generic %[[TyPtrFun]]
17+
18+
; CHECK: %[[I9]] = OpFunction
19+
; CHECK: %[[I29]] = OpFunction
20+
; CHECK: %[[I49]] = OpFunction
21+
; CHECK: %[[I89]] = OpFunction
22+
23+
; CHECK: %[[Arg1:.*]] = OpPhi %[[TyPtrStruct2]]
24+
; CHECK: %[[VTbl:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[#]]
25+
; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[VTbl]]
26+
; CHECK: %[[#]] = OpFunctionPointerCallINTEL %[[TyVoid]] %[[FP]] %[[Arg1]] %[[#]]
27+
28+
%"cls::id" = type { %"cls::detail::array" }
29+
%"cls::detail::array" = type { [1 x i64] }
30+
%struct.obj_storage_t = type { %"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" }
31+
%"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" = type { [8 x i8] }
1032

1133
@_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8
1234
@_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8
1335
@_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8
1436
@_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8
37+
@__spirv_BuiltInWorkgroupId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
38+
@__spirv_BuiltInGlobalLinearId = external dso_local local_unnamed_addr addrspace(1) constant i64, align 8
39+
@__spirv_BuiltInWorkgroupSize = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
1540

16-
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
41+
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%"cls::id") align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
1742
entry:
18-
%0 = load i64, ptr %_arg_StorageAcc3, align 8
19-
%add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0
43+
%r0 = load i64, ptr %_arg_StorageAcc3, align 8
44+
%add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %r0
2045
%arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4)
2146
%cmp.i = icmp ugt i32 %_arg_TestCase, 3
2247
br i1 %cmp.i, label %entry.critedge, label %if.end.1
@@ -51,9 +76,9 @@ if.end.2: ; preds = %if.end.1
5176
exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge
5277
%vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ]
5378
%retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ]
54-
%1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
55-
%2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
56-
tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1)
79+
%r1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
80+
%r2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
81+
tail call spir_func addrspace(4) void %r2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %r1)
5782
ret void
5883
}
5984

llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,39 @@
55
; CHECK-DAG: OpCapability FunctionPointersINTEL
66
; CHECK-DAG: OpCapability Int64
77
; CHECK: OpExtension "SPV_INTEL_function_pointers"
8-
; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
8+
99
; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
1010
; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
11-
; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]]
12-
; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42
13-
; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
14-
; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]]
15-
; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]]
16-
; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]]
17-
; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]]
18-
; CHECK: OpReturn
11+
; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyInt64]] %[[TyInt64]]
12+
; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
13+
; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Function %[[TyFun]]
14+
; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFun]] %[[DefFunFp:.*]]
15+
; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Function %[[TyPtrFun]]
16+
; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
17+
; CHECK-DAG: %[[TyPtrPtrInt8:.*]] = OpTypePointer Function %[[TyPtrInt8]]
18+
; CHECK: OpFunction
19+
; CHECK: %[[Var:.*]] = OpVariable %[[TyPtrPtrInt8]] Function
20+
; CHECK: %[[SAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
21+
; CHECK: OpStore %[[SAddr]] %[[ConstFunFp]]
22+
; CHECK: %[[LAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
23+
; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[LAddr]]
24+
; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FP]] %[[#]]
1925
; CHECK: OpFunctionEnd
20-
; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]]
26+
27+
; CHECK: %[[DefFunFp]] = OpFunction %[[TyInt64]] None %[[TyFun]]
2128

2229
target triple = "spir64-unknown-unknown"
2330

2431
define spir_kernel void @test() {
2532
entry:
26-
%0 = load ptr, ptr @foo
27-
%1 = call i64 %0(i64 42)
33+
%fp = alloca ptr
34+
store ptr @foo, ptr %fp
35+
%tocall = load ptr, ptr %fp
36+
%res = call i64 %tocall(i64 42)
2837
ret void
2938
}
3039

31-
define void @foo(i64 %a) {
40+
define i64 @foo(i64 %a) {
3241
entry:
33-
ret void
42+
ret i64 %a
3443
}

llvm/test/CodeGen/SPIRV/instructions/select-phi.ll

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
; This test case checks how phi-nodes with different operand types select
2+
; a result type. Majority of operands makes it i8* in this case.
3+
14
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
25
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
36

@@ -15,14 +18,13 @@
1518

1619
; CHECK: %[[Branch1:.*]] = OpLabel
1720
; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
21+
; CHECK: %[[Res1Casted:.*]] = OpBitcast %[[CharPtr]] %[[Res1]]
1822
; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
1923
; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]]
20-
; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]]
2124
; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
2225
; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]]
23-
; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]]
2426
; CHECK: OpLabel
25-
; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]]
27+
; CHECK: OpPhi %[[CharPtr]] %[[Res1Casted]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]]
2628

2729
%struct = type { %array }
2830
%array = type { [1 x i64] }

0 commit comments

Comments
 (0)