Skip to content

Commit e648d0b

Browse files
AlexeySachkovvladimirlaz
authored andcommitted
Fix handling of function pointers going through select
1 parent cc90c60 commit e648d0b

File tree

2 files changed

+170
-4
lines changed

2 files changed

+170
-4
lines changed

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,30 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
10791079
return mapValue(V, BI);
10801080
}
10811081

1082-
if (SelectInst *Sel = dyn_cast<SelectInst>(V))
1082+
if (SelectInst *Sel = dyn_cast<SelectInst>(V)) {
1083+
SPIRVValue *TrueValue = nullptr;
1084+
SPIRVValue *FalseValue = nullptr;
1085+
if (isa<Function>(Sel->getTrueValue())) {
1086+
if (!BM->checkExtension(ExtensionID::SPV_INTEL_function_pointers,
1087+
SPIRVEC_FunctionPointers, toString(Sel)))
1088+
return nullptr;
1089+
1090+
// select with function pointers
1091+
auto *TrueF = cast<Function>(Sel->getTrueValue());
1092+
TrueValue = BM->addFunctionPointerINTELInst(
1093+
transType(TrueF->getType()),
1094+
static_cast<SPIRVFunction *>(transValue(TrueF, BB)), BB);
1095+
auto *FalseF = cast<Function>(Sel->getFalseValue());
1096+
FalseValue = BM->addFunctionPointerINTELInst(
1097+
transType(FalseF->getType()),
1098+
static_cast<SPIRVFunction *>(transValue(FalseF, BB)), BB);
1099+
} else {
1100+
TrueValue = transValue(Sel->getTrueValue(), BB);
1101+
FalseValue = transValue(Sel->getFalseValue(), BB);
1102+
}
10831103
return mapValue(V, BM->addSelectInst(transValue(Sel->getCondition(), BB),
1084-
transValue(Sel->getTrueValue(), BB),
1085-
transValue(Sel->getFalseValue(), BB),
1086-
BB));
1104+
TrueValue, FalseValue, BB));
1105+
}
10871106

