Skip to content

Commit d221363

Browse files
author
Alexander Johnston
committed
[SYCL] Local Accessor Support for CUDA
Provides the LocalAccessorToSharedMemory compiler pass required for supporting SYCL local accessors in CUDA. Contributors Alexander Johnston <[email protected]> David Wood <[email protected]> Signed-off-by: Alexander Johnston <[email protected]>
1 parent 278d45d commit d221363

15 files changed

+613
-47
lines changed

clang/lib/CodeGen/BackendUtil.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,6 @@ void EmitAssemblyHelper::EmitAssembly(BackendAction Action,
842842
PerFunctionPasses.add(
843843
createTargetTransformInfoWrapperPass(getTargetIRAnalysis()));
844844

845-
if (LangOpts.SYCLIsDevice)
846-
PerFunctionPasses.add(createSYCLLowerWGScopePass());
847-
848845
CreatePasses(PerModulePasses, PerFunctionPasses);
849846

850847
legacy::PassManager CodeGenPasses;

clang/lib/CodeGen/CodeGenAction.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "CodeGenModule.h"
1111
#include "CoverageMappingGen.h"
1212
#include "MacroPPCallbacks.h"
13+
#include "SYCLLowerIR/LowerWGScope.h"
1314
#include "clang/AST/ASTConsumer.h"
1415
#include "clang/AST/ASTContext.h"
1516
#include "clang/AST/DeclCXX.h"
@@ -33,6 +34,7 @@
3334
#include "llvm/IR/GlobalValue.h"
3435
#include "llvm/IR/LLVMContext.h"
3536
#include "llvm/IR/LLVMRemarkStreamer.h"
37+
#include "llvm/IR/LegacyPassManager.h"
3638
#include "llvm/IR/Module.h"
3739
#include "llvm/IRReader/IRReader.h"
3840
#include "llvm/Linker/Linker.h"
@@ -326,6 +328,17 @@ namespace clang {
326328
CodeGenOpts.getProfileUse() != CodeGenOptions::ProfileNone)
327329
Ctx.setDiagnosticsHotnessRequested(true);
328330

331+
// The parallel_for_work_group legalization pass can emit calls to
332+
// builtins function. Definitions of those builtins can be provided in
333+
// LinkModule. We force the pass to legalize the code before the link
334+
// happens.
335+
if (LangOpts.SYCLIsDevice) {
336+
PrettyStackTraceString CrashInfo("Pre-linking SYCL passes");
337+
legacy::PassManager PreLinkingSyclPasses;
338+
PreLinkingSyclPasses.add(createSYCLLowerWGScopePass());
339+
PreLinkingSyclPasses.run(*getModule());
340+
}
341+
329342
// Link each LinkModule into our module.
330343
if (LinkInModules())
331344
return;

clang/lib/CodeGen/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ class SYCLLowerWGScopeLegacyPass : public FunctionPass {
124124
return false;
125125

126126
FunctionAnalysisManager FAM;
127-
auto PA = Impl.run(F, FAM);
127+
auto TT = llvm::Triple(F.getParent()->getTargetTriple());
128+
auto PA = Impl.run(F, TT, FAM);
128129
return !PA.areAllPreserved();
129130
}
130131

@@ -188,8 +189,8 @@ enum class MemorySemantics : unsigned {
188189
ImageMemory = 0x800,
189190
};
190191

191-
Instruction *genWGBarrier(Instruction &Before);
192-
Value *genLinearLocalID(Instruction &Before);
192+
Instruction *genWGBarrier(Instruction &Before, const Triple &TT);
193+
Value *genLinearLocalID(Instruction &Before, const Triple &TT);
193194
GlobalVariable *createWGLocalVariable(Module &M, Type *T, const Twine &Name);
194195
} // namespace spirv
195196

