Skip to content

Commit c98559b

Browse files
authored
[SYCL] Cast address spaces before replacing byval argument usages (#1405)
For NVPTX target address space inference for kernel arguments and allocas is happening in the backend (NVPTXLowerArgs and NVPTXLowerAlloca passes). After frontend these pointers are in LLVM default address space 0 which is the generic address space for NVPTX target. Perform address space cast of a pointer to the shadow global variable from the local to the generic address space before replacing all usages of a byval argument. Signed-off-by: Artur Gainullin <[email protected]>
1 parent dc7d851 commit c98559b

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

llvm/lib/SYCLLowerIR/LowerWGScope.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,8 +703,22 @@ static void shareByValParams(Function &F, const Triple &TT) {
703703
spirv::createWGLocalVariable(*F.getParent(), T, "ArgShadow");
704704

705705
// 3) replace argument with shadow in all uses
706+
Value *RepVal = Shadow;
707+
if (TT.isNVPTX()) {
708+
// For NVPTX target address space inference for kernel arguments and
709+
// allocas is happening in the backend (NVPTXLowerArgs and
710+
// NVPTXLowerAlloca passes). After the frontend these pointers are in LLVM
711+
// default address space 0 which is the generic address space for NVPTX
712+
// target.
713+
assert(Arg.getType()->getPointerAddressSpace() == 0);
714+
715+
// Cast a pointer in the shared address space to the generic address
716+
// space.
717+
RepVal =
718+
ConstantExpr::getPointerBitCastOrAddrSpaceCast(Shadow, Arg.getType());
719+
}
706720
for (auto *U : Arg.users())
707-
U->replaceUsesOfWith(&Arg, Shadow);
721+
U->replaceUsesOfWith(&Arg, RepVal);
708722

709723
// 4) fill the shadow from the argument for the leader WI only
710724
LLVMContext &Ctx = At.getContext();

llvm/test/SYCLLowerIR/cast_shadow.ll

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt < %s -LowerWGScope -verify -S | FileCheck %s
3+
4+
target triple = "nvptx64-nvidia-cuda-sycldevice"
5+
6+
%struct.baz = type { i8 }
7+
%struct.spam = type { %struct.wobble, %struct.wobble, %struct.wobble, %struct.wombat.0 }
8+
%struct.wobble = type { %struct.wombat }
9+
%struct.wombat = type { [1 x i64] }
10+
%struct.wombat.0 = type { %struct.wombat }
11+
%struct.quux = type { i8 }
12+
13+
; CHECK: @[[SHADOW:[a-zA-Z0-9]+]] = internal unnamed_addr addrspace(3) global %struct.spam undef
14+
15+
define internal void @wobble(%struct.baz* %arg, %struct.spam* byval(%struct.spam) %arg1) !work_group_scope !0 {
16+
; CHECK: [[TMP10:%.*]] = bitcast %struct.spam* [[ARG1:%.*]] to i8*
17+
; CHECK: call void @llvm.memcpy.p3i8.p0i8.i64(i8 addrspace(3)* align 16 bitcast (%struct.spam addrspace(3)* @[[SHADOW]] to i8 addrspace(3)*), i8* [[TMP10]], i64 32, i1 false)
18+
; CHECK: call void @widget(%struct.spam* addrspacecast (%struct.spam addrspace(3)* @[[SHADOW]] to %struct.spam*), %struct.quux* byval(%struct.quux) [[TMP2:%.*]])
19+
;
20+
bb:
21+
%tmp = alloca %struct.baz*
22+
%tmp2 = alloca %struct.quux
23+
store %struct.baz* %arg, %struct.baz** %tmp
24+
%tmp3 = load %struct.baz*, %struct.baz** %tmp
25+
call void @widget(%struct.spam* %arg1, %struct.quux* byval(%struct.quux) %tmp2)
26+
ret void
27+
}
28+
29+
define internal void @widget(%struct.spam* %arg, %struct.quux* byval(%struct.quux) %arg1) !work_item_scope !0 !parallel_for_work_item !0 {
30+
bb:
31+
ret void
32+
}
33+
34+
!0 = !{}

0 commit comments

Comments
 (0)