Skip to content

Commit dc7d851

Browse files
authored
[SYCL] Simplify handing of builtins in LowerWGScope (#1399)
Implements a few code simplification/unification for LowerWGScope. Signed-off-by: Victor Lomuller <[email protected]>
1 parent 0408899 commit dc7d851

File tree

4 files changed

+31
-61
lines changed

4 files changed

+31
-61
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ enum class MemorySemantics : unsigned {
185185
};
186186

187187
Instruction *genWGBarrier(Instruction &Before, const Triple &TT);
188-
Value *genLinearLocalID(Instruction &Before, const Triple &TT);
188+
Value *genPseudoLocalID(Instruction &Before, const Triple &TT);
189189
GlobalVariable *createWGLocalVariable(Module &M, Type *T, const Twine &Name);
190190
} // namespace spirv
191191

@@ -261,7 +261,7 @@ static void guardBlockWithIsLeaderCheck(BasicBlock *IfBB, BasicBlock *TrueBB,
261261
BasicBlock *MergeBB,
262262
const DebugLoc &DbgLoc,
263263
const Triple &TT) {
264-
Value *LinearLocalID = spirv::genLinearLocalID(*IfBB->getTerminator(), TT);
264+
Value *LinearLocalID = spirv::genPseudoLocalID(*IfBB->getTerminator(), TT);
265265
auto *Ty = LinearLocalID->getType();
266266
Value *Zero = Constant::getNullValue(Ty);
267267
IRBuilder<> Builder(IfBB->getContext());
@@ -861,11 +861,8 @@ GlobalVariable *spirv::createWGLocalVariable(Module &M, Type *T,
861861
// TODO generalize to support all SPIR-V intrinsic operations and builtin
862862
// variables
863863

864-
// extern "C" const __constant size_t __spirv_BuiltInLocalInvocationIndex;
865-
// Must correspond to the code in
866-
// llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp
867-
// OCL20ToSPIRV::transWorkItemBuiltinsToVariables()
868-
Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
864+
// Return a value equals to 0 if and only if the local linear id is 0.
865+
Value *spirv::genPseudoLocalID(Instruction &Before, const Triple &TT) {
869866
Module &M = *Before.getModule();
870867
if (TT.isNVPTX()) {
871868
LLVMContext &Ctx = Before.getContext();
@@ -874,35 +871,29 @@ Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
874871
IRBuilder<> Bld(Ctx);
875872
Bld.SetInsertPoint(&Before);
876873

877-
AttributeList Attr;
878-
Attr = Attr.addAttribute(Ctx, AttributeList::FunctionIndex,
879-
Attribute::Convergent);
880-
881874
#define CREATE_CALLEE(NAME, FN_NAME) \
882-
FunctionCallee FnCallee##NAME = M.getOrInsertFunction(FN_NAME, Attr, RetTy); \
875+
FunctionCallee FnCallee##NAME = M.getOrInsertFunction(FN_NAME, RetTy); \
883876
assert(FnCallee##NAME && "spirv intrinsic creation failed"); \
884-
auto NAME = Bld.CreateCall(FnCallee##NAME, {}); \
885-
NAME->addAttribute(AttributeList::FunctionIndex, Attribute::Convergent);
877+
auto NAME = Bld.CreateCall(FnCallee##NAME, {});
886878

887879
CREATE_CALLEE(LocalInvocationId_X, "_Z27__spirv_LocalInvocationId_xv");
888880
CREATE_CALLEE(LocalInvocationId_Y, "_Z27__spirv_LocalInvocationId_yv");
889881
CREATE_CALLEE(LocalInvocationId_Z, "_Z27__spirv_LocalInvocationId_zv");
890-
CREATE_CALLEE(WorkgroupSize_Y, "_Z23__spirv_WorkgroupSize_yv");
891-
CREATE_CALLEE(WorkgroupSize_Z, "_Z23__spirv_WorkgroupSize_zv");
892882

893883
#undef CREATE_CALLEE
894884

895-
// 1: ((__spirv_WorkgroupSize_y() * __spirv_WorkgroupSize_z())
896-
// 2: * __spirv_LocalInvocationId_x())
897-
// 3: + (__spirv_WorkgroupSize_z() * __spirv_LocalInvocationId_y())
898-
// 4: + (__spirv_LocalInvocationId_z())
899-
return Bld.CreateAdd(
900-
Bld.CreateAdd(
901-
Bld.CreateMul(Bld.CreateMul(WorkgroupSize_Y, WorkgroupSize_Z), // 1
902-
LocalInvocationId_X), // 2
903-
Bld.CreateMul(WorkgroupSize_Z, LocalInvocationId_Y)), // 3
904-
LocalInvocationId_Z); // 4
885+
// 1: returns
886+
// __spirv_LocalInvocationId_x() |
887+
// __spirv_LocalInvocationId_y() |
888+
// __spirv_LocalInvocationId_z()
889+
//
890+
return Bld.CreateOr(LocalInvocationId_X,
891+
Bld.CreateOr(LocalInvocationId_Y, LocalInvocationId_Z));
905892
} else {
893+
// extern "C" const __constant size_t __spirv_BuiltInLocalInvocationIndex;
894+
// Must correspond to the code in
895+
// llvm-spirv/lib/SPIRV/OCL20ToSPIRV.cpp
896+
// OCL20ToSPIRV::transWorkItemBuiltinsToVariables()
906897
StringRef Name = "__spirv_BuiltInLocalInvocationIndex";
907898
GlobalVariable *G = M.getGlobalVariable(Name);
908899

@@ -932,11 +923,7 @@ Value *spirv::genLinearLocalID(Instruction &Before, const Triple &TT) {
932923
// uint32_t Semantics) noexcept;
933924
Instruction *spirv::genWGBarrier(Instruction &Before, const Triple &TT) {
934925
Module &M = *Before.getModule();
935-
StringRef Name;
936-
if (TT.isNVPTX())
937-
Name = "_Z22__spirv_ControlBarrierN5__spv5ScopeES0_j";
938-
else
939-
Name = "__spirv_ControlBarrier";
926+
StringRef Name = "_Z22__spirv_ControlBarrierjjj";
940927
LLVMContext &Ctx = Before.getContext();
941928
Type *ScopeTy = Type::getInt32Ty(Ctx);
942929
Type *SemanticsTy = Type::getInt32Ty(Ctx);

llvm/test/SYCLLowerIR/byval_arg.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ define internal spir_func void @wibble(%struct.baz* byval(%struct.baz) %arg1) !w
1818
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 8 bitcast (%struct.baz addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i8* [[TMP2]], i64 8, i1 false)
1919
; CHECK-NEXT: br label [[MERGE]]
2020
; CHECK: merge:
21-
; CHECK-NEXT: call void @__spirv_ControlBarrier(i32 2, i32 2, i32 272)
21+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
2222
; CHECK-NEXT: ret void
2323
;
2424
ret void

llvm/test/SYCLLowerIR/convergent.ll

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,25 @@
11
; RUN: opt < %s -LowerWGScope -S | FileCheck %s
2-
; RUN: opt < %s -LowerWGScope --mtriple=nvptx -S | FileCheck %s -check-prefix=CHECK-PTX
2+
; RUN: opt < %s -LowerWGScope --mtriple=nvptx -S | FileCheck %s -check-prefix=CHECK -check-prefix=CHECK-PTX
33

44

55
%struct.baz = type { i64 }
66

77
define internal spir_func void @wibble(%struct.baz* byval(%struct.baz) %arg1) !work_group_scope !0 {
8-
; CHECK-PTX: %1 = call i64 @_Z27__spirv_LocalInvocationId_xv() #0
9-
; CHECK-PTX: %2 = call i64 @_Z27__spirv_LocalInvocationId_yv() #0
10-
; CHECK-PTX: %3 = call i64 @_Z27__spirv_LocalInvocationId_zv() #0
11-
; CHECK-PTX: %4 = call i64 @_Z23__spirv_WorkgroupSize_yv() #0
12-
; CHECK-PTX: %5 = call i64 @_Z23__spirv_WorkgroupSize_zv() #0
13-
; CHECK-PTX: call void @_Z22__spirv_ControlBarrierN5__spv5ScopeES0_j(i32 2, i32 2, i32 272) #0
14-
; CHECK: call void @__spirv_ControlBarrier(i32 2, i32 2, i32 272) #1
8+
; CHECK-PTX: call i64 @_Z27__spirv_LocalInvocationId_xv()
9+
; CHECK-PTX: call i64 @_Z27__spirv_LocalInvocationId_yv()
10+
; CHECK-PTX: call i64 @_Z27__spirv_LocalInvocationId_zv()
11+
; CHECK: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
1512
ret void
1613
}
1714

18-
; CHECK-PTX: ; Function Attrs: convergent
19-
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_xv() #0
15+
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_xv()
2016

21-
; CHECK-PTX: ; Function Attrs: convergent
22-
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_yv() #0
17+
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_yv()
2318

24-
; CHECK-PTX: ; Function Attrs: convergent
25-
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_zv() #0
26-
27-
; CHECK-PTX: ; Function Attrs: convergent
28-
; CHECK-PTX: declare i64 @_Z23__spirv_WorkgroupSize_yv() #0
29-
30-
; CHECK-PTX: ; Function Attrs: convergent
31-
; CHECK-PTX: declare i64 @_Z23__spirv_WorkgroupSize_zv() #0
32-
33-
; CHECK-PTX: ; Function Attrs: convergent
34-
; CHECK-PTX: declare void @_Z22__spirv_ControlBarrierN5__spv5ScopeES0_j(i32, i32, i32) #0
35-
36-
; CHECK-PTX: attributes #0 = { convergent }
19+
; CHECK-PTX: declare i64 @_Z27__spirv_LocalInvocationId_zv()
3720

3821
; CHECK: ; Function Attrs: convergent
39-
; CHECK: declare void @__spirv_ControlBarrier(i32, i32, i32) #1
22+
; CHECK: declare void @_Z22__spirv_ControlBarrierjjj(i32, i32, i32) #1
4023

4124
; CHECK: attributes #1 = { convergent }
4225

llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ define internal spir_func void @wibble(%struct.bar addrspace(4)* %arg, %struct.z
2828
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.zot addrspace(3)* @[[GROUP_SHADOW]] to i8 addrspace(3)*), i8* align 8 [[TMP1]], i64 96, i1 false)
2929
; CHECK-NEXT: br label [[MERGE]]
3030
; CHECK: merge:
31-
; CHECK-NEXT: call void @__spirv_ControlBarrier(i32 2, i32 2, i32 272)
31+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
3232
; CHECK-NEXT: [[TMP:%.*]] = alloca [[STRUCT_BAR:%.*]] addrspace(4)*, align 8
3333
; CHECK-NEXT: [[TMP2:%.*]] = alloca [[STRUCT_FOO_0:%.*]], align 1
3434
; CHECK-NEXT: [[ID:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
@@ -51,12 +51,12 @@ define internal spir_func void @wibble(%struct.bar addrspace(4)* %arg, %struct.z
5151
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[MAT_LD]], [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW]]
5252
; CHECK-NEXT: br label [[LEADERMAT]]
5353
; CHECK: LeaderMat:
54-
; CHECK-NEXT: call void @__spirv_ControlBarrier(i32 2, i32 2, i32 272)
54+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
5555
; CHECK-NEXT: [[MAT_LD1:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW]]
5656
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[MAT_LD1]], [[STRUCT_BAR]] addrspace(4)** [[TMP]]
5757
; CHECK-NEXT: [[TMP5:%.*]] = bitcast %struct.foo.0* [[TMP2]] to i8*
5858
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 1 [[TMP5]], i8 addrspace(3)* align 8 getelementptr inbounds (%struct.foo.0, [[STRUCT_FOO_0]] addrspace(3)* @[[PFWI_SHADOW]], i32 0, i32 0), i64 1, i1 false)
59-
; CHECK-NEXT: call void @__spirv_ControlBarrier(i32 2, i32 2, i32 272)
59+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
6060
; CHECK-NEXT: [[WG_VAL_TMP4:%.*]] = load [[STRUCT_ZOT]] addrspace(4)*, [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @wibbleWG_tmp4
6161
; CHECK-NEXT: call spir_func void @bar(%struct.zot addrspace(4)* [[WG_VAL_TMP4]], %struct.foo.0* byval(%struct.foo.0) align 1 [[TMP2]])
6262
; CHECK-NEXT: ret void

0 commit comments

Comments
 (0)