Skip to content

Commit 16f64b8

Browse files
authored
[SYCL] Share PFWG lambda object through shared memory (#1455)
In the current implementation private address of the PFWG lambda object is shared by leader work item through local memory to other work items. This is not correct. That is why perform the copy of the PFWG lambda object to shared memory and make work items work with address of the object in shared memory. I.e. this case should be handled in the similar way as for byval parameters. Signed-off-by: Artur Gainullin <[email protected]>
1 parent 52676dd commit 16f64b8

File tree

2 files changed

+94
-55
lines changed

2 files changed

+94
-55
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -375,20 +375,29 @@ using LocalsSet = SmallPtrSet<AllocaInst *, 4>;
375375
static void copyBetweenPrivateAndShadow(Value *L, GlobalVariable *Shadow,
376376
IRBuilder<> &Builder, bool Loc2Shadow) {
377377
Type *T = nullptr;
378-
int LocAlignN = 0;
378+
MaybeAlign LocAlign(0);
379379

380380
if (const auto *AI = dyn_cast<AllocaInst>(L)) {
381381
T = AI->getAllocatedType();
382-
LocAlignN = AI->getAlignment();
382+
LocAlign = MaybeAlign(AI->getAlignment());
383383
} else {
384-
T = cast<Argument>(L)->getParamByValType();
385-
LocAlignN = cast<Argument>(L)->getParamAlignment();
384+
if (cast<Argument>(L)->hasByValAttr()) {
385+
T = cast<Argument>(L)->getParamByValType();
386+
LocAlign = MaybeAlign(cast<Argument>(L)->getParamAlignment());
387+
} else {
388+
Type *Ty = cast<Argument>(L)->getType();
389+
Module &M = *Shadow->getParent();
390+
LocAlign = M.getDataLayout().getValueOrABITypeAlignment(
391+
MaybeAlign(cast<Argument>(L)->getParamAlignment()), Ty);
392+
auto PtrTy = dyn_cast<PointerType>(cast<Argument>(L)->getType());
393+
assert(PtrTy && "Expected pointer type");
394+
T = PtrTy->getElementType();
395+
}
386396
}
387397

388398
if (T->isAggregateType()) {
389399
// TODO: we should use methods which directly return MaybeAlign once such
390400
// are added to LLVM for AllocaInst and GlobalVariable
391-
auto LocAlign = MaybeAlign(LocAlignN);
392401
auto ShdAlign = MaybeAlign(Shadow->getAlignment());
393402
Module &M = *Shadow->getParent();
394403
auto SizeVal = M.getDataLayout().getTypeStoreSize(T);
@@ -679,10 +688,25 @@ static void fixupPrivateMemoryPFWILambdaCaptures(CallInst *PFWICall) {
679688
// Go through "byval" parameters which are passed as AS(0) pointers
680689
// and: (1) create local shadows for them (2) and initialize them from the
681690
// leader's copy and (3) replace usages with pointer to the shadow
682-
static void shareByValParams(Function &F, const Triple &TT) {
683-
// split
691+
//
692+
// Do the same for 'this' pointer which points to PFWG lamda object which is
693+
// allocated in the caller. Caller is a kernel function which is generated by
694+
// SYCL frontend. Kernel function allocates PFWG lambda object and initalizes
695+
// captured objects (like accessors) using arguments of the kernel. After
696+
// intialization kernel calls PFWG function (which is the operator() of the PFWG
697+
// object). PFWG object captures all objects by value and all uses (except
698+
// initialization from kernel arguments) of this values can only be in scope of
699+
// PFWG function that is why copy back of PFWG object is not needed.
700+
static void sharePFWGPrivateObjects(Function &F, const Triple &TT) {
701+
// Skip alloca instructions and split. Alloca instructions must be in the
702+
// beginning of the function otherwise they are considered as dynamic which
703+
// can cause the problems with inlining.
684704
BasicBlock *EntryBB = &F.getEntryBlock();
685-
BasicBlock *LeaderBB = EntryBB->splitBasicBlock(&EntryBB->front(), "leader");
705+
Instruction *SplitPoint = &*EntryBB->begin();
706+
for (; SplitPoint->getOpcode() == Instruction::Alloca;
707+
SplitPoint = SplitPoint->getNextNode())
708+
;
709+
BasicBlock *LeaderBB = EntryBB->splitBasicBlock(SplitPoint, "leader");
686710
BasicBlock *MergeBB = LeaderBB->splitBasicBlock(&LeaderBB->front(), "merge");
687711

688712
// 1) rewire the above basic blocks so that LeaderBB is executed only for the
@@ -692,38 +716,48 @@ static void shareByValParams(Function &F, const Triple &TT) {
692716
Instruction &At = LeaderBB->back();
693717

694718
for (auto &Arg : F.args()) {
695-
if (!Arg.hasByValAttr())
696-
continue;
697-
assert(Arg.getType()->getPointerAddressSpace() ==
698-
asUInt(spirv::AddrSpace::Private));
699-
Type *T = Arg.getParamByValType();
700-
701-
// 2) create the shared copy - "shadow" - for current byval arg
702-
GlobalVariable *Shadow =
703-
spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow");
719+
Type *T;
720+
LLVMContext &Ctx = At.getContext();
721+
IRBuilder<> Builder(Ctx);
722+
Builder.SetInsertPoint(&LeaderBB->front());
704723

705-
// 3) replace argument with shadow in all uses
706-
Value *RepVal = Shadow;
707-
if (TT.isNVPTX()) {
708-
// For NVPTX target address space inference for kernel arguments and
709-
// allocas is happening in the backend (NVPTXLowerArgs and
710-
// NVPTXLowerAlloca passes). After the frontend these pointers are in LLVM
711-
// default address space 0 which is the generic address space for NVPTX
712-
// target.
713-
assert(Arg.getType()->getPointerAddressSpace() == 0);
714-
715-
// Cast a pointer in the shared address space to the generic address
716-
// space.
724+
// 2) create the shared copy - "shadow" - for current arg
725+
GlobalVariable *Shadow;
726+
Value *RepVal;
727+
if (Arg.hasByValAttr()) {
728+
assert(Arg.getType()->getPointerAddressSpace() ==
729+
asUInt(spirv::AddrSpace::Private));
730+
T = Arg.getParamByValType();
731+
Shadow = spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow");
732+
RepVal = Shadow;
733+
if (TT.isNVPTX()) {
734+
// For NVPTX target address space inference for kernel arguments and
735+
// allocas is happening in the backend (NVPTXLowerArgs and
736+
// NVPTXLowerAlloca passes). After the frontend these pointers are in
737+
// LLVM default address space 0 which is the generic address space for
738+
// NVPTX target.
739+
assert(Arg.getType()->getPointerAddressSpace() == 0);
740+
741+
// Cast a pointer in the shared address space to the generic address
742+
// space.
743+
RepVal = ConstantExpr::getPointerBitCastOrAddrSpaceCast(Shadow,
744+
Arg.getType());
745+
}
746+
}
747+
// Process 'this' pointer which points to PFWG lambda object
748+
else if (Arg.getArgNo() == 0) {
749+
PointerType *PtrT = dyn_cast<PointerType>(Arg.getType());
750+
assert(PtrT && "Expected this pointer as the first argument");
751+
T = PtrT->getElementType();
752+
Shadow = spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow");
717753
RepVal =
718-
ConstantExpr::getPointerBitCastOrAddrSpaceCast(Shadow, Arg.getType());
754+
Builder.CreatePointerBitCastOrAddrSpaceCast(Shadow, Arg.getType());
719755
}
756+
757+
// 3) replace argument with shadow in all uses
720758
for (auto *U : Arg.users())
721759
U->replaceUsesOfWith(&Arg, RepVal);
722760

723-
// 4) fill the shadow from the argument for the leader WI only
724-
LLVMContext &Ctx = At.getContext();
725-
IRBuilder<> Builder(Ctx);
726-
Builder.SetInsertPoint(&LeaderBB->front());
727761
copyBetweenPrivateAndShadow(&Arg, Shadow, Builder,
728762
true /*private->shadow*/);
729763
}
@@ -832,8 +866,9 @@ PreservedAnalyses SYCLLowerWGScopePass::run(Function &F, const llvm::Triple &TT,
832866
for (auto *PFWICall : PFWICalls)
833867
fixupPrivateMemoryPFWILambdaCaptures(PFWICall);
834868

835-
// Finally, create shadows for and replace usages of byval pointer params
836-
shareByValParams(F, TT);
869+
// Finally, create shadows for and replace usages of byval pointer params and
870+
// PFWG lambda object ('this' pointer).
871+
sharePFWGPrivateObjects(F, TT);
837872

838873
#ifndef NDEBUG
839874
if (HaveChanges && Debug > 0)

llvm/test/SYCLLowerIR/pfwg_and_pfwi.ll

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,55 @@
1313
%struct.foo = type { %struct.barney }
1414
%struct.foo.0 = type { i8 }
1515

16-
; CHECK: @[[PFWG_SHADOW:.*]] = internal unnamed_addr addrspace(3) global %struct.bar addrspace(4)*
16+
; CHECK: @[[GROUP_SHADOW_PTR:.*]] = internal unnamed_addr addrspace(3) global %struct.zot addrspace(4)*
17+
; CHECK: @[[PFWG_SHADOW_PTR:.*]] = internal unnamed_addr addrspace(3) global %struct.bar addrspace(4)*
1718
; CHECK: @[[PFWI_SHADOW:.*]] = internal unnamed_addr addrspace(3) global %struct.foo.0
19+
; CHECK: @[[PFWG_SHADOW:.*]] = internal unnamed_addr addrspace(3) global %struct.bar
1820
; CHECK: @[[GROUP_SHADOW:.*]] = internal unnamed_addr addrspace(3) global %struct.zot
1921

2022
define internal spir_func void @wibble(%struct.bar addrspace(4)* %arg, %struct.zot* byval(%struct.zot) align 8 %arg1) align 2 !work_group_scope !0 {
2123
; CHECK-LABEL: @wibble(
2224
; CHECK-NEXT: bb:
25+
; CHECK-NEXT: [[TMP:%.*]] = alloca [[STRUCT_BAR:%.*]] addrspace(4)*, align 8
26+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [[STRUCT_FOO_0:%.*]], align 1
2327
; CHECK-NEXT: [[TMP0:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
2428
; CHECK-NEXT: [[CMPZ3:%.*]] = icmp eq i64 [[TMP0]], 0
2529
; CHECK-NEXT: br i1 [[CMPZ3]], label [[LEADER:%.*]], label [[MERGE:%.*]]
2630
; CHECK: leader:
2731
; CHECK-NEXT: [[TMP1:%.*]] = bitcast %struct.zot* [[ARG1:%.*]] to i8*
2832
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.zot addrspace(3)* @[[GROUP_SHADOW]] to i8 addrspace(3)*), i8* align 8 [[TMP1]], i64 96, i1 false)
33+
; CHECK-NEXT: [[ARG_CAST:%.*]] = bitcast [[STRUCT_BAR]] addrspace(4)* [[ARG:%.*]] to i8 addrspace(4)*
34+
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p4i8.i64(i8 addrspace(3)* align 8 getelementptr inbounds (%struct.bar, [[STRUCT_BAR]] addrspace(3)* @[[PFWG_SHADOW]], i32 0, i32 0), i8 addrspace(4)* align 8 [[ARG_CAST]], i64 1, i1 false)
2935
; CHECK-NEXT: br label [[MERGE]]
3036
; CHECK: merge:
31-
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
32-
; CHECK-NEXT: [[TMP:%.*]] = alloca [[STRUCT_BAR:%.*]] addrspace(4)*, align 8
33-
; CHECK-NEXT: [[TMP2:%.*]] = alloca [[STRUCT_FOO_0:%.*]], align 1
34-
; CHECK-NEXT: [[ID:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
35-
; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[ID]], 0
37+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0
38+
; CHECK-NEXT: [[TMP3:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
39+
; CHECK-NEXT: [[CMPZ:%.*]] = icmp eq i64 [[TMP3]], 0
3640
; CHECK-NEXT: br i1 [[CMPZ]], label [[WG_LEADER:%.*]], label [[WG_CF:%.*]]
3741
; CHECK: wg_leader:
38-
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[ARG:%.*]], [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8
42+
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* addrspacecast (%struct.bar addrspace(3)* @[[PFWG_SHADOW]] to [[STRUCT_BAR]] addrspace(4)*), [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8
3943
; CHECK-NEXT: [[TMP3:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)** [[TMP]], align 8
4044
; CHECK-NEXT: [[TMP4:%.*]] = addrspacecast [[STRUCT_ZOT:%.*]] addrspace(3)* @[[GROUP_SHADOW]] to [[STRUCT_ZOT]] addrspace(4)*
41-
; CHECK-NEXT: store [[STRUCT_ZOT]] addrspace(4)* [[TMP4]], [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @wibbleWG_tmp4
45+
; CHECK-NEXT: store [[STRUCT_ZOT]] addrspace(4)* [[TMP4]], [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @[[GROUP_SHADOW_PTR]]
4246
; CHECK-NEXT: br label [[WG_CF]]
4347
; CHECK: wg_cf:
44-
; CHECK-NEXT: [[TMP3:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
45-
; CHECK-NEXT: [[CMPZ2:%.*]] = icmp eq i64 [[TMP3]], 0
48+
; CHECK-NEXT: [[TMP4:%.*]] = load i64, i64 addrspace(1)* @__spirv_BuiltInLocalInvocationIndex
49+
; CHECK-NEXT: [[CMPZ2:%.*]] = icmp eq i64 [[TMP4]], 0
4650
; CHECK-NEXT: br i1 [[CMPZ2]], label [[TESTMAT:%.*]], label [[LEADERMAT:%.*]]
4751
; CHECK: TestMat:
48-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast %struct.foo.0* [[TMP2]] to i8*
49-
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 8 getelementptr inbounds (%struct.foo.0, [[STRUCT_FOO_0]] addrspace(3)* @[[PFWI_SHADOW]], i32 0, i32 0), i8* align 1 [[TMP4]], i64 1, i1 false)
52+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast %struct.foo.0* [[TMP2]] to i8*
53+
; CHECK-NEXT: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 8 getelementptr inbounds (%struct.foo.0, [[STRUCT_FOO_0]] addrspace(3)* @[[PFWI_SHADOW]], i32 0, i32 0), i8* align 1 [[TMP5]], i64 1, i1 false)
5054
; CHECK-NEXT: [[MAT_LD:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)** [[TMP]]
51-
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[MAT_LD]], [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW]]
55+
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[MAT_LD]], [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW_PTR]]
5256
; CHECK-NEXT: br label [[LEADERMAT]]
5357
; CHECK: LeaderMat:
54-
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
55-
; CHECK-NEXT: [[MAT_LD1:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW]]
58+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0
59+
; CHECK-NEXT: [[MAT_LD1:%.*]] = load [[STRUCT_BAR]] addrspace(4)*, [[STRUCT_BAR]] addrspace(4)* addrspace(3)* @[[PFWG_SHADOW_PTR]]
5660
; CHECK-NEXT: store [[STRUCT_BAR]] addrspace(4)* [[MAT_LD1]], [[STRUCT_BAR]] addrspace(4)** [[TMP]]
57-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast %struct.foo.0* [[TMP2]] to i8*
58-
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 1 [[TMP5]], i8 addrspace(3)* align 8 getelementptr inbounds (%struct.foo.0, [[STRUCT_FOO_0]] addrspace(3)* @[[PFWI_SHADOW]], i32 0, i32 0), i64 1, i1 false)
59-
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272)
60-
; CHECK-NEXT: [[WG_VAL_TMP4:%.*]] = load [[STRUCT_ZOT]] addrspace(4)*, [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @wibbleWG_tmp4
61+
; CHECK-NEXT: [[TMP6:%.*]] = bitcast %struct.foo.0* [[TMP2]] to i8*
62+
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p3i8.i64(i8* align 1 [[TMP6]], i8 addrspace(3)* align 8 getelementptr inbounds (%struct.foo.0, [[STRUCT_FOO_0]] addrspace(3)* @[[PFWI_SHADOW]], i32 0, i32 0), i64 1, i1 false)
63+
; CHECK-NEXT: call void @_Z22__spirv_ControlBarrierjjj(i32 2, i32 2, i32 272) #0
64+
; CHECK-NEXT: [[WG_VAL_TMP4:%.*]] = load [[STRUCT_ZOT]] addrspace(4)*, [[STRUCT_ZOT]] addrspace(4)* addrspace(3)* @[[GROUP_SHADOW_PTR]]
6165
; CHECK-NEXT: call spir_func void @bar(%struct.zot addrspace(4)* [[WG_VAL_TMP4]], %struct.foo.0* byval(%struct.foo.0) align 1 [[TMP2]])
6266
; CHECK-NEXT: ret void
6367
;

0 commit comments

Comments
 (0)