|
46 | 46 | #include <algorithm>
|
47 | 47 | #include <map>
|
48 | 48 | #include <memory>
|
| 49 | +#include <queue> |
49 | 50 | #include <string>
|
| 51 | +#include <unordered_set> |
| 52 | +#include <utility> |
50 | 53 | #include <vector>
|
51 | 54 |
|
52 | 55 | using namespace llvm;
|
@@ -352,124 +355,81 @@ void groupEntryPoints(const Module &M, EntryPointGroupMap &EntryPointsGroups,
|
352 | 355 | EntryPointsGroups[GLOBAL_SCOPE_NAME] = {};
|
353 | 356 | }
|
354 | 357 |
|
355 |
| -enum HasAssertStatus { No_Assert, Assert, Assert_Indirect }; |
356 |
| - |
357 |
| -// Go through function call graph searching for assert call. |
358 |
| -HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) { |
359 |
| - // Map holds the info about assertions in already examined functions: |
360 |
| - // true - if there is an assertion in underlying functions, |
361 |
| - // false - if there are definetely no assertions in underlying functions. |
362 |
| - static std::map<const Function *, bool> hasAssertionInCallGraphMap; |
363 |
| - std::vector<const Function *> FuncCallStack; |
364 |
| - |
365 |
| - static std::vector<const Function *> isIndirectlyCalledInGraph; |
366 |
| - |
367 |
| - std::vector<const Function *> Workstack; |
368 |
| - Workstack.push_back(Func); |
369 |
| - |
370 |
| - while (!Workstack.empty()) { |
371 |
| - const Function *F = Workstack.back(); |
372 |
| - Workstack.pop_back(); |
373 |
| - if (F != Func) |
374 |
| - FuncCallStack.push_back(F); |
375 |
| - |
376 |
| - bool HasIndirectlyCalledAttr = false; |
377 |
| - if (std::find(isIndirectlyCalledInGraph.begin(), |
378 |
| - isIndirectlyCalledInGraph.end(), |
379 |
| - F) != isIndirectlyCalledInGraph.end()) |
380 |
| - HasIndirectlyCalledAttr = true; |
381 |
| - else if (F->hasFnAttribute("referenced-indirectly")) { |
382 |
| - HasIndirectlyCalledAttr = true; |
383 |
| - isIndirectlyCalledInGraph.push_back(F); |
384 |
| - } |
| 358 | +// This function traverses over reversed call graph by BFS algorithm. |
| 359 | +// It means that an edge links some function @func with functions |
| 360 | +// which contain call of function @func.It starts from |
| 361 | +// @StartingFunction and lifts up until it reach all reachable functions |
| 362 | +// or it reaches some function containing "referenced-indirectly" attribute. |
| 363 | +// If it reaches "referenced-indirectly" attribute than it returns true and |
| 364 | +// an empty list. |
| 365 | +// Otherwise, it returns false and a list of reached SPIR kernel function's |
| 366 | +// names. |
| 367 | +std::pair<bool, std::vector<StringRef>> |
| 368 | +TraverseCGToFindSPIRKernels(const Function *StartingFunction) { |
| 369 | + std::queue<const Function *> FunctionsToVisit; |
| 370 | + std::unordered_set<const Function *> VisitedFunctions; |
| 371 | + FunctionsToVisit.push(StartingFunction); |
| 372 | + std::vector<StringRef> KernelNames; |
| 373 | + |
| 374 | + while (!FunctionsToVisit.empty()) { |
| 375 | + const Function *F = FunctionsToVisit.front(); |
| 376 | + FunctionsToVisit.pop(); |
| 377 | + |
| 378 | + // It is possible that we insert some particular function several |
| 379 | + // times in functionsToVisit queue. |
| 380 | + if (VisitedFunctions.find(F) != VisitedFunctions.end()) |
| 381 | + continue; |
385 | 382 |
|
386 |
| - bool IsLeaf = true; |
387 |
| - for (const auto &I : instructions(F)) { |
388 |
| - if (!isa<CallBase>(&I)) |
389 |
| - continue; |
| 383 | + VisitedFunctions.insert(F); |
390 | 384 |
|
391 |
| - const Function *CF = cast<CallBase>(&I)->getCalledFunction(); |
392 |
| - if (!CF) |
| 385 | + for (const auto *U : F->users()) { |
| 386 | + const Instruction *I = cast<const Instruction>(U); |
| 387 | + const Function *ParentF = I->getFunction(); |
| 388 | + if (VisitedFunctions.find(ParentF) != VisitedFunctions.end()) |
393 | 389 | continue;
|
394 | 390 |
|
395 |
| - bool IsIndirectlyCalled = |
396 |
| - HasIndirectlyCalledAttr || |
397 |
| - std::find(isIndirectlyCalledInGraph.begin(), |
398 |
| - isIndirectlyCalledInGraph.end(), |
399 |
| - CF) != isIndirectlyCalledInGraph.end(); |
400 |
| - |
401 |
| - // Return if we've already discovered if there are asserts in the |
402 |
| - // function call graph. |
403 |
| - auto HasAssert = hasAssertionInCallGraphMap.find(CF); |
404 |
| - if (HasAssert != hasAssertionInCallGraphMap.end()) { |
405 |
| - // If we know, that this function does not contain assert, we still |
406 |
| - // should investigate another instructions in the function. |
407 |
| - if (!HasAssert->second) |
408 |
| - continue; |
409 |
| - |
410 |
| - return IsIndirectlyCalled ? Assert_Indirect : Assert; |
| 391 | + if (ParentF->hasFnAttribute("referenced-indirectly")) { |
| 392 | + return {true, {}}; |
411 | 393 | }
|
412 | 394 |
|
413 |
| - if (CF->getName().startswith("__devicelib_assert_fail")) { |
414 |
| - // Mark all the functions above in call graph as ones that can call |
415 |
| - // assert. |
416 |
| - for (const auto *It : FuncCallStack) |
417 |
| - hasAssertionInCallGraphMap[It] = true; |
418 |
| - |
419 |
| - hasAssertionInCallGraphMap[Func] = true; |
420 |
| - hasAssertionInCallGraphMap[CF] = true; |
421 |
| - |
422 |
| - return IsIndirectlyCalled ? Assert_Indirect : Assert; |
423 |
| - } |
| 395 | + if (ParentF->getCallingConv() == CallingConv::SPIR_KERNEL) |
| 396 | + KernelNames.push_back(ParentF->getName()); |
424 | 397 |
|
425 |
| - if (!CF->isDeclaration()) { |
426 |
| - Workstack.push_back(CF); |
427 |
| - IsLeaf = false; |
428 |
| - if (HasIndirectlyCalledAttr) |
429 |
| - isIndirectlyCalledInGraph.push_back(CF); |
430 |
| - } |
431 |
| - } |
432 |
| - |
433 |
| - if (IsLeaf && !FuncCallStack.empty()) { |
434 |
| - // Mark the leaf function as one that definetely does not call assert. |
435 |
| - hasAssertionInCallGraphMap[FuncCallStack.back()] = false; |
436 |
| - FuncCallStack.clear(); |
| 398 | + FunctionsToVisit.push(ParentF); |
437 | 399 | }
|
438 | 400 | }
|
439 |
| - return No_Assert; |
| 401 | + |
| 402 | + return {false, std::move(KernelNames)}; |
440 | 403 | }
|
441 | 404 |
|
442 | 405 | std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
|
443 |
| - std::vector<StringRef> Result; |
444 |
| - |
445 |
| - bool HasIndirectlyCalledAssert = false; |
446 |
| - EntryPointGroup Kernels; |
447 |
| - for (const auto &F : M.functions()) { |
448 |
| - // TODO: handle SYCL_EXTERNAL functions for dynamic linkage. |
449 |
| - // TODO: handle function pointers. |
450 |
| - if (F.getCallingConv() != CallingConv::SPIR_KERNEL) |
451 |
| - continue; |
452 |
| - |
453 |
| - Kernels.push_back(&F); |
454 |
| - if (HasIndirectlyCalledAssert) |
455 |
| - continue; |
456 |
| - |
457 |
| - HasAssertStatus HasAssert = hasAssertInFunctionCallGraph(&F); |
458 |
| - switch (HasAssert) { |
459 |
| - case Assert: |
460 |
| - Result.push_back(F.getName()); |
461 |
| - break; |
462 |
| - case Assert_Indirect: |
463 |
| - HasIndirectlyCalledAssert = true; |
464 |
| - break; |
465 |
| - case No_Assert: |
466 |
| - break; |
| 406 | + Optional<const Function *> DevicelibAssertFailFunction; |
| 407 | + std::vector<StringRef> SPIRKernelNames; |
| 408 | + // This loop finds all SPIR kernel's names and __devicelib_assert_fail |
| 409 | + // function if it is present. |
| 410 | + for (const Function &F : M) { |
| 411 | + if (F.getCallingConv() == CallingConv::SPIR_KERNEL) |
| 412 | + SPIRKernelNames.push_back(F.getName()); |
| 413 | + |
| 414 | + if (F.getName().startswith("__devicelib_assert_fail")) { |
| 415 | + assert(!DevicelibAssertFailFunction.hasValue()); |
| 416 | + DevicelibAssertFailFunction = &F; |
467 | 417 | }
|
468 | 418 | }
|
469 | 419 |
|
470 |
| - if (HasIndirectlyCalledAssert) |
471 |
| - for (const auto *F : Kernels) |
472 |
| - Result.push_back(F->getName()); |
| 420 | + if (!DevicelibAssertFailFunction) |
| 421 | + return {}; |
| 422 | + |
| 423 | + auto TraverseResult = |
| 424 | + TraverseCGToFindSPIRKernels(*DevicelibAssertFailFunction); |
| 425 | + std::vector<StringRef> Result; |
| 426 | + if (TraverseResult.first) { |
| 427 | + // If assert is met in some indirectly callable function than |
| 428 | + // we return all kernels in Module due to the current assert's design. |
| 429 | + Result = std::move(SPIRKernelNames); |
| 430 | + } else { |
| 431 | + Result = std::move(TraverseResult.second); |
| 432 | + } |
473 | 433 |
|
474 | 434 | return Result;
|
475 | 435 | }
|
|
0 commit comments