@@ -263,8 +264,9 @@ static bool mayHaveSideEffects(const Instruction *I) {
263264
//
264265
static void guardBlockWithIsLeaderCheck(BasicBlock *IfBB, BasicBlock *TrueBB,
265266
BasicBlock *MergeBB,
266-
const DebugLoc &DbgLoc) {
267-
Value *LinearLocalID = spirv::genLinearLocalID(*IfBB->getTerminator());
267+
const DebugLoc &DbgLoc,
268+
const Triple &TT) {
269+
Value *LinearLocalID = spirv::genLinearLocalID(*IfBB->getTerminator(), TT);
268270
auto *Ty = LinearLocalID->getType();
269271
Value *Zero = Constant::getNullValue(Ty);
270272
IRBuilder<> Builder(IfBB->getContext());
@@ -341,7 +343,7 @@ using InstrRange = std::pair<Instruction *, Instruction *>;
341343
// ...
342344
// B
343345
// ... USE2(%I1_new) ...
344-
static void tformRange(const InstrRange &R) {
346+
static void tformRange(const InstrRange &R, const Triple &TT) {
345347
// Instructions seen between the first and the last
346348
SmallPtrSet<Instruction *, 16> Seen;
347349
Instruction *FirstSE = R.first;
@@ -360,15 +362,15 @@ static void tformRange(const InstrRange &R) {
360362

361363
// 1) insert the first "is work group leader" test (at the first split) for
362364
// the worker WIs to detour the side effects instructions
363-
guardBlockWithIsLeaderCheck(BBa, LeaderBB, BBb, FirstSE->getDebugLoc());
365+
guardBlockWithIsLeaderCheck(BBa, LeaderBB, BBb, FirstSE->getDebugLoc(), TT);
364366

365367
// 2) "Share" the output values of the instructions in the range
366368
for (auto *I : Seen)
367369
shareOutputViaLocalMem(*I, *BBa, *BBb, Seen);
368370

369371
// 3) Insert work group barrier so that workers further read valid data
370372
// (before the materialization reads inserted at step 2)
371-
spirv::genWGBarrier(BBb->front());
373+
spirv::genWGBarrier(BBb->front(), TT);
372374
}
373375

374376
namespace {
@@ -443,13 +445,13 @@ static void copyBetweenPrivateAndShadow(Value *L, GlobalVariable *Shadow,
443445
//
444446
static void materializeLocalsInWIScopeBlocksImpl(
445447
const DenseMap<BasicBlock *, std::unique_ptr<LocalsSet>> &BB2MatLocals,
446-
const DenseMap<AllocaInst *, GlobalVariable *> &Local2Shadow) {
448+
const DenseMap<AllocaInst *, GlobalVariable *> &Local2Shadow, const Triple &TT) {
447449
for (auto &P : BB2MatLocals) {
448450
// generate LeaderBB and private<->shadow copies in proper BBs
449451
BasicBlock *LeaderBB = P.first;
450452
BasicBlock *BB = LeaderBB->splitBasicBlock(&LeaderBB->front(), "LeaderMat");
451453
// Add a barrier to the original block:
452-
Instruction *At = spirv::genWGBarrier(*BB->getFirstNonPHI())->getNextNode();
454+
Instruction *At = spirv::genWGBarrier(*BB->getFirstNonPHI(), TT)->getNextNode();
453455

454456
for (AllocaInst *L : *P.second.get()) {
455457
auto MapEntry = Local2Shadow.find(L);
@@ -472,7 +474,7 @@ static void materializeLocalsInWIScopeBlocksImpl(
472474
BasicBlock *TestBB =
473475
LeaderBB->splitBasicBlock(&LeaderBB->front(), "TestMat");
474476
std::swap(TestBB, LeaderBB);
475-
guardBlockWithIsLeaderCheck(TestBB, LeaderBB, BB, At->getDebugLoc());
477+
guardBlockWithIsLeaderCheck(TestBB, LeaderBB, BB, At->getDebugLoc(), TT);
476478
}
477479
}
478480

@@ -536,7 +538,8 @@ static bool localMustBeMaterialized(const AllocaInst *L, const BasicBlock &BB) {
536538
//
537539
void materializeLocalsInWIScopeBlocks(
538540
SmallPtrSetImpl<AllocaInst *> &Locals,
539-
SmallPtrSetImpl<BasicBlock *> &WIScopeBBs) {
541+
SmallPtrSetImpl<BasicBlock *> &WIScopeBBs,
542+
const Triple &TT) {
540543
// maps local variable to its "shadow" workgroup-shared global:
541544
DenseMap<AllocaInst *, GlobalVariable *> Local2Shadow;
542545
// records which locals must be materialized at the beginning of a block:
@@ -567,7 +570,7 @@ void materializeLocalsInWIScopeBlocks(
567570
}
568571
}
569572
// perform the materialization
570-
materializeLocalsInWIScopeBlocksImpl(BB2MatLocals, Local2Shadow);
573+
materializeLocalsInWIScopeBlocksImpl(BB2MatLocals, Local2Shadow, TT);
571574
}
572575

573576
#ifndef NDEBUG
@@ -680,7 +683,7 @@ static void fixupPrivateMemoryPFWILambdaCaptures(CallInst *PFWICall) {
680683
// Go through "byval" parameters which are passed as AS(0) pointers
681684
// and: (1) create local shadows for them (2) and initialize them from the
682685
// leader's copy and (3) replace usages with pointer to the shadow
683-
static void shareByValParams(Function &F) {
686+
static void shareByValParams(Function &F, const Triple &TT) {
684687
// split
685688
BasicBlock *EntryBB = &F.getEntryBlock();
686689
BasicBlock *LeaderBB = EntryBB->splitBasicBlock(&EntryBB->front(), "leader");
@@ -689,7 +692,7 @@ static void shareByValParams(Function &F) {
689692
// 1) rewire the above basic blocks so that LeaderBB is executed only for the
690693
// leader workitem
691694
guardBlockWithIsLeaderCheck(EntryBB, LeaderBB, MergeBB,
692-
EntryBB->back().getDebugLoc());
695+
EntryBB->back().getDebugLoc(), TT);
693696
Instruction &At = LeaderBB->back();
694697

695698
for (auto &Arg : F.args()) {
@@ -715,10 +718,11 @@ static void shareByValParams(Function &F) {
715718
true /*private->shadow*/);
716719
}
717720
// 5) make sure workers use up-to-date shared values written by the leader
718-
spirv::genWGBarrier(MergeBB->front());
721+
spirv::genWGBarrier(MergeBB->front(), TT);
719722
}
720723

721724
PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
725+
const llvm::Triple &TT,
722726
FunctionAnalysisManager &FAM) {
723727
if (!F.getMetadata(WG_SCOPE_MD))
724728
return PreservedAnalyses::all();
@@ -796,7 +800,7 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
796800

797801
// Perform the transformation
798802
for (auto &R : Ranges) {
799-
tformRange(R);
803+
tformRange(R, TT);
800804
Changed = true;
801805
}
802806
// There can be allocas not corresponding to any variable declared in user
@@ -813,14 +817,14 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F,
813817
WIScopeBBs.insert(I->getParent());
814818

815819
// Now materialize the locals:
816-
materializeLocalsInWIScopeBlocks(Allocas, WIScopeBBs);
820+
materializeLocalsInWIScopeBlocks(Allocas, WIScopeBBs, TT);
817821

818822
// Fixup captured addresses of private_memory isntances in current WI
819823
for (auto *PFWICall : PFWICalls)
820824
fixupPrivateMemoryPFWILambdaCaptures(PFWICall);
821825

822826
// Finally, create shadows for and replace usages of byval pointer params
823-
shareByValParams(F);
827+
shareByValParams(F, TT);
824828

825829
#ifndef NDEBUG
826830
if (HaveChanges && Debug > 0)
@@ -866,37 +870,74 @@ GlobalVariable *spirv::createWGLocalVariable(Module &M, Type *T,
866870
// Must correspond to the code in
867871
// llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp
868872
// OCL20ToSPIRV::transWorkItemBuiltinsToVariables()
869-
Value *spirv::genLinearLocalID(Instruction &Before) {
873+
Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
870874
Module &M = *Before.getModule();
871-
StringRef Name = "__spirv_BuiltInLocalInvocationIndex";
872-
GlobalVariable *G = M.getGlobalVariable(Name);
873-
874-
if (!G) {
875-
Type *T = getSizeTTy(M);
876-
G = new GlobalVariable(M, // module
877-
T, // type
878-
true, // isConstant
879-
GlobalValue::ExternalLinkage, // Linkage
880-
nullptr, // Initializer
881-
Name, // Name
882-
nullptr, // InsertBefore
883-
GlobalVariable::NotThreadLocal, // ThreadLocalMode
884-
// TODO 'Input' crashes CPU Back-End
885-
// asUInt(spirv::AddrSpace::Input) // AddressSpace
886-
asUInt(spirv::AddrSpace::Global) // AddressSpace
887-
);
888-
unsigned Align = M.getDataLayout().getPreferredAlignment(G);
889-
G->setAlignment(MaybeAlign(Align));
875+
if (TT.isNVPTX()) {
876+
LLVMContext &Ctx = Before.getContext();
877+
Type *RetTy = getSizeTTy(M);
878+
879+
IRBuilder<> Bld(Ctx);
880+
Bld.SetInsertPoint(&Before);
881+
882+
#define CREATE_CALLEE(NAME, FN_NAME) \
883+
FunctionCallee FnCallee##NAME = M.getOrInsertFunction(FN_NAME, RetTy); \
884+
assert(FnCallee##NAME && "spirv intrinsic creation failed"); \
885+
auto NAME = Bld.CreateCall(FnCallee##NAME, {});
886+
887+
CREATE_CALLEE(LocalInvocationId_X, "_Z27__spirv_LocalInvocationId_xv");
888+
CREATE_CALLEE(LocalInvocationId_Y, "_Z27__spirv_LocalInvocationId_yv");
889+
CREATE_CALLEE(LocalInvocationId_Z, "_Z27__spirv_LocalInvocationId_zv");
890+
CREATE_CALLEE(WorkgroupSize_Y, "_Z23__spirv_WorkgroupSize_yv");
891+
CREATE_CALLEE(WorkgroupSize_Z, "_Z23__spirv_WorkgroupSize_zv");
892+
893+
#undef CREATE_CALLEE
894+
895+
// 1: ((__spirv_WorkgroupSize_y() * __spirv_WorkgroupSize_z())
896+
// 2: * __spirv_LocalInvocationId_x())
897+
// 3: + (__spirv_WorkgroupSize_z() * __spirv_LocalInvocationId_y())
898+
// 4: + (__spirv_LocalInvocationId_z())
899+
return Bld.CreateAdd(
900+
Bld.CreateAdd(
901+
Bld.CreateMul(
902+
Bld.CreateMul(WorkgroupSize_Y, WorkgroupSize_Z), // 1
903+
LocalInvocationId_X), // 2
904+
Bld.CreateMul(WorkgroupSize_Z, LocalInvocationId_Y)), // 3
905+
LocalInvocationId_Z); // 4
906+
} else {
907+
StringRef Name = "__spirv_BuiltInLocalInvocationIndex";
908+
GlobalVariable *G = M.getGlobalVariable(Name);
909+
910+
if (!G) {
911+
Type *T = getSizeTTy(M);
912+
G = new GlobalVariable(M, // module
913+
T, // type
914+
true, // isConstant
915+
GlobalValue::ExternalLinkage, // Linkage
916+
nullptr, // Initializer
917+
Name, // Name
918+
nullptr, // InsertBefore
919+
GlobalVariable::NotThreadLocal, // ThreadLocalMode
920+
// TODO 'Input' crashes CPU Back-End
921+
// asUInt(spirv::AddrSpace::Input) // AddressSpace
922+
asUInt(spirv::AddrSpace::Global) // AddressSpace
923+
);
924+
unsigned Align = M.getDataLayout().getPreferredAlignment(G);
925+
G->setAlignment(Align);
926+
}
927+
Value *Res = new LoadInst(G, "", &Before);
928+
return Res;
890929
}
891-
Value *Res = new LoadInst(G, "", &Before);
892-
return Res;
893930
}
894931

895932
// extern void __spirv_ControlBarrier(Scope Execution, Scope Memory,
896933
// uint32_t Semantics) noexcept;
897-
Instruction *spirv::genWGBarrier(Instruction &Before) {
934+
Instruction *spirv::genWGBarrier(Instruction &Before, const Triple &TT) {
898935
Module &M = *Before.getModule();
899-
StringRef Name = "__spirv_ControlBarrier";
936+
StringRef Name;
937+
if (TT.isNVPTX())
938+
Name = "_Z22__spirv_ControlBarrierN5__spv5ScopeES0_j";
939+
else
940+
Name = "__spirv_ControlBarrier";
900941
LLVMContext &Ctx = Before.getContext();
901942
Type *ScopeTy = Type::getInt32Ty(Ctx);
902943
Type *SemanticsTy = Type::getInt32Ty(Ctx);

clang/lib/CodeGen/SYCLLowerIR/LowerWGScope.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace llvm {
2121
/// execution model semantics - this code must be executed once per work group.
2222
class SYCLLowerWGScopePass : public PassInfoMixin<SYCLLowerWGScopePass> {
2323
public:
24-
PreservedAnalyses run(Function &F, FunctionAnalysisManager &);
24+
PreservedAnalyses run(Function &F, const Triple &TT, FunctionAnalysisManager &);
2525
};
2626

2727
FunctionPass *createSYCLLowerWGScopePass();

llvm/lib/Target/NVPTX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(NVPTXCodeGen_sources
3333
NVVMIntrRange.cpp
3434
NVVMReflect.cpp
3535
NVPTXProxyRegErasure.cpp
36+
SYCL/LocalAccessorToSharedMemory.cpp
3637
)
3738

3839
add_llvm_target(NVPTXCodeGen ${NVPTXCodeGen_sources})

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "NVPTXTargetObjectFile.h"
1818
#include "NVPTXTargetTransformInfo.h"
1919
#include "TargetInfo/NVPTXTargetInfo.h"
20+
#include "SYCL/LocalAccessorToSharedMemory.h"
2021
#include "llvm/ADT/STLExtras.h"
2122
#include "llvm/ADT/Triple.h"
2223
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -70,6 +71,8 @@ void initializeNVPTXLowerArgsPass(PassRegistry &);
7071
void initializeNVPTXLowerAllocaPass(PassRegistry &);
7172
void initializeNVPTXProxyRegErasurePass(PassRegistry &);
7273

74+
void initializeLocalAccessorToSharedMemoryPass(PassRegistry &);
75+
7376
} // end namespace llvm
7477

7578
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
@@ -89,6 +92,9 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXTarget() {
8992
initializeNVPTXLowerAllocaPass(PR);
9093
initializeNVPTXLowerAggrCopiesPass(PR);
9194
initializeNVPTXProxyRegErasurePass(PR);
95+
96+
// SYCL-specific passes, needed here to be available to `opt`.
97+
initializeLocalAccessorToSharedMemoryPass(PR);
9298
}
9399

94100
static std::string computeDataLayout(bool is64Bit, bool UseShortPointers) {
@@ -266,6 +272,11 @@ void NVPTXPassConfig::addIRPasses() {
266272
const NVPTXSubtarget &ST = *getTM<NVPTXTargetMachine>().getSubtargetImpl();
267273
addPass(createNVVMReflectPass(ST.getSmVersion()));
268274

275+
if (getTM<NVPTXTargetMachine>().getTargetTriple().getOS() == Triple::CUDA &&
276+
getTM<NVPTXTargetMachine>().getTargetTriple().getEnvironment() == Triple::SYCLDevice) {
277+
addPass(createLocalAccessorToSharedMemoryPass());
278+
}
279+
269280
if (getOptLevel() != CodeGenOpt::None)
270281
addPass(createNVPTXImageOptimizerPass());
271282
addPass(createNVPTXAssignValidGlobalNamesPass());

0 commit comments

Comments
 (0)