@@ -53,6 +53,20 @@ ModulePass *llvm::createSYCLLowerWGLocalMemoryLegacyPass() {
53
53
}
54
54
55
55
static void lowerAllocaLocalMemCall (CallInst *CI, Module &M) {
56
+ assert (CI);
57
+
58
+ // Static local memory allocation should be allowed only in a scope of a spir
59
+ // kernel (not a spir function) to make it consistent with OpenCL restriction.
60
+ // However, __sycl_allocateLocalMemory is invoked in a scope of kernel lambda
61
+ // call operator, which is technically not a SPIR-V kernel scope.
62
+ // TODO: Relax that restriction for SYCL or modify this pass to move
63
+ // allocation of memory up to a spir kernel scope for each nested device
64
+ // function call.
65
+ CallingConv::ID CC = CI->getCaller ()->getCallingConv ();
66
+ assert ((CC == llvm::CallingConv::SPIR_FUNC ||
67
+ CC == llvm::CallingConv::SPIR_KERNEL) &&
68
+ " WG static local memory can be allocated only in kernel scope" );
69
+
56
70
Value *ArgSize = CI->getArgOperand (0 );
57
71
uint64_t Size = cast<llvm::ConstantInt>(ArgSize)->getZExtValue ();
58
72
Value *ArgAlign = CI->getArgOperand (1 );
@@ -84,41 +98,22 @@ static void lowerAllocaLocalMemCall(CallInst *CI, Module &M) {
84
98
}
85
99
86
100
static bool allocaWGLocalMemory (Module &M) {
87
- for (Function &F : M) {
88
- if (!F.isDeclaration () || F.getName () != SYCL_ALLOCLOCALMEM_CALL)
89
- continue ;
90
-
91
- SmallVector<CallInst *, 4 > ALMCalls;
92
- for (auto *U : F.users ()) {
93
- if (auto *CI = dyn_cast<CallInst>(U))
94
- ALMCalls.push_back (CI);
95
- }
96
-
97
- for (auto &CI : ALMCalls) {
98
- // Static local memory allocation should be requested only in
99
- // spir kernel scope (not a spir function) in accordance to OpenCL
100
- // restriction. However, __sycl_allocateLocalMemory is invoced in kernel
101
- // lambda call operator's scope, which is technically not SPIR-V kernel
102
- // scope.
103
- // TODO: Check if restriction may be relaxed for SYCL or imrpove pass
104
- // to move allocation of memory up to a spir kernel scope for each nested
105
- // device function call.
106
- CallingConv::ID CC = CI->getCaller ()->getCallingConv ();
107
- assert ((CC == llvm::CallingConv::SPIR_FUNC ||
108
- CC == llvm::CallingConv::SPIR_KERNEL) &&
109
- " WG static local memory can be allocated only in kernel scope" );
110
-
111
- lowerAllocaLocalMemCall (CI, M);
112
- }
113
-
114
- // Remove __sycl_allocateLocalMemory declaration.
115
- assert (F.use_empty () && " __sycl_allocateLocalMemory is still in use" );
116
- F.eraseFromParent ();
117
-
118
- return true ;
101
+ Function *ALMFunc = M.getFunction (SYCL_ALLOCLOCALMEM_CALL);
102
+ if (!ALMFunc)
103
+ return false ;
104
+
105
+ assert (ALMFunc->isDeclaration () && " should have declaration only" );
106
+
107
+ for (User *U : ALMFunc->users ()) {
108
+ auto *CI = cast<CallInst>(U);
109
+ lowerAllocaLocalMemCall (CI, M);
119
110
}
120
111
121
- return false ;
112
+ // Remove __sycl_allocateLocalMemory declaration.
113
+ assert (ALMFunc->use_empty () && " __sycl_allocateLocalMemory is still in use" );
114
+ ALMFunc->eraseFromParent ();
115
+
116
+ return true ;
122
117
}
123
118
124
119
PreservedAnalyses SYCLLowerWGLocalMemoryPass::run (Module &M,
0 commit comments