@@ -185,7 +185,7 @@ enum class MemorySemantics : unsigned {
185
185
};
186
186
187
187
Instruction *genWGBarrier (Instruction &Before, const Triple &TT);
188
- Value *genLinearLocalID (Instruction &Before, const Triple &TT);
188
+ Value *genPseudoLocalID (Instruction &Before, const Triple &TT);
189
189
GlobalVariable *createWGLocalVariable (Module &M, Type *T, const Twine &Name);
190
190
} // namespace spirv
191
191
@@ -261,7 +261,7 @@ static void guardBlockWithIsLeaderCheck(BasicBlock *IfBB, BasicBlock *TrueBB,
261
261
BasicBlock *MergeBB,
262
262
const DebugLoc &DbgLoc,
263
263
const Triple &TT) {
264
- Value *LinearLocalID = spirv::genLinearLocalID (*IfBB->getTerminator (), TT);
264
+ Value *LinearLocalID = spirv::genPseudoLocalID (*IfBB->getTerminator (), TT);
265
265
auto *Ty = LinearLocalID->getType ();
266
266
Value *Zero = Constant::getNullValue (Ty);
267
267
IRBuilder<> Builder (IfBB->getContext ());
@@ -861,11 +861,8 @@ GlobalVariable *spirv::createWGLocalVariable(Module &M, Type *T,
861
861
// TODO generalize to support all SPIR-V intrinsic operations and builtin
862
862
// variables
863
863
864
- // extern "C" const __constant size_t __spirv_BuiltInLocalInvocationIndex;
865
- // Must correspond to the code in
866
- // llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp
867
- // OCL20ToSPIRV::transWorkItemBuiltinsToVariables()
868
- Value *spirv::genLinearLocalID (Instruction &Before, const Triple &TT) {
864
+ // Return a value equals to 0 if and only if the local linear id is 0.
865
+ Value *spirv::genPseudoLocalID (Instruction &Before, const Triple &TT) {
869
866
Module &M = *Before.getModule ();
870
867
if (TT.isNVPTX ()) {
871
868
LLVMContext &Ctx = Before.getContext ();
@@ -874,35 +871,29 @@ Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
874
871
IRBuilder<> Bld (Ctx);
875
872
Bld.SetInsertPoint (&Before);
876
873
877
- AttributeList Attr;
878
- Attr = Attr.addAttribute (Ctx, AttributeList::FunctionIndex,
879
- Attribute::Convergent);
880
-
881
874
#define CREATE_CALLEE (NAME, FN_NAME ) \
882
- FunctionCallee FnCallee##NAME = M.getOrInsertFunction (FN_NAME, Attr, RetTy); \
875
+ FunctionCallee FnCallee##NAME = M.getOrInsertFunction (FN_NAME, RetTy); \
883
876
assert (FnCallee##NAME && " spirv intrinsic creation failed" ); \
884
- auto NAME = Bld.CreateCall (FnCallee##NAME, {}); \
885
- NAME->addAttribute (AttributeList::FunctionIndex, Attribute::Convergent);
877
+ auto NAME = Bld.CreateCall (FnCallee##NAME, {});
886
878
887
879
CREATE_CALLEE (LocalInvocationId_X, " _Z27__spirv_LocalInvocationId_xv" );
888
880
CREATE_CALLEE (LocalInvocationId_Y, " _Z27__spirv_LocalInvocationId_yv" );
889
881
CREATE_CALLEE (LocalInvocationId_Z, " _Z27__spirv_LocalInvocationId_zv" );
890
- CREATE_CALLEE (WorkgroupSize_Y, " _Z23__spirv_WorkgroupSize_yv" );
891
- CREATE_CALLEE (WorkgroupSize_Z, " _Z23__spirv_WorkgroupSize_zv" );
892
882
893
883
#undef CREATE_CALLEE
894
884
895
- // 1: ((__spirv_WorkgroupSize_y() * __spirv_WorkgroupSize_z())
896
- // 2: * __spirv_LocalInvocationId_x())
897
- // 3: + (__spirv_WorkgroupSize_z() * __spirv_LocalInvocationId_y())
898
- // 4: + (__spirv_LocalInvocationId_z())
899
- return Bld.CreateAdd (
900
- Bld.CreateAdd (
901
- Bld.CreateMul (Bld.CreateMul (WorkgroupSize_Y, WorkgroupSize_Z), // 1
902
- LocalInvocationId_X), // 2
903
- Bld.CreateMul (WorkgroupSize_Z, LocalInvocationId_Y)), // 3
904
- LocalInvocationId_Z); // 4
885
+ // 1: returns
886
+ // __spirv_LocalInvocationId_x() |
887
+ // __spirv_LocalInvocationId_y() |
888
+ // __spirv_LocalInvocationId_z()
889
+ //
890
+ return Bld.CreateOr (LocalInvocationId_X,
891
+ Bld.CreateOr (LocalInvocationId_Y, LocalInvocationId_Z));
905
892
} else {
893
+ // extern "C" const __constant size_t __spirv_BuiltInLocalInvocationIndex;
894
+ // Must correspond to the code in
895
+ // llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp
896
+ // OCL20ToSPIRV::transWorkItemBuiltinsToVariables()
906
897
StringRef Name = " __spirv_BuiltInLocalInvocationIndex" ;
907
898
GlobalVariable *G = M.getGlobalVariable (Name);
908
899
@@ -932,11 +923,7 @@ Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
932
923
// uint32_t Semantics) noexcept;
933
924
Instruction *spirv::genWGBarrier (Instruction &Before, const Triple &TT) {
934
925
Module &M = *Before.getModule ();
935
- StringRef Name;
936
- if (TT.isNVPTX ())
937
- Name = " _Z22__spirv_ControlBarrierN5__spv5ScopeES0_j" ;
938
- else
939
- Name = " __spirv_ControlBarrier" ;
926
+ StringRef Name = " _Z22__spirv_ControlBarrierjjj" ;
940
927
LLVMContext &Ctx = Before.getContext ();
941
928
Type *ScopeTy = Type::getInt32Ty (Ctx);
942
929
Type *SemanticsTy = Type::getInt32Ty (Ctx);
0 commit comments