Skip to content

Commit 162cc19

Browse files
committed
[SYCL] Add group local memory call lowering pass
The following call is included into device code for work group by SYCL headers to declare static local memory allocation in SYCL kernel: '__attribute__((opencl_local)) uint8_t *Ptr = __sycl_allocateLocalMemory(sizeof(T), alignof(T))'. Then during compilation it should be transformed somewhere to equivalent of the following code: '__local alignas(alignof(T)) uint8_t Mem[sizeof(T)]; __local uint8_t *Ptr = &Mem;' Add a pass that handles creation of fixed-size allocations in local address space at the kernel scope. Signed-off-by: Mikhail Lychkov <[email protected]>
1 parent 90c79c7 commit 162cc19

File tree

12 files changed

+355
-0
lines changed

12 files changed

+355
-0
lines changed

clang/lib/CodeGen/BackendUtil.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "llvm/Passes/PassBuilder.h"
4242
#include "llvm/Passes/PassPlugin.h"
4343
#include "llvm/Passes/StandardInstrumentations.h"
44+
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
4445
#include "llvm/Support/BuryPointer.h"
4546
#include "llvm/Support/CommandLine.h"
4647
#include "llvm/Support/MemoryBuffer.h"
@@ -959,6 +960,10 @@ void EmitAssemblyHelper::EmitAssembly(BackendAction Action,
959960
PerModulePasses.add(createSPIRITTAnnotationsPass());
960961
}
961962

963+
// This pass should be always called for SYCL device code.
964+
if (LangOpts.SYCLIsDevice)
965+
PerModulePasses.add(createSYCLLowerWGLocalMemoryPass());
966+
962967
switch (Action) {
963968
case Backend_EmitNothing:
964969
break;

clang/test/CodeGenSYCL/Inputs/sycl.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
#pragma once
22

3+
typedef __UINT8_TYPE__ uint8_t;
4+
typedef __SIZE_TYPE__ size_t;
5+
36
#define ATTR_SYCL_KERNEL __attribute__((sycl_kernel))
47

8+
#ifndef __SYCL_ALWAYS_INLINE
9+
#if __has_attribute(always_inline)
10+
#define __SYCL_ALWAYS_INLINE __attribute__((always_inline))
11+
#else
12+
#define __SYCL_ALWAYS_INLINE
13+
#endif
14+
#endif // __SYCL_ALWAYS_INLINE
15+
516
// Dummy runtime classes to model SYCL API.
617
namespace cl {
718
namespace sycl {
@@ -494,5 +505,19 @@ class image {
494505
}
495506
};
496507

508+
extern "C" SYCL_EXTERNAL __attribute__((opencl_local)) uint8_t *
509+
__sycl_allocateLocalMemory(size_t Size, size_t Alignment);
510+
511+
template <typename T>
512+
__attribute__((opencl_local)) T *
513+
__SYCL_ALWAYS_INLINE
514+
group_local_memory() {
515+
#ifdef __SYCL_DEVICE_ONLY__
516+
__attribute__((opencl_local)) uint8_t *AllocatedMem =
517+
__sycl_allocateLocalMemory(sizeof(T), alignof(T));
518+
return (__attribute__((opencl_local)) T *)AllocatedMem;
519+
#endif
520+
}
521+
497522
} // namespace sycl
498523
} // namespace cl
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -disable-llvm-passes -S -emit-llvm %s -o - | FileCheck %s
2+
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -S -emit-llvm %s -o - | FileCheck %s
3+
4+
// CHECK: [[WGLOCALMEM_1:@WGLocalMem.*]] = internal addrspace(3) global [8 x i8] undef, align 8
5+
// CHECK: [[WGLOCALMEM_2:@WGLocalMem.*]] = internal addrspace(3) global [4 x i8] undef, align 4
6+
// CHECK: [[WGLOCALMEM_3:@WGLocalMem.*]] = internal addrspace(3) global [128 x i8] undef, align 4
7+
8+
#include "Inputs/sycl.hpp"
9+
10+
constexpr size_t WgSize = 32;
11+
constexpr size_t WgCount = 4;
12+
constexpr size_t Size = WgSize * WgCount;
13+
14+
class KernelA;
15+
class KernelB;
16+
17+
using namespace cl::sycl;
18+
19+
int main() {
20+
queue Q;
21+
{
22+
Q.submit([&](handler &cgh) {
23+
cgh.parallel_for<KernelA>(
24+
range<1>(Size), [=](item<1> Item) {
25+
auto *Ptr1 = group_local_memory<long>();
26+
auto *Ptr2 = group_local_memory<float>();
27+
});
28+
});
29+
}
30+
31+
{
32+
Q.submit([&](handler &cgh) {
33+
cgh.parallel_for<KernelB>(
34+
range<1>(Size), [=](item<1> Item) {
35+
auto *Ptr3 = group_local_memory<int[WgSize]>();
36+
});
37+
});
38+
}
39+
}

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
432432
void initializeSPIRITTAnnotationsLegacyPassPass(PassRegistry &);
433433
void initializeESIMDLowerLoadStorePass(PassRegistry &);
434434
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
435+
void initializeSYCLLowerWGLocalMemoryLegacyPass(PassRegistry &);
435436
void initializeTailCallElimPass(PassRegistry&);
436437
void initializeTailDuplicatePass(PassRegistry&);
437438
void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&);

