Skip to content

[SYCL] Add group local memory call lowering pass #3329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/BackendUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
#include "llvm/Support/BuryPointer.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
Expand Down Expand Up @@ -959,6 +960,10 @@ void EmitAssemblyHelper::EmitAssembly(BackendAction Action,
PerModulePasses.add(createSPIRITTAnnotationsPass());
}

// Allocate static local memory in SYCL kernel scope for each allocation call.
if (LangOpts.SYCLIsDevice)
PerModulePasses.add(createSYCLLowerWGLocalMemoryLegacyPass());

switch (Action) {
case Backend_EmitNothing:
break;
Expand Down
8 changes: 8 additions & 0 deletions clang/test/CodeGenSYCL/group-local-memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice \
// RUN: -S -emit-llvm -mllvm -debug-pass=Structure -disable-llvm-passes \
// RUN: -o - %s 2>&1 | FileCheck %s
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown-sycldevice \
// RUN: -S -emit-llvm -mllvm -debug-pass=Structure -o - %s 2>&1 \
// RUN: | FileCheck %s

// CHECK: Replace __sycl_allocateLocalMemory with allocation of memory in local address space
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
void initializeSPIRITTAnnotationsLegacyPassPass(PassRegistry &);
void initializeESIMDLowerLoadStorePass(PassRegistry &);
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
void initializeSYCLLowerWGLocalMemoryLegacyPass(PassRegistry &);
void initializeTailCallElimPass(PassRegistry&);
void initializeTailDuplicatePass(PassRegistry&);
void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&);
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/LinkAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/SYCLLowerIR/LowerESIMD.h"
#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
#include "llvm/SYCLLowerIR/LowerWGScope.h"
#include "llvm/Support/Valgrind.h"
#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
Expand Down Expand Up @@ -206,6 +207,7 @@ namespace {
(void)llvm::createESIMDLowerLoadStorePass();
(void)llvm::createESIMDLowerVecArgPass();
(void)llvm::createSPIRITTAnnotationsPass();
(void)llvm::createSYCLLowerWGLocalMemoryLegacyPass();
std::string buf;
llvm::raw_string_ostream os(buf);
(void) llvm::createPrintModulePass(os);
Expand Down
49 changes: 49 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/LowerWGLocalMemory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===-- LowerWGLocalMemory.h - SYCL kernel local memory allocation pass ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This pass does the following for each allocate call to
// __sycl_allocateLocalMemory(Size, Alignment) function at the kernel scope:
// - inserts a global (in scope of a program) byte array of Size bytes with
// specified alignment in work group local address space.
// - replaces allocate call with access to this memory.
//
// For example, the following IR code in a kernel function:
// define spir_kernel void @KernelA() {
// %0 = call spir_func i8 addrspace(3)* @__sycl_allocateLocalMemory(
// i64 128, i64 4)
// %1 = bitcast i8 addrspace(3)* %0 to i32 addrspace(3)*
// }
//
// is translated to the following:
// @WGLocalMem = internal addrspace(3) global [128 x i8] undef, align 4
// define spir_kernel void @KernelA() {
// %0 = bitcast i8 addrspace(3)* getelementptr inbounds (
// [128 x i8], [128 x i8] addrspace(3)* @WGLocalMem, i32 0, i32 0)
// to i32 addrspace(3)*
// }
//===----------------------------------------------------------------------===//

#ifndef LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H
#define LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H

#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"

namespace llvm {

class SYCLLowerWGLocalMemoryPass
: public PassInfoMixin<SYCLLowerWGLocalMemoryPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};

ModulePass *createSYCLLowerWGLocalMemoryLegacyPass();
void initializeSYCLLowerWGLocalMemoryLegacyPass(PassRegistry &);

} // namespace llvm

#endif // LLVM_SYCLLOWERIR_LOWERWGLOCALMEMORY_H
1 change: 1 addition & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
LowerESIMD.cpp
LowerESIMDVLoadVStore.cpp
LowerESIMDVecArg.cpp
LowerWGLocalMemory.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/SYCLLowerIR
Expand Down
118 changes: 118 additions & 0 deletions llvm/lib/SYCLLowerIR/LowerWGLocalMemory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===-- LowerWGLocalMemory.cpp - SYCL kernel local memory allocation pass -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// See intro comments in the header.
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/LowerWGLocalMemory.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/Pass.h"

using namespace llvm;

#define DEBUG_TYPE "LowerWGLocalMemory"

static constexpr char SYCL_ALLOCLOCALMEM_CALL[] = "__sycl_allocateLocalMemory";
static constexpr char LOCALMEMORY_GV_PREF[] = "WGLocalMem";

namespace {
class SYCLLowerWGLocalMemoryLegacy : public ModulePass {
public:
static char ID;

SYCLLowerWGLocalMemoryLegacy() : ModulePass(ID) {
initializeSYCLLowerWGLocalMemoryLegacyPass(
*PassRegistry::getPassRegistry());
}

bool runOnModule(Module &M) override {
ModuleAnalysisManager DummyMAM;
auto PA = Impl.run(M, DummyMAM);
return !PA.areAllPreserved();
}

private:
SYCLLowerWGLocalMemoryPass Impl;
};
} // namespace

char SYCLLowerWGLocalMemoryLegacy::ID = 0;
INITIALIZE_PASS(SYCLLowerWGLocalMemoryLegacy, "sycllowerwglocalmemory",
"Replace __sycl_allocateLocalMemory with allocation of memory "
"in local address space",
false, false)

ModulePass *llvm::createSYCLLowerWGLocalMemoryLegacyPass() {
return new SYCLLowerWGLocalMemoryLegacy();
}

