Skip to content

Commit c961ef7

Browse files
committed
[sycl-post-link] Support indirectly called assert
Per design doc (https://github.com/intel/llvm/blob/sycl/sycl/doc/Assert.md): If a callgraph for indirect callable function (marked with specific attribute) has a call to __devicelib_assert_fail, then all kernels in the module are conservatively marked as using asserts.
1 parent 7944294 commit c961ef7

File tree

2 files changed

+169
-5
lines changed

2 files changed

+169
-5
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
; This test checks that the post-link tool properly generates "assert used"
2+
; property for indirectly called assertions - it should include all the kernels
3+
; even they do not call assertions in their call graph.
4+
; Per design doc, if a callgraph for indirect callable function
5+
; (marked with "referenced-indirectly" attribute in IR) has a call to
6+
; __devicelib_assert_fail, then all kernels in the module are conservatively
7+
; marked as using asserts.
8+
9+
; RUN: sycl-post-link -split=auto -symbols -S %s -o %t.table
10+
; RUN: FileCheck %s -input-file=%t_0.prop
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
13+
target triple = "spir64-unknown-linux-sycldevice"
14+
15+
@_ZL2GV = internal addrspace(1) constant [1 x i32] [i32 42], align 4
16+
@.str = private unnamed_addr addrspace(1) constant [2 x i8] c"0\00", align 1
17+
@.str.1 = private unnamed_addr addrspace(1) constant [11 x i8] c"assert.cpp\00", align 1
18+
@__PRETTY_FUNCTION__._Z3foov = private unnamed_addr addrspace(1) constant [11 x i8] c"void foo()\00", align 1
19+
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
20+
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
21+
@_ZL10assert_fmt = internal addrspace(2) constant [85 x i8] c"%s:%d: %s: global id: [%lu,%lu,%lu], local id: [%lu,%lu,%lu] Assertion `%s` failed.\0A\00", align 1
22+
23+
; CHECK: [SYCL/assert used]
24+
25+
; CHECK-DAG: main_TU0_kernel0
26+
define dso_local spir_kernel void @main_TU0_kernel0() #0 {
27+
entry:
28+
call spir_func void @_Z3foov()
29+
ret void
30+
}
31+
32+
define dso_local spir_func void @_Z3foov() {
33+
entry:
34+
%a = alloca i32, align 4
35+
%ptr = bitcast i32* %a to i32 (i32)*
36+
%call = call spir_func i32 %ptr(i32 1)
37+
%add = add nsw i32 2, %call
38+
store i32 %add, i32* %a, align 4
39+
tail call spir_func void @__assert_fail(i8 addrspace(4)* getelementptr inbounds ([2 x i8], [2 x i8] addrspace(4)* addrspacecast ([2 x i8] addrspace(1)* @.str to [2 x i8] addrspace(4)*), i64 0, i64 0), i8 addrspace(4)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(4)* addrspacecast ([11 x i8] addrspace(1)* @.str.1 to [11 x i8] addrspace(4)*), i64 0, i64 0), i32 8, i8 addrspace(4)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(4)* addrspacecast ([11 x i8] addrspace(1)* @__PRETTY_FUNCTION__._Z3foov to [11 x i8] addrspace(4)*), i64 0, i64 0))
40+
ret void
41+
}
42+
43+
; CHECK-DAG: main_TU0_kernel1
44+
define dso_local spir_kernel void @main_TU0_kernel1() #0 {
45+
entry:
46+
call spir_func void @_Z4foo1v()
47+
ret void
48+
}
49+
50+
; Function Attrs: nounwind
51+
define dso_local spir_func void @_Z4foo1v() {
52+
entry:
53+
%a = alloca i32, align 4
54+
store i32 2, i32* %a, align 4
55+
ret void
56+
}
57+
58+
; CHECK-DAG: main_TU1_kernel0
59+
define dso_local spir_kernel void @main_TU1_kernel0() #2 {
60+
entry:
61+
call spir_func void @_Z3foov()
62+
ret void
63+
}
64+
65+
; CHECK-DAG: main_TU1_kernel1
66+
define dso_local spir_kernel void @main_TU1_kernel1() #2 {
67+
entry:
68+
call spir_func void @_Z4foo2v()
69+
ret void
70+
}
71+
72+
; This function is marked with "referenced-indirectly"
73+
; Function Attrs: nounwind
74+
define dso_local spir_func void @_Z4foo2v() #1 {
75+
entry:
76+
%a = alloca i32, align 4
77+
%0 = load i32, i32 addrspace(4)* getelementptr inbounds ([1 x i32], [1 x i32] addrspace(4)* addrspacecast ([1 x i32] addrspace(1)* @_ZL2GV to [1 x i32] addrspace(4)*), i64 0, i64 0), align 4
78+
%add = add nsw i32 4, %0
79+
store i32 %add, i32* %a, align 4
80+
tail call spir_func void @__assert_fail(i8 addrspace(4)* getelementptr inbounds ([2 x i8], [2 x i8] addrspace(4)* addrspacecast ([2 x i8] addrspace(1)* @.str to [2 x i8] addrspace(4)*), i64 0, i64 0), i8 addrspace(4)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(4)* addrspacecast ([11 x i8] addrspace(1)* @.str.1 to [11 x i8] addrspace(4)*), i64 0, i64 0), i32 8, i8 addrspace(4)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(4)* addrspacecast ([11 x i8] addrspace(1)* @__PRETTY_FUNCTION__._Z3foov to [11 x i8] addrspace(4)*), i64 0, i64 0))
81+
ret void
82+
}
83+
84+
85+
; Function Attrs: convergent norecurse mustprogress
86+
define weak dso_local spir_func void @__assert_fail(i8 addrspace(4)* %expr, i8 addrspace(4)* %file, i32 %line, i8 addrspace(4)* %func) local_unnamed_addr {
87+
entry:
88+
%call = tail call spir_func i64 @_Z28__spirv_GlobalInvocationId_xv()
89+
%call1 = tail call spir_func i64 @_Z28__spirv_GlobalInvocationId_yv()
90+
%call2 = tail call spir_func i64 @_Z28__spirv_GlobalInvocationId_zv()
91+
%call3 = tail call spir_func i64 @_Z27__spirv_LocalInvocationId_xv()
92+
%call4 = tail call spir_func i64 @_Z27__spirv_LocalInvocationId_yv()
93+
%call5 = tail call spir_func i64 @_Z27__spirv_LocalInvocationId_zv()
94+
tail call spir_func void @__devicelib_assert_fail(i8 addrspace(4)* %expr, i8 addrspace(4)* %file, i32 %line, i8 addrspace(4)* %func, i64 %call, i64 %call1, i64 %call2, i64 %call3, i64 %call4, i64 %call5)
95+
ret void
96+
}
97+
98+
; Function Attrs: inlinehint norecurse mustprogress
99+
declare dso_local spir_func i64 @_Z28__spirv_GlobalInvocationId_xv() local_unnamed_addr
100+
101+
; Function Attrs: inlinehint norecurse mustprogress
102+
declare dso_local spir_func i64 @_Z28__spirv_GlobalInvocationId_yv() local_unnamed_addr
103+
104+
; Function Attrs: inlinehint norecurse mustprogress
105+
declare dso_local spir_func i64 @_Z28__spirv_GlobalInvocationId_zv() local_unnamed_addr
106+
107+
; Function Attrs: inlinehint norecurse mustprogress
108+
declare dso_local spir_func i64 @_Z27__spirv_LocalInvocationId_xv() local_unnamed_addr
109+
110+
; Function Attrs: inlinehint norecurse mustprogress
111+
declare dso_local spir_func i64 @_Z27__spirv_LocalInvocationId_yv() local_unnamed_addr
112+
113+
; Function Attrs: inlinehint norecurse mustprogress
114+
declare dso_local spir_func i64 @_Z27__spirv_LocalInvocationId_zv() local_unnamed_addr
115+
116+
; Function Attrs: convergent norecurse mustprogress
117+
define weak dso_local spir_func void @__devicelib_assert_fail(i8 addrspace(4)* %expr, i8 addrspace(4)* %file, i32 %line, i8 addrspace(4)* %func, i64 %gid0, i64 %gid1, i64 %gid2, i64 %lid0, i64 %lid1, i64 %lid2) local_unnamed_addr {
118+
entry:
119+
%call = tail call spir_func i32 (i8 addrspace(2)*, ...) @_Z18__spirv_ocl_printfPU3AS2Kcz(i8 addrspace(2)* getelementptr inbounds ([85 x i8], [85 x i8] addrspace(2)* @_ZL10assert_fmt, i64 0, i64 0), i8 addrspace(4)* %file, i32 %line, i8 addrspace(4)* %func, i64 %gid0, i64 %gid1, i64 %gid2, i64 %lid0, i64 %lid1, i64 %lid2, i8 addrspace(4)* %expr)
120+
ret void
121+
}
122+
123+
; Function Attrs: convergent
124+
declare dso_local spir_func i32 @_Z18__spirv_ocl_printfPU3AS2Kcz(i8 addrspace(2)*, ...) local_unnamed_addr
125+
126+
attributes #0 = { "sycl-module-id"="TU1.cpp" }
127+
attributes #1 = { "referenced-indirectly" "sycl-module-id"="TU2.cpp" }
128+
attributes #2 = { "sycl-module-id"="TU2.cpp" }
129+
130+
131+
!opencl.spir.version = !{!0, !0}
132+
!spirv.Source = !{!1, !1}
133+
134+
!0 = !{i32 1, i32 2}
135+
!1 = !{i32 4, i32 100000}