llvm/include/llvm/LinkAllPasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/IR/Function.h"
3939
#include "llvm/IR/IRPrintingPasses.h"
4040
#include "llvm/SYCLLowerIR/LowerESIMD.h"
41+
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
4142
#include "llvm/SYCLLowerIR/LowerWGScope.h"
4243
#include "llvm/Support/Valgrind.h"
4344
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
@@ -206,6 +207,7 @@ namespace {
206207
(void)llvm::createESIMDLowerLoadStorePass();
207208
(void)llvm::createESIMDLowerVecArgPass();
208209
(void)llvm::createSPIRITTAnnotationsPass();
210+
(void)llvm::createSYCLLowerWGLocalMemoryPass();
209211
std::string buf;
210212
llvm::raw_string_ostream os(buf);
211213
(void) llvm::createPrintModulePass(os);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//===-- LowerWGLocalMemory.h - SYCL kernel local memory allocation pass ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Replaces calls to __sycl_allocateLocalMemory(Size, Alignment) function with
10+
// allocation of memory in local address space at the kernel scope.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H
15+
#define LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H
16+
17+
#include "llvm/IR/Module.h"
18+
#include "llvm/IR/PassManager.h"
19+
20+
namespace llvm {
21+
22+
class SYCLLowerWGLocalMemoryPass
23+
: public PassInfoMixin<SYCLLowerWGLocalMemoryPass> {
24+
public:
25+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
26+
};
27+
28+
ModulePass *createSYCLLowerWGLocalMemoryPass();
29+
void initializeSYCLLowerWGLocalMemoryLegacyPass(PassRegistry &);
30+
31+
} // namespace llvm
32+
33+
#endif // LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H

llvm/lib/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ add_llvm_component_library(LLVMCodeGen
211211
ProfileData
212212
Scalar
213213
Support
214+
SYCLLowerIR
214215
Target
215216
TransformUtils
216217
)

llvm/lib/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
107107
initializeStackProtectorPass(Registry);
108108
initializeStackSlotColoringPass(Registry);
109109
initializeStripDebugMachineModulePass(Registry);
110+
initializeSYCLLowerWGLocalMemoryLegacyPass(Registry);
110111
initializeTailDuplicatePass(Registry);
111112
initializeTargetPassConfigPass(Registry);
112113
initializeTwoAddressInstructionPassPass(Registry);

llvm/lib/SYCLLowerIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
3838
LowerESIMD.cpp
3939
LowerESIMDVLoadVStore.cpp
4040
LowerESIMDVecArg.cpp
41+
LowerWGLocalMemory.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//===-- LowerWGLocalMemory.cpp - SYCL kernel local memory allocation pass -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass replaces calls to __sycl_allocateLocalMemory(Size, Alignment)
10+
// function with allocation of memory in local address space at the kernel
11+
// scope.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
16+
#include "llvm/IR/Function.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/InstIterator.h"
19+
#include "llvm/InitializePasses.h"
20+
#include "llvm/Pass.h"
21+
22+
using namespace llvm;
23+
24+
#define DEBUG_TYPE "LowerWGLocalMemory"
25+
26+
static constexpr char SYCL_ALLOCLOCALMEM_CALL[] = "__sycl_allocateLocalMemory";
27+
static constexpr char LOCALMEMORY_GV_PREF[] = "WGLocalMem";
28+
29+
namespace {
30+
class SYCLLowerWGLocalMemoryLegacy : public ModulePass {
31+
public:
32+
static char ID;
33+
34+
SYCLLowerWGLocalMemoryLegacy() : ModulePass(ID) {
35+
initializeSYCLLowerWGLocalMemoryLegacyPass(
36+
*PassRegistry::getPassRegistry());
37+
}
38+
39+
bool runOnModule(Module &M) override {
40+
ModuleAnalysisManager DummyMAM;
41+
auto PA = Impl.run(M, DummyMAM);
42+
return !PA.areAllPreserved();
43+
}
44+
45+
private:
46+
SYCLLowerWGLocalMemoryPass Impl;
47+
};
48+
} // namespace
49+
50+
char SYCLLowerWGLocalMemoryLegacy::ID = 0;
51+
INITIALIZE_PASS(SYCLLowerWGLocalMemoryLegacy, "sycllowerwglocalmemory",
52+
"Replace __sycl_allocateLocalMemory with allocation of memory "
53+
"in local address space",
54+
false, false)
55+
56+
ModulePass *llvm::createSYCLLowerWGLocalMemoryPass() {
57+
return new SYCLLowerWGLocalMemoryLegacy();
58+
}
59+
60+
static bool lowerAllocaLocalMem(Module &M) {
61+
SmallVector<CallInst *, 8> ToReplace;
62+
for (Function &F : M) {
63+
CallingConv::ID CC = F.getCallingConv();
64+
65+
for (auto &I : instructions(F)) {
66+
auto *CI = dyn_cast<CallInst>(&I);
67+
Function *Callee = nullptr;
68+
if (!CI || !(Callee = CI->getCalledFunction()))
69+
continue;
70+
StringRef Name = Callee->getName();
71+
if (Name != SYCL_ALLOCLOCALMEM_CALL)
72+
continue;
73+
74+
// TODO: Static local memory allocation should be requested only in
75+
// spir kernel scope.
76+
assert((CC == llvm::CallingConv::SPIR_FUNC ||
77+
CC == llvm::CallingConv::SPIR_KERNEL) &&
78+
"WG static local memery can be allocated only in kernel scope");
79+
80+
ToReplace.push_back(CI);
81+
}
82+
}
83+
84+
if (ToReplace.empty())
85+
return false;
86+
87+
for (auto *CI : ToReplace) {
88+
Value *ArgSize = CI->getArgOperand(0);
89+
uint64_t Size = cast<llvm::ConstantInt>(ArgSize)->getZExtValue();
90+
Value *ArgAlign = CI->getArgOperand(1);
91+
uint64_t Alignment = cast<llvm::ConstantInt>(ArgAlign)->getZExtValue();
92+
93+
IRBuilder<> Builder(CI);
94+
Type *LocalMemArrayTy = ArrayType::get(Builder.getInt8Ty(), Size);
95+
unsigned LocalAS =
96+
CI->getFunctionType()->getReturnType()->getPointerAddressSpace();
97+
auto *LocalMemArrayGV =
98+
new GlobalVariable(M, // module
99+
LocalMemArrayTy, // type
100+
false, // isConstant
101+
GlobalValue::InternalLinkage, // Linkage
102+
UndefValue::get(LocalMemArrayTy), // Initializer
103+
LOCALMEMORY_GV_PREF, // Name prefix
104+
nullptr, // InsertBefore
105+
GlobalVariable::NotThreadLocal, // ThreadLocalMode
106+
LocalAS // AddressSpace
107+
);
108+
LocalMemArrayGV->setAlignment(Align(Alignment));
109+
110+
Value *LocalMemArrayGVPtr = Builder.CreatePointerCast(
111+
LocalMemArrayGV,
112+
Builder.getInt8PtrTy(LocalMemArrayGV->getAddressSpace()));
113+
CI->replaceAllUsesWith(LocalMemArrayGVPtr);
114+
CI->eraseFromParent();
115+
}
116+
return true;
117+
}
118+
119+
PreservedAnalyses SYCLLowerWGLocalMemoryPass::run(Module &M,
120+
ModuleAnalysisManager &) {
121+
if (lowerAllocaLocalMem(M))
122+
return PreservedAnalyses::none();
123+
return PreservedAnalyses::all();
124+
}

0 commit comments

Comments
 (0)