// Static local memory allocation should be allowed only in a scope of a kernel
// (not a device function) and shouldn't be called inside loop or if statement
// to make it consistent with OpenCL restriction.
// TODO: Relax that restriction for SYCL or modify this pass to move allocation
// of memory up to a kernel scope at the beginning for each nested device
// function call, loop or if statement.
static void lowerAllocaLocalMemCall(CallInst *CI, Module &M) {
assert(CI);

Value *ArgSize = CI->getArgOperand(0);
uint64_t Size = cast<llvm::ConstantInt>(ArgSize)->getZExtValue();
Value *ArgAlign = CI->getArgOperand(1);
uint64_t Alignment = cast<llvm::ConstantInt>(ArgAlign)->getZExtValue();

IRBuilder<> Builder(CI);
Type *LocalMemArrayTy = ArrayType::get(Builder.getInt8Ty(), Size);
unsigned LocalAS =
CI->getFunctionType()->getReturnType()->getPointerAddressSpace();
auto *LocalMemArrayGV =
new GlobalVariable(M, // module
LocalMemArrayTy, // type
false, // isConstant
GlobalValue::InternalLinkage, // Linkage
UndefValue::get(LocalMemArrayTy), // Initializer
LOCALMEMORY_GV_PREF, // Name prefix
nullptr, // InsertBefore
GlobalVariable::NotThreadLocal, // ThreadLocalMode
LocalAS // AddressSpace
);
LocalMemArrayGV->setAlignment(Align(Alignment));

Value *GVPtr =
Builder.CreatePointerCast(LocalMemArrayGV, Builder.getInt8PtrTy(LocalAS));
CI->replaceAllUsesWith(GVPtr);

assert(CI->use_empty() && "removing live instruction");
CI->eraseFromParent();
}

static bool allocaWGLocalMemory(Module &M) {
Function *ALMFunc = M.getFunction(SYCL_ALLOCLOCALMEM_CALL);
if (!ALMFunc)
return false;

assert(ALMFunc->isDeclaration() && "should have declaration only");

for (User *U : ALMFunc->users()) {
auto *CI = cast<CallInst>(U);
lowerAllocaLocalMemCall(CI, M);
}

// Remove __sycl_allocateLocalMemory declaration.
assert(ALMFunc->use_empty() && "__sycl_allocateLocalMemory is still in use");
ALMFunc->eraseFromParent();

return true;
}

PreservedAnalyses SYCLLowerWGLocalMemoryPass::run(Module &M,
ModuleAnalysisManager &) {
if (allocaWGLocalMemory(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
50 changes: 50 additions & 0 deletions llvm/test/SYCLLowerIR/group_local_memory.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; RUN: opt -S -sycllowerwglocalmemory < %s | FileCheck %s

; CHECK-DAG: [[WGLOCALMEM_1:@WGLocalMem.*]] = internal addrspace(3) global [128 x i8] undef, align 4
; CHECK-DAG: [[WGLOCALMEM_2:@WGLocalMem.*]] = internal addrspace(3) global [4 x i8] undef, align 4
; CHECK-DAG: [[WGLOCALMEM_3:@WGLocalMem.*]] = internal addrspace(3) global [256 x i8] undef, align 8

; CHECK-NOT: __sycl_allocateLocalMemory

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown-sycldevice"

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTS7KernelA() local_unnamed_addr #0 {
entry:
%0 = tail call spir_func i8 addrspace(3)* @__sycl_allocateLocalMemory(i64 128, i64 4) #2
%1 = bitcast i8 addrspace(3)* %0 to i32 addrspace(3)*
; CHECK: i8 addrspace(3)* getelementptr inbounds ([128 x i8], [128 x i8] addrspace(3)* [[WGLOCALMEM_1]], i32 0, i32 0)
%2 = getelementptr inbounds i8, i8 addrspace(3)* %0, i64 4
; CHECK: i8 addrspace(3)* getelementptr inbounds ([128 x i8], [128 x i8] addrspace(3)* [[WGLOCALMEM_1]], i32 0, i32 0)
%3 = tail call spir_func i8 addrspace(3)* @__sycl_allocateLocalMemory(i64 4, i64 4) #2
%4 = bitcast i8 addrspace(3)* %3 to float addrspace(3)*
; CHECK: i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[WGLOCALMEM_2]], i32 0, i32 0)
ret void
}

; Function Attrs: convergent
declare dso_local spir_func i8 addrspace(3)* @__sycl_allocateLocalMemory(i64, i64) local_unnamed_addr #1

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTS7KernelB() local_unnamed_addr #0 {
entry:
%0 = tail call spir_func i8 addrspace(3)* @__sycl_allocateLocalMemory(i64 256, i64 8) #2
%1 = bitcast i8 addrspace(3)* %0 to i64 addrspace(3)*
; CHECK: i8 addrspace(3)* getelementptr inbounds ([256 x i8], [256 x i8] addrspace(3)* [[WGLOCALMEM_3]], i32 0, i32 0)
ret void
}

attributes #0 = { convergent norecurse "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { convergent }

!llvm.module.flags = !{!0}
!opencl.spir.version = !{!1}
!spirv.Source = !{!2}
!llvm.ident = !{!3}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 2}
!2 = !{i32 4, i32 100000}
!3 = !{!"clang version 13.0.0"}
1 change: 1 addition & 0 deletions llvm/tools/opt/opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ int main(int argc, char **argv) {
initializeSPIRITTAnnotationsLegacyPassPass(Registry);
initializeESIMDLowerLoadStorePass(Registry);
initializeESIMDLowerVecArgLegacyPassPass(Registry);
initializeSYCLLowerWGLocalMemoryLegacyPass(Registry);

#ifdef BUILD_EXAMPLES
initializeExampleIRTransforms(Registry);
Expand Down