llvm/tools/sycl-post-link/sycl-post-link.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,16 @@ static void collectKernelModuleMap(
288288
}
289289
}
290290

291+
enum HasAssertStatus { No_Assert, Assert, Assert_Indirect };
292+
291293
// Go through function call graph searching for assert call.
292-
static bool hasAssertInFunctionCallGraph(llvm::Function *Func) {
294+
static HasAssertStatus hasAssertInFunctionCallGraph(llvm::Function *Func) {
293295
// Map holds the info about assertions in already examined functions:
294296
// true - if there is an assertion in underlying functions,
295297
// false - if there are definetely no assertions in underlying functions.
296298
static std::map<llvm::Function *, bool> hasAssertionInCallGraphMap;
297299
std::vector<llvm::Function *> FuncCallStack;
300+
bool HasIndirectlyCalledAssert = false;
298301

299302
std::vector<llvm::Function *> Workstack;
300303
Workstack.push_back(Func);
@@ -305,6 +308,10 @@ static bool hasAssertInFunctionCallGraph(llvm::Function *Func) {
305308
if (F != Func)
306309
FuncCallStack.push_back(F);
307310

311+
if (!HasIndirectlyCalledAssert &&
312+
F->hasFnAttribute("referenced-indirectly"))
313+
HasIndirectlyCalledAssert = true;
314+
308315
bool IsLeaf = true;
309316
for (auto &I : instructions(F)) {
310317
if (!isa<CallBase>(&I))
@@ -323,7 +330,8 @@ static bool hasAssertInFunctionCallGraph(llvm::Function *Func) {
323330
if (!HasAssert->second)
324331
continue;
325332

326-
return true;
333+
return HasIndirectlyCalledAssert ? Assert_Indirect : Assert;
334+
;
327335
}
328336

329337
if (CF->getName().startswith("__devicelib_assert_fail")) {
@@ -335,7 +343,7 @@ static bool hasAssertInFunctionCallGraph(llvm::Function *Func) {
335343
hasAssertionInCallGraphMap[Func] = true;
336344
hasAssertionInCallGraphMap[CF] = true;
337345

338-
return true;
346+
return HasIndirectlyCalledAssert ? Assert_Indirect : Assert;
339347
}
340348

341349
if (!CF->isDeclaration()) {
@@ -350,7 +358,7 @@ static bool hasAssertInFunctionCallGraph(llvm::Function *Func) {
350358
FuncCallStack.clear();
351359
}
352360
}
353-
return false;
361+
return No_Assert;
354362
}
355363

356364
// Input parameter KernelModuleMap is a map containing groups of kernels with
@@ -547,15 +555,36 @@ static string_vector saveDeviceImageProperty(
547555

548556
{
549557
Module *M = ResultModules[I].get();
558+
bool HasIndirectlyCalledAssert = false;
559+
std::vector<llvm::Function *> Kernels;
550560
for (auto &F : M->functions()) {
551561
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
552562
// TODO: handle function pointers.
553563
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
554-
if (hasAssertInFunctionCallGraph(&F))
564+
Kernels.push_back(&F);
565+
if (HasIndirectlyCalledAssert)
566+
continue;
567+
568+
HasAssertStatus HasAssert = hasAssertInFunctionCallGraph(&F);
569+
switch (HasAssert) {
570+
case Assert:
555571
PropSet[llvm::util::PropertySetRegistry::SYCL_ASSERT_USED].insert(
556572
{F.getName(), true});
573+
break;
574+
case Assert_Indirect:
575+
HasIndirectlyCalledAssert = true;
576+
break;
577+
case No_Assert:
578+
break;
579+
}
557580
}
558581
}
582+
583+
if (HasIndirectlyCalledAssert) {
584+
for (auto *F : Kernels)
585+
PropSet[llvm::util::PropertySetRegistry::SYCL_ASSERT_USED].insert(
586+
{F->getName(), true});
587+
}
559588
}
560589

561590
std::error_code EC;

0 commit comments

Comments
 (0)