@@ -357,14 +357,14 @@ void groupEntryPoints(const Module &M, EntryPointGroupMap &EntryPointsGroups,
357
357
358
358
// This function traverses over reversed call graph by BFS algorithm.
359
359
// It means that an edge links some function @func with functions
360
- // which contain call of function @func.It starts from
360
+ // which contain call of function @func. It starts from
361
361
// @StartingFunction and lifts up until it reach all reachable functions
362
362
// or it reaches some function containing "referenced-indirectly" attribute.
363
- // If it reaches "referenced-indirectly" attribute than it returns true and
364
- // an empty list .
365
- // Otherwise, it returns false and a list of reached SPIR kernel function's
366
- // names.
367
- std::pair< bool , std::vector<StringRef>>
363
+ // If it reaches "referenced-indirectly" attribute than it returns an empty
364
+ // Optional .
365
+ // Otherwise, it returns an Optional containing a list of reached
366
+ // SPIR kernel function's names.
367
+ Optional< std::vector<StringRef>>
368
368
TraverseCGToFindSPIRKernels (const Function *StartingFunction) {
369
369
std::queue<const Function *> FunctionsToVisit;
370
370
std::unordered_set<const Function *> VisitedFunctions;
@@ -375,21 +375,20 @@ TraverseCGToFindSPIRKernels(const Function *StartingFunction) {
375
375
const Function *F = FunctionsToVisit.front ();
376
376
FunctionsToVisit.pop ();
377
377
378
+ auto InsertionResult = VisitedFunctions.insert (F);
378
379
// It is possible that we insert some particular function several
379
380
// times in functionsToVisit queue.
380
- if (VisitedFunctions. find (F) != VisitedFunctions. end () )
381
+ if (!InsertionResult. second )
381
382
continue ;
382
383
383
- VisitedFunctions.insert (F);
384
-
385
384
for (const auto *U : F->users ()) {
386
385
const Instruction *I = cast<const Instruction>(U);
387
386
const Function *ParentF = I->getFunction ();
388
- if (VisitedFunctions.find (ParentF) != VisitedFunctions. end ( ))
387
+ if (VisitedFunctions.count (ParentF))
389
388
continue ;
390
389
391
390
if (ParentF->hasFnAttribute (" referenced-indirectly" )) {
392
- return {true , {} };
391
+ return {};
393
392
}
394
393
395
394
if (ParentF->getCallingConv () == CallingConv::SPIR_KERNEL)
@@ -399,39 +398,30 @@ TraverseCGToFindSPIRKernels(const Function *StartingFunction) {
399
398
}
400
399
}
401
400
402
- return { false , std::move (KernelNames)} ;
401
+ return std::move (KernelNames);
403
402
}
404
403
405
404
std::vector<StringRef> getKernelNamesUsingAssert (const Module &M) {
406
- Optional<const Function *> DevicelibAssertFailFunction;
407
- std::vector<StringRef> SPIRKernelNames;
408
- // This loop finds all SPIR kernel's names and __devicelib_assert_fail
409
- // function if it is present.
410
- for (const Function &F : M) {
411
- if (F.getCallingConv () == CallingConv::SPIR_KERNEL)
412
- SPIRKernelNames.push_back (F.getName ());
413
-
414
- if (F.getName ().startswith (" __devicelib_assert_fail" )) {
415
- assert (!DevicelibAssertFailFunction.hasValue ());
416
- DevicelibAssertFailFunction = &F;
417
- }
418
- }
419
-
405
+ auto DevicelibAssertFailFunction = M.getFunction (" __devicelib_assert_fail" );
420
406
if (!DevicelibAssertFailFunction)
421
407
return {};
422
408
423
409
auto TraverseResult =
424
- TraverseCGToFindSPIRKernels (*DevicelibAssertFailFunction);
425
- std::vector<StringRef> Result;
426
- if (TraverseResult.first ) {
427
- // If assert is met in some indirectly callable function than
428
- // we return all kernels in Module due to the current assert's design.
429
- Result = std::move (SPIRKernelNames);
430
- } else {
431
- Result = std::move (TraverseResult.second );
410
+ TraverseCGToFindSPIRKernels (DevicelibAssertFailFunction);
411
+
412
+ if (TraverseResult.hasValue ()) {
413
+ return std::move (*TraverseResult);
414
+ }
415
+
416
+ // Here we reached "referenced-indirectly", so we need to find all kernels and
417
+ // return them.
418
+ std::vector<StringRef> SPIRKernelNames;
419
+ for (const Function &F : M) {
420
+ if (F.getCallingConv () == CallingConv::SPIR_KERNEL)
421
+ SPIRKernelNames.push_back (F.getName ());
432
422
}
433
423
434
- return Result ;
424
+ return SPIRKernelNames ;
435
425
}
436
426
437
427
// Gets reqd_work_group_size information for function Func.
0 commit comments