Skip to content

Commit 747e14d

Browse files
committed
[SYCL] optimize getKernelNamesUsingAssert
Now it traverses reversed call graph by BFS algorithm from __devicelib_assert_fail function up to SPIR kernels.
1 parent fd0b108 commit 747e14d

File tree

4 files changed

+96
-136
lines changed

4 files changed

+96
-136
lines changed

llvm/test/tools/sycl-post-link/assert-indirect-with-split-2.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,6 @@ entry:
6363
ret void
6464
}
6565

66-
; CHECK-NOT: main_TU0_kernel1
67-
define dso_local spir_kernel void @main_TU0_kernel1() #0 {
68-
entry:
69-
call spir_func void @_Z4foo1v()
70-
ret void
71-
}
72-
7366
; Function Attrs: nounwind
7467
define dso_local spir_func void @_Z4foo1v() {
7568
entry:
@@ -85,6 +78,13 @@ entry:
8578
ret void
8679
}
8780

81+
; CHECK-NOT: main_TU0_kernel1
82+
define dso_local spir_kernel void @main_TU0_kernel1() #0 {
83+
entry:
84+
call spir_func void @_Z4foo1v()
85+
ret void
86+
}
87+
8888

8989
; This function is marked with "referenced-indirectly", but it doesn't call an assert
9090
; Function Attrs: nounwind

llvm/test/tools/sycl-post-link/assert-property-2.ll

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ entry:
122122
ret void
123123
}
124124

125-
; CHECK: _ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E7Kernel9
125+
; CHECK-DAG: _ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E7Kernel9
126126
; Function Attrs: convergent noinline norecurse mustprogress
127127
define weak_odr dso_local spir_kernel void @_ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E7Kernel9() #0 {
128128
entry:
129129
call spir_func void @_Z1Jv()
130130
ret void
131131
}
132132

133-
; CHECK: _ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E8Kernel10
133+
; CHECK-DAG: _ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E8Kernel10
134134
; Function Attrs: convergent noinline norecurse optnone mustprogress
135135
define weak_odr dso_local spir_kernel void @_ZTSZZ4mainENKUlRN2cl4sycl7handlerEE_clES2_E8Kernel10() #0 {
136136
entry:
@@ -164,7 +164,7 @@ entry:
164164
ret void
165165
}
166166

167-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE6Kernel
167+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE6Kernel
168168
; Function Attrs: convergent norecurse mustprogress
169169
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE6Kernel"() local_unnamed_addr #0 {
170170
entry:
@@ -186,7 +186,7 @@ entry:
186186
ret void
187187
}
188188

189-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel2
189+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel2
190190
; Function Attrs: convergent norecurse mustprogress
191191
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel2"() local_unnamed_addr #0 {
192192
entry:
@@ -216,7 +216,7 @@ entry:
216216
ret void
217217
}
218218

219-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel3
219+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel3
220220
; Function Attrs: convergent norecurse mustprogress
221221
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel3"() local_unnamed_addr #0 {
222222
entry:
@@ -244,15 +244,15 @@ entry:
244244
ret void
245245
}
246246

247-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel4
247+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel4
248248
; Function Attrs: convergent norecurse mustprogress
249249
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel4"() local_unnamed_addr #0 {
250250
entry:
251251
call spir_func void @_Z7common2v()
252252
ret void
253253
}
254254

255-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel5
255+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel5
256256
; Function Attrs: convergent norecurse mustprogress
257257
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel5"() local_unnamed_addr #0 {
258258
entry:
@@ -267,23 +267,14 @@ entry:
267267
ret void
268268
}
269269

270-
; CHECK-NOT: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel
271-
; Function Attrs: convergent norecurse mustprogress
272-
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel6"() local_unnamed_addr #0 {
273-
entry:
274-
call spir_func void @_Z6E_exclv()
275-
call spir_func void @_Z6E_exclv()
276-
ret void
277-
}
278-
279270
; Function Attrs: convergent norecurse nounwind mustprogress
280271
define dso_local spir_func void @_Z6F_inclv() local_unnamed_addr {
281272
entry:
282273
call spir_func void @_Z11assert_funcv()
283274
ret void
284275
}
285276

286-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel7
277+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel7
287278
; Function Attrs: convergent norecurse mustprogress
288279
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel7"() local_unnamed_addr #0 {
289280
entry:
@@ -328,14 +319,23 @@ entry:
328319
ret void
329320
}
330321

331-
; CHECK: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel8
322+
; CHECK-DAG: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel8
332323
; Function Attrs: convergent norecurse mustprogress
333324
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel8"() local_unnamed_addr #0 {
334325
call spir_func void @_Z1Gv()
335326
call spir_func void @_Z1Hv()
336327
ret void
337328
}
338329