10881107
if (AllocaInst *Alc = dyn_cast<AllocaInst>(V))
10891108
return mapValue(
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_function_pointers -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
; RUN: llvm-spirv -r %t.spv -o %t.r.bc
6+
; RUN: llvm-dis %t.r.bc -o %t.r.ll
7+
; RUN: FileCheck < %t.r.ll %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV: EntryPoint 6 [[#KERNEL_ID:]] "_ZTS6kernel"
10+
; CHECK-SPIRV-DAG: Name [[#BAR:]] "_Z3barii"
11+
; CHECK-SPIRV-DAG: Name [[#BAZ:]] "_Z3bazii"
12+
; CHECK-SPIRV: TypeInt [[#INT32:]] 32
13+
; CHECK-SPIRV: TypeFunction [[#FUNC_TYPE:]] [[#INT32]] [[#INT32]]
14+
; CHECK-SPIRV: TypePointer [[#FUNC_PTR_TYPE:]] [[#]] [[#FUNC_TYPE]]
15+
; CHECK-SPIRV: TypePointer [[#FUNC_PTR_ALLOCA_TYPE:]] [[#]] [[#FUNC_PTR_TYPE]]
16+
; CHECK-SPIRV: Function [[#]] [[#KERNEL_ID]]
17+
; CHECK-SPIRV: Variable [[#FUNC_PTR_ALLOCA_TYPE]] [[#FPTR:]]
18+
; CHECK-SPIRV-DAG: FunctionPointerINTEL [[#FUNC_PTR_TYPE]] [[#BARPTR:]] [[#BAR]]
19+
; CHECK-SPIRV-DAG: FunctionPointerINTEL [[#FUNC_PTR_TYPE]] [[#BAZPTR:]] [[#BAZ]]
20+
; CHECK-SPIRV: Select [[#FUNC_PTR_TYPE]] [[#SELECT:]] [[#]] [[#BARPTR]] [[#BAZPTR]]
21+
; CHECK-SPIRV: Store [[#FPTR]] [[#SELECT]]
22+
; CHECK-SPIRV: Load [[#FUNC_PTR_TYPE]] [[#LOAD:]] [[#FPTR]]
23+
; CHECK-SPIRV: FunctionPointerCallINTEL [[#]] [[#]] [[#LOAD]]
24+
25+
; CHECK-LLVM: define spir_kernel void @_ZTS6kernel
26+
; CHECK-LLVM: %[[FPTR_ALLOCA:.*]] = alloca i32 (i32, i32)*
27+
; CHECK-LLVM: %[[SELECT:.*]] = select i1 %{{.*}}, i32 (i32, i32)* @_Z3barii, i32 (i32, i32)* @_Z3bazii
28+
; CHECK-LLVM: store i32 (i32, i32)* %[[SELECT]], i32 (i32, i32)** %[[FPTR_ALLOCA]]
29+
; CHECK-LLVM: %[[FPTR:.*]] = load i32 (i32, i32)*, i32 (i32, i32)** %[[FPTR_ALLOCA]]
30+
; CHECK-LLVM: call spir_func i32 %[[FPTR]](
31+
32+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
33+
target triple = "spir64-unknown-unknown"
34+
35+
%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }
36+
%"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" = type { [1 x i64] }
37+
%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }
38+
39+
$_ZTS6kernel = comdat any
40+
41+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
42+
43+
; Function Attrs: norecurse
44+
define weak_odr dso_local spir_kernel void @_ZTS6kernel(i32 addrspace(1)* %_arg_, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_1, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_2, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_3) local_unnamed_addr #0 comdat !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
45+
entry:
46+
%fptr.alloca = alloca i32 (i32, i32)*, align 8
47+
%ref.tmp.i = alloca %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", align 8
48+
%agg.tmp2.i = alloca %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range", align 8
49+
%agg.tmp3.i = alloca %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range", align 8
50+
%agg.tmp6 = alloca %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", align 8
51+
%0 = bitcast %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* %agg.tmp2.i to i8*
52+
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0)
53+
%1 = bitcast %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* %agg.tmp3.i to i8*
54+
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %1)
55+
%2 = addrspacecast %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* %agg.tmp2.i to %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" addrspace(4)*
56+
%ptrint4.i = ptrtoint %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" addrspace(4)* %2 to i64
57+
%maskedptr5.i = and i64 %ptrint4.i, 7
58+
%maskcond6.i = icmp eq i64 %maskedptr5.i, 0
59+
%3 = addrspacecast %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* %agg.tmp3.i to %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" addrspace(4)*
60+
%ptrint.i = ptrtoint %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" addrspace(4)* %3 to i64
61+
%maskedptr.i = and i64 %ptrint.i, 7
62+
%maskcond.i = icmp eq i64 %maskedptr.i, 0
63+
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0)
64+
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %1)
65+
%4 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_3, i64 0, i32 0, i32 0, i64 0
66+
%5 = load i64, i64* %4, align 8
67+
%add.ptr.i = getelementptr inbounds i32, i32 addrspace(1)* %_arg_, i64 %5
68+
%6 = addrspacecast %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %agg.tmp6 to %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" addrspace(4)*
69+
%ptrint = ptrtoint %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" addrspace(4)* %6 to i64
70+
%maskedptr = and i64 %ptrint, 7
71+
%maskcond = icmp eq i64 %maskedptr, 0
72+
%7 = load <3 x i64>, <3 x i64> addrspace(4)* addrspacecast (<3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId to <3 x i64> addrspace(4)*), align 32, !noalias !8
73+
%8 = extractelement <3 x i64> %7, i64 0
74+
%arrayinit.begin.i.i.i.i.i = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" addrspace(4)* %6, i64 0, i32 0, i32 0, i64 0
75+
store i64 %8, i64 addrspace(4)* %arrayinit.begin.i.i.i.i.i, align 8, !tbaa !15, !alias.scope !8
76+
%9 = bitcast %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %ref.tmp.i to i8*
77+
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %9) #4
78+
%10 = addrspacecast %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %ref.tmp.i to %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" addrspace(4)*
79+
%ptrint.i2 = ptrtoint %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" addrspace(4)* %10 to i64
80+
%maskedptr.i3 = and i64 %ptrint.i2, 7
81+
%maskcond.i4 = icmp eq i64 %maskedptr.i3, 0
82+
%rem.i.i = and i64 %8, 1
83+
%cmp.i.i = icmp eq i64 %rem.i.i, 0
84+
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %9) #4
85+
%_Z3barii._Z3bazii.i = select i1 %cmp.i.i, i32 (i32, i32)* @_Z3barii, i32 (i32, i32)* @_Z3bazii
86+
store i32 (i32, i32)* %_Z3barii._Z3bazii.i, i32 (i32, i32)** %fptr.alloca, align 8
87+
%fptr = load i32 (i32, i32)*, i32 (i32, i32)** %fptr.alloca, align 8
88+
%call4.i = call spir_func i32 %fptr(i32 10, i32 10), !callees !19
89+
%arrayidx.i3.i = getelementptr inbounds i32, i32 addrspace(1)* %add.ptr.i, i64 %8
90+
%arrayidx.ascast.i.i = addrspacecast i32 addrspace(1)* %arrayidx.i3.i to i32 addrspace(4)*
91+
store i32 %call4.i, i32 addrspace(4)* %arrayidx.ascast.i.i, align 4, !tbaa !20
92+
ret void
93+
}
94+
95+
; Function Attrs: argmemonly nounwind willreturn
96+
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1
97+
98+
; Function Attrs: argmemonly nounwind willreturn
99+
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1
100+
101+
; Function Attrs: norecurse nounwind readnone
102+
define dso_local spir_func i32 @_Z3barii(i32 %a, i32 %b) local_unnamed_addr #2 {
103+
entry:
104+
%add = add nsw i32 %b, %a
105+
ret i32 %add
106+
}
107+
108+
; Function Attrs: norecurse nounwind readnone
109+
define dso_local spir_func i32 @_Z3bazii(i32 %a, i32 %b) local_unnamed_addr #2 {
110+
entry:
111+
%sub = sub nsw i32 %a, %b
112+
ret i32 %sub
113+
}
114+
115+
attributes #0 = { norecurse "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "sycl-module-id"="f.cpp" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" }
116+
attributes #1 = { argmemonly nounwind willreturn }
117+
attributes #2 = { norecurse nounwind readnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
118+
attributes #3 = { nounwind willreturn }
119+
attributes #4 = { nounwind }
120+
121+
!llvm.module.flags = !{!0}
122+
!opencl.spir.version = !{!1}
123+
!spirv.Source = !{!2}
124+
!llvm.ident = !{!3}
125+
126+
!0 = !{i32 1, !"wchar_size", i32 4}
127+
!1 = !{i32 1, i32 2}
128+
!2 = !{i32 4, i32 100000}
129+
!3 = !{!"clang version 11.0.0 "}
130+
!4 = !{i32 1, i32 0, i32 0, i32 0}
131+
!5 = !{!"none", !"none", !"none", !"none"}
132+
!6 = !{!"int*", !"cl::sycl::range<1>", !"cl::sycl::range<1>", !"cl::sycl::id<1>"}
133+
!7 = !{!"", !"", !"", !""}
134+
!8 = !{!9, !11, !13}
135+
!9 = distinct !{!9, !10, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv: %agg.result"}
136+
!10 = distinct !{!10, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEE8initSizeEv"}
137+
!11 = distinct !{!11, !12, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v: %agg.result"}
138+
!12 = distinct !{!12, !"_ZN7__spirvL22initGlobalInvocationIdILi1EN2cl4sycl2idILi1EEEEET0_v"}
139+
!13 = distinct !{!13, !14, !"_ZN2cl4sycl6detail7Builder5getIdILi1EEEKNS0_2idIXT_EEEv: %agg.result"}
140+
!14 = distinct !{!14, !"_ZN2cl4sycl6detail7Builder5getIdILi1EEEKNS0_2idIXT_EEEv"}
141+
!15 = !{!16, !16, i64 0}
142+
!16 = !{!"long", !17, i64 0}
143+
!17 = !{!"omnipotent char", !18, i64 0}
144+
!18 = !{!"Simple C++ TBAA"}
145+
!19 = !{i32 (i32, i32)* @_Z3barii, i32 (i32, i32)* @_Z3bazii}
146+
!20 = !{!21, !21, i64 0}
147+
!21 = !{!"int", !17, i64 0}

0 commit comments

Comments
 (0)