330+
; CHECK-NOT: _ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel6
331+
; Function Attrs: convergent norecurse mustprogress
332+
define weak_odr dso_local spir_kernel void @"_ZTSZZ4mainENK3$_0clERN2cl4sycl7handlerEE7Kernel6"() local_unnamed_addr #0 {
333+
entry:
334+
call spir_func void @_Z6E_exclv()
335+
call spir_func void @_Z6E_exclv()
336+
ret void
337+
}
338+
339339
; Function Attrs: convergent norecurse mustprogress
340340
define weak dso_local spir_func void @__assert_fail(i8 addrspace(4)* %expr, i8 addrspace(4)* %file, i32 %line, i8 addrspace(4)* %func) local_unnamed_addr {
341341
entry:

llvm/test/tools/sycl-post-link/assert-property-with-split.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ target triple = "spir64-unknown-linux"
1818

1919
; CHECK: [SYCL/assert used]
2020

21-
; CHECK: _ZTSZ4mainE11TU0_kernel0
21+
; CHECK-DAG: _ZTSZ4mainE11TU0_kernel0
2222
define dso_local spir_kernel void @_ZTSZ4mainE11TU0_kernel0() #0 {
2323
entry:
2424
call spir_func void @_Z3foov()
@@ -36,6 +36,13 @@ entry:
3636
ret void
3737
}
3838

39+
; CHECK-DAG: _ZTSZ4mainE10TU1_kernel
40+
define dso_local spir_kernel void @_ZTSZ4mainE10TU1_kernel() #1 {
41+
entry:
42+
call spir_func void @_Z4foo2v()
43+
ret void
44+
}
45+
3946
; CHECK-NOT: _ZTSZ4mainE11TU0_kernel1
4047
define dso_local spir_kernel void @_ZTSZ4mainE11TU0_kernel1() #0 {
4148
entry:
@@ -51,13 +58,6 @@ entry:
5158
ret void
5259
}
5360

54-
; CHECK: _ZTSZ4mainE10TU1_kernel
55-
define dso_local spir_kernel void @_ZTSZ4mainE10TU1_kernel() #1 {
56-
entry:
57-
call spir_func void @_Z4foo2v()
58-
ret void
59-
}
60-
6161
; Function Attrs: nounwind
6262
define dso_local spir_func void @_Z4foo2v() {
6363
entry:

llvm/tools/sycl-post-link/sycl-post-link.cpp

Lines changed: 63 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
#include <algorithm>
4747
#include <map>
4848
#include <memory>
49+
#include <queue>
4950
#include <string>
51+
#include <unordered_set>
52+
#include <utility>
5053
#include <vector>
5154

5255
using namespace llvm;
@@ -352,124 +355,81 @@ void groupEntryPoints(const Module &M, EntryPointGroupMap &EntryPointsGroups,
352355
EntryPointsGroups[GLOBAL_SCOPE_NAME] = {};
353356
}
354357

355-
enum HasAssertStatus { No_Assert, Assert, Assert_Indirect };
356-
357-
// Go through function call graph searching for assert call.
358-
HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
359-
// Map holds the info about assertions in already examined functions:
360-
// true - if there is an assertion in underlying functions,
361-
// false - if there are definetely no assertions in underlying functions.
362-
static std::map<const Function *, bool> hasAssertionInCallGraphMap;
363-
std::vector<const Function *> FuncCallStack;
364-
365-
static std::vector<const Function *> isIndirectlyCalledInGraph;
366-
367-
std::vector<const Function *> Workstack;
368-
Workstack.push_back(Func);
369-
370-
while (!Workstack.empty()) {
371-
const Function *F = Workstack.back();
372-
Workstack.pop_back();
373-
if (F != Func)
374-
FuncCallStack.push_back(F);
375-
376-
bool HasIndirectlyCalledAttr = false;
377-
if (std::find(isIndirectlyCalledInGraph.begin(),
378-
isIndirectlyCalledInGraph.end(),
379-
F) != isIndirectlyCalledInGraph.end())
380-
HasIndirectlyCalledAttr = true;
381-
else if (F->hasFnAttribute("referenced-indirectly")) {
382-
HasIndirectlyCalledAttr = true;
383-
isIndirectlyCalledInGraph.push_back(F);
384-
}
358+
// This function traverses over reversed call graph by BFS algorithm.
359+
// It means that an edge links some function @func with functions
360+
// which contain call of function @func.It starts from
361+
// @StartingFunction and lifts up until it reach all reachable functions
362+
// or it reaches some function containing "referenced-indirectly" attribute.
363+
// If it reaches "referenced-indirectly" attribute than it returns true and
364+
// an empty list.
365+
// Otherwise, it returns false and a list of reached SPIR kernel function's
366+
// names.
367+
std::pair<bool, std::vector<StringRef>>
368+
TraverseCGToFindSPIRKernels(const Function *StartingFunction) {
369+
std::queue<const Function *> FunctionsToVisit;
370+
std::unordered_set<const Function *> VisitedFunctions;
371+
FunctionsToVisit.push(StartingFunction);
372+
std::vector<StringRef> KernelNames;
373+
374+
while (!FunctionsToVisit.empty()) {
375+
const Function *F = FunctionsToVisit.front();
376+
FunctionsToVisit.pop();
377+
378+
// It is possible that we insert some particular function several
379+
// times in functionsToVisit queue.
380+
if (VisitedFunctions.find(F) != VisitedFunctions.end())
381+
continue;
385382

386-
bool IsLeaf = true;
387-
for (const auto &I : instructions(F)) {
388-
if (!isa<CallBase>(&I))
389-
continue;
383+
VisitedFunctions.insert(F);
390384

391-
const Function *CF = cast<CallBase>(&I)->getCalledFunction();
392-
if (!CF)
385+
for (const auto *U : F->users()) {
386+
const Instruction *I = cast<const Instruction>(U);
387+
const Function *ParentF = I->getFunction();
388+
if (VisitedFunctions.find(ParentF) != VisitedFunctions.end())
393389
continue;
394390

395-
bool IsIndirectlyCalled =
396-
HasIndirectlyCalledAttr ||
397-
std::find(isIndirectlyCalledInGraph.begin(),
398-
isIndirectlyCalledInGraph.end(),
399-
CF) != isIndirectlyCalledInGraph.end();
400-
401-
// Return if we've already discovered if there are asserts in the
402-
// function call graph.
403-
auto HasAssert = hasAssertionInCallGraphMap.find(CF);
404-
if (HasAssert != hasAssertionInCallGraphMap.end()) {
405-
// If we know, that this function does not contain assert, we still
406-
// should investigate another instructions in the function.
407-
if (!HasAssert->second)
408-
continue;
409-
410-
return IsIndirectlyCalled ? Assert_Indirect : Assert;
391+
if (ParentF->hasFnAttribute("referenced-indirectly")) {
392+
return {true, {}};
411393
}
412394

413-
if (CF->getName().startswith("__devicelib_assert_fail")) {
414-
// Mark all the functions above in call graph as ones that can call
415-
// assert.
416-
for (const auto *It : FuncCallStack)
417-
hasAssertionInCallGraphMap[It] = true;
418-
419-
hasAssertionInCallGraphMap[Func] = true;
420-
hasAssertionInCallGraphMap[CF] = true;
421-
422-
return IsIndirectlyCalled ? Assert_Indirect : Assert;
423-
}
395+
if (ParentF->getCallingConv() == CallingConv::SPIR_KERNEL)
396+
KernelNames.push_back(ParentF->getName());
424397

425-
if (!CF->isDeclaration()) {
426-
Workstack.push_back(CF);
427-
IsLeaf = false;
428-
if (HasIndirectlyCalledAttr)
429-
isIndirectlyCalledInGraph.push_back(CF);
430-
}
431-
}
432-
433-
if (IsLeaf && !FuncCallStack.empty()) {
434-
// Mark the leaf function as one that definetely does not call assert.
435-
hasAssertionInCallGraphMap[FuncCallStack.back()] = false;
436-
FuncCallStack.clear();
398+
FunctionsToVisit.push(ParentF);
437399
}
438400
}
439-
return No_Assert;
401+
402+
return {false, std::move(KernelNames)};
440403
}
441404

442405
std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
443-
std::vector<StringRef> Result;
444-
445-
bool HasIndirectlyCalledAssert = false;
446-
EntryPointGroup Kernels;
447-
for (const auto &F : M.functions()) {
448-
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
449-
// TODO: handle function pointers.
450-
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
451-
continue;
452-
453-
Kernels.push_back(&F);
454-
if (HasIndirectlyCalledAssert)
455-
continue;
456-
457-
HasAssertStatus HasAssert = hasAssertInFunctionCallGraph(&F);
458-
switch (HasAssert) {
459-
case Assert:
460-
Result.push_back(F.getName());
461-
break;
462-
case Assert_Indirect:
463-
HasIndirectlyCalledAssert = true;
464-
break;
465-
case No_Assert:
466-
break;
406+
Optional<const Function *> DevicelibAssertFailFunction;
407+
std::vector<StringRef> SPIRKernelNames;
408+
// This loop finds all SPIR kernel's names and __devicelib_assert_fail
409+
// function if it is present.
410+
for (const Function &F : M) {
411+
if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
412+
SPIRKernelNames.push_back(F.getName());
413+
414+
if (F.getName().startswith("__devicelib_assert_fail")) {
415+
assert(!DevicelibAssertFailFunction.hasValue());
416+
DevicelibAssertFailFunction = &F;
467417
}
468418
}
469419

470-
if (HasIndirectlyCalledAssert)
471-
for (const auto *F : Kernels)
472-
Result.push_back(F->getName());
420+
if (!DevicelibAssertFailFunction)
421+
return {};
422+
423+
auto TraverseResult =
424+
TraverseCGToFindSPIRKernels(*DevicelibAssertFailFunction);
425+
std::vector<StringRef> Result;
426+
if (TraverseResult.first) {
427+
// If assert is met in some indirectly callable function than
428+
// we return all kernels in Module due to the current assert's design.
429+
Result = std::move(SPIRKernelNames);
430+
} else {
431+
Result = std::move(TraverseResult.second);
432+
}
473433

474434
return Result;
475435
}

0 commit comments

Comments
 (0)