Skip to content

Commit 748ef7e

Browse files
authored
[CUDA][HIP] Fix record layout on Windows (#87651)
On windows, record layout should be consistent with host side, otherwise host code is not able to access fields of the record correctly. Fixes: #51031 Fixes: SWDEV-446010
1 parent 8888369 commit 748ef7e

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

clang/lib/AST/RecordLayoutBuilder.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,6 +2458,11 @@ static bool mustSkipTailPadding(TargetCXXABI ABI, const CXXRecordDecl *RD) {
24582458
}
24592459

24602460
static bool isMsLayout(const ASTContext &Context) {
2461+
// Check if it's CUDA device compilation; ensure layout consistency with host.
2462+
if (Context.getLangOpts().CUDA && Context.getLangOpts().CUDAIsDevice &&
2463+
Context.getAuxTargetInfo())
2464+
return Context.getAuxTargetInfo()->getCXXABI().isMicrosoft();
2465+
24612466
return Context.getTargetInfo().getCXXABI().isMicrosoft();
24622467
}
24632468

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
// RUN: %clang_cc1 -triple x86_64-pc-windows-msvc -fdump-record-layouts \
2+
// RUN: -emit-llvm -o %t -xhip %s 2>&1 | FileCheck %s --check-prefix=AST
3+
// RUN: cat %t | FileCheck --check-prefixes=CHECK,HOST %s
4+
// RUN: %clang_cc1 -fcuda-is-device -triple amdgcn-amd-amdhsa -target-cpu gfx1100 \
5+
// RUN: -emit-llvm -fdump-record-layouts -aux-triple x86_64-pc-windows-msvc \
6+
// RUN: -o %t -xhip %s | FileCheck %s --check-prefix=AST
7+
// RUN: cat %t | FileCheck --check-prefixes=CHECK,DEV %s
8+
9+
#include "Inputs/cuda.h"
10+
11+
// AST: *** Dumping AST Record Layout
12+
// AST-LABEL: 0 | struct C
13+
// AST-NEXT: 0 | struct A (base) (empty)
14+
// AST-NEXT: 1 | struct B (base) (empty)
15+
// AST-NEXT: 4 | int i
16+
// AST-NEXT: | [sizeof=8, align=4,
17+
// AST-NEXT: | nvsize=8, nvalign=4]
18+
19+
// CHECK: %struct.C = type { [4 x i8], i32 }
20+
21+
struct A {};
22+
struct B {};
23+
struct C : A, B {
24+
int i;
25+
};
26+
27+
// AST: *** Dumping AST Record Layout
28+
// AST-LABEL: 0 | struct I
29+
// AST-NEXT: 0 | (I vftable pointer)
30+
// AST-NEXT: 8 | int i
31+
// AST-NEXT: | [sizeof=16, align=8,
32+
// AST-NEXT: | nvsize=16, nvalign=8]
33+
34+
// AST: *** Dumping AST Record Layout
35+
// AST-LABEL: 0 | struct J
36+
// AST-NEXT: 0 | struct I (primary base)
37+
// AST-NEXT: 0 | (I vftable pointer)
38+
// AST-NEXT: 8 | int i
39+
// AST-NEXT: 16 | int j
40+
// AST-NEXT: | [sizeof=24, align=8,
41+
// AST-NEXT: | nvsize=24, nvalign=8]
42+
43+
// CHECK: %struct.I = type { ptr, i32 }
44+
// CHECK: %struct.J = type { %struct.I, i32 }
45+
46+
// HOST: @0 = private unnamed_addr constant { [4 x ptr] } { [4 x ptr] [ptr @"??_R4J@@6B@", ptr @"?f@J@@UEAAXXZ", ptr null, ptr @"?h@J@@UEAAXXZ"] }, comdat($"??_7J@@6B@")
47+
// HOST: @1 = private unnamed_addr constant { [4 x ptr] } { [4 x ptr] [ptr @"??_R4I@@6B@", ptr @_purecall, ptr null, ptr @_purecall] }, comdat($"??_7I@@6B@")
48+
// HOST: @"??_7J@@6B@" = unnamed_addr alias ptr, getelementptr inbounds ({ [4 x ptr] }, ptr @0, i32 0, i32 0, i32 1)
49+
// HOST: @"??_7I@@6B@" = unnamed_addr alias ptr, getelementptr inbounds ({ [4 x ptr] }, ptr @1, i32 0, i32 0, i32 1)
50+
51+
// DEV: @_ZTV1J = linkonce_odr unnamed_addr addrspace(1) constant { [5 x ptr addrspace(1)] } { [5 x ptr addrspace(1)] [ptr addrspace(1) null, ptr addrspace(1) null, ptr addrspace(1) null, ptr addrspace(1) addrspacecast (ptr @_ZN1J1gEv to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr @_ZN1J1hEv to ptr addrspace(1))] }, comdat, align 8
52+
// DEV: @_ZTV1I = linkonce_odr unnamed_addr addrspace(1) constant { [5 x ptr addrspace(1)] } { [5 x ptr addrspace(1)] [ptr addrspace(1) null, ptr addrspace(1) null, ptr addrspace(1) null, ptr addrspace(1) addrspacecast (ptr @__cxa_pure_virtual to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr @__cxa_pure_virtual to ptr addrspace(1))] }, comdat, align 8
53+
struct I {
54+
virtual void f() = 0;
55+
__device__ virtual void g() = 0;
56+
__device__ __host__ virtual void h() = 0;
57+
int i;
58+
};
59+
60+
struct J : I {
61+
void f() override {}
62+
__device__ void g() override {}
63+
__device__ __host__ void h() override {}
64+
int j;
65+
};
66+
67+
// DEV: define dso_local amdgpu_kernel void @_Z8C_kernel1C(ptr addrspace(4) noundef byref(%struct.C) align 4 %0)
68+
// DEV: %coerce = alloca %struct.C, align 4, addrspace(5)
69+
// DEV: %c = addrspacecast ptr addrspace(5) %coerce to ptr
70+
// DEV: call void @llvm.memcpy.p0.p4.i64(ptr align 4 %c, ptr addrspace(4) align 4 %0, i64 8, i1 false)
71+
// DEV: %i = getelementptr inbounds %struct.C, ptr %c, i32 0, i32 1
72+
// DEV: store i32 1, ptr %i, align 4
73+
74+
__global__ void C_kernel(C c)
75+
{
76+
c.i = 1;
77+
}
78+
79+
// HOST-LABEL: define dso_local void @"?test_C@@YAXXZ"()
80+
// HOST: %c = alloca %struct.C, align 4
81+
// HOST: %i = getelementptr inbounds %struct.C, ptr %c, i32 0, i32 1
82+
// HOST: store i32 11, ptr %i, align 4
83+
84+
void test_C() {
85+
C c;
86+
c.i = 11;
87+
C_kernel<<<1, 1>>>(c);
88+
}
89+
90+
// DEV: define dso_local void @_Z5J_devP1J(ptr noundef %j)
91+
// DEV: %j.addr = alloca ptr, align 8, addrspace(5)
92+
// DEV: %j.addr.ascast = addrspacecast ptr addrspace(5) %j.addr to ptr
93+
// DEV: store ptr %j, ptr %j.addr.ascast, align 8
94+
// DEV: %0 = load ptr, ptr %j.addr.ascast, align 8
95+
// DEV: %i = getelementptr inbounds %struct.I, ptr %0, i32 0, i32 1
96+
// DEV: store i32 2, ptr %i, align 8
97+
// DEV: %1 = load ptr, ptr %j.addr.ascast, align 8
98+
// DEV: %j1 = getelementptr inbounds %struct.J, ptr %1, i32 0, i32 1
99+
// DEV: store i32 3, ptr %j1, align 8
100+
// DEV: %2 = load ptr, ptr %j.addr.ascast, align 8
101+
// DEV: %vtable = load ptr addrspace(1), ptr %2, align 8
102+
// DEV: %vfn = getelementptr inbounds ptr addrspace(1), ptr addrspace(1) %vtable, i64 1
103+
// DEV: %3 = load ptr addrspace(1), ptr addrspace(1) %vfn, align 8
104+
// DEV: call addrspace(1) void %3(ptr noundef nonnull align 8 dereferenceable(24) %2)
105+
// DEV: %4 = load ptr, ptr %j.addr.ascast, align 8
106+
// DEV: %vtable2 = load ptr addrspace(1), ptr %4, align 8
107+
// DEV: %vfn3 = getelementptr inbounds ptr addrspace(1), ptr addrspace(1) %vtable2, i64 2
108+
// DEV: %5 = load ptr addrspace(1), ptr addrspace(1) %vfn3, align 8
109+
// DEV: call addrspace(1) void %5(ptr noundef nonnull align 8 dereferenceable(24) %4)
110+
111+
__device__ void J_dev(J *j) {
112+
j->i = 2;
113+
j->j = 3;
114+
j->g();
115+
j->h();
116+
}
117+
118+
// DEV: define dso_local amdgpu_kernel void @_Z8J_kernelv()
119+
// DEV: %j = alloca %struct.J, align 8, addrspace(5)
120+
// DEV: %j.ascast = addrspacecast ptr addrspace(5) %j to ptr
121+
// DEV: call void @_ZN1JC1Ev(ptr noundef nonnull align 8 dereferenceable(24) %j.ascast)
122+
// DEV: call void @_Z5J_devP1J(ptr noundef %j.ascast)
123+
124+
__global__ void J_kernel() {
125+
J j;
126+
J_dev(&j);
127+
}
128+
129+
// HOST-LABEL: define dso_local void @"?J_host@@YAXPEAUJ@@@Z"(ptr noundef %j)
130+
// HOST: %0 = load ptr, ptr %j.addr, align 8
131+
// HOST: %i = getelementptr inbounds %struct.I, ptr %0, i32 0, i32 1
132+
// HOST: store i32 12, ptr %i, align 8
133+
// HOST: %1 = load ptr, ptr %j.addr, align 8
134+
// HOST: %j1 = getelementptr inbounds %struct.J, ptr %1, i32 0, i32 1
135+
// HOST: store i32 13, ptr %j1, align 8
136+
// HOST: %2 = load ptr, ptr %j.addr, align 8
137+
// HOST: %vtable = load ptr, ptr %2, align 8
138+
// HOST: %vfn = getelementptr inbounds ptr, ptr %vtable, i64 0
139+
// HOST: %3 = load ptr, ptr %vfn, align 8
140+
// HOST: call void %3(ptr noundef nonnull align 8 dereferenceable(24) %2)
141+
// HOST: %4 = load ptr, ptr %j.addr, align 8
142+
// HOST: %vtable2 = load ptr, ptr %4, align 8
143+
// HOST: %vfn3 = getelementptr inbounds ptr, ptr %vtable2, i64 2
144+
// HOST: %5 = load ptr, ptr %vfn3, align 8
145+
// HOST: call void %5(ptr noundef nonnull align 8 dereferenceable(24) %4)
146+
147+
void J_host(J *j) {
148+
j->i = 12;
149+
j->j = 13;
150+
j->f();
151+
j->h();
152+
}
153+
154+
// HOST: define dso_local void @"?test_J@@YAXXZ"()
155+
// HOST: %j = alloca %struct.J, align 8
156+
// HOST: %call = call noundef ptr @"??0J@@QEAA@XZ"(ptr noundef nonnull align 8 dereferenceable(24) %j)
157+
// HOST: call void @"?J_host@@YAXPEAUJ@@@Z"(ptr noundef %j)
158+
159+
void test_J() {
160+
J j;
161+
J_host(&j);
162+
J_kernel<<<1, 1>>>();
163+
}
164+
165+
// HOST: define linkonce_odr dso_local noundef ptr @"??0J@@QEAA@XZ"(ptr noundef nonnull returned align 8 dereferenceable(24) %this)
166+
// HOST: %this.addr = alloca ptr, align 8
167+
// HOST: store ptr %this, ptr %this.addr, align 8
168+
// HOST: %this1 = load ptr, ptr %this.addr, align 8
169+
// HOST: %call = call noundef ptr @"??0I@@QEAA@XZ"(ptr noundef nonnull align 8 dereferenceable(16) %this1) #5
170+
// HOST: store ptr @"??_7J@@6B@", ptr %this1, align 8
171+
// HOST: ret ptr %this1
172+
173+
// HOST: define linkonce_odr dso_local noundef ptr @"??0I@@QEAA@XZ"(ptr noundef nonnull returned align 8 dereferenceable(16) %this)
174+
// HOST: %this.addr = alloca ptr, align 8
175+
// HOST: store ptr %this, ptr %this.addr, align 8
176+
// HOST: %this1 = load ptr, ptr %this.addr, align 8
177+
// HOST: store ptr @"??_7I@@6B@", ptr %this1, align 8
178+
// HOST: ret ptr %this1
179+
180+
// DEV: define linkonce_odr void @_ZN1JC1Ev(ptr noundef nonnull align 8 dereferenceable(24) %this)
181+
// DEV: %this.addr = alloca ptr, align 8, addrspace(5)
182+
// DEV: %this.addr.ascast = addrspacecast ptr addrspace(5) %this.addr to ptr
183+
// DEV: store ptr %this, ptr %this.addr.ascast, align 8
184+
// DEV: %this1 = load ptr, ptr %this.addr.ascast, align 8
185+
// DEV: call void @_ZN1JC2Ev(ptr noundef nonnull align 8 dereferenceable(24) %this1)
186+
187+
// DEV: define linkonce_odr void @_ZN1JC2Ev(ptr noundef nonnull align 8 dereferenceable(24) %this)
188+
// DEV: %this.addr = alloca ptr, align 8, addrspace(5)
189+
// DEV: %this.addr.ascast = addrspacecast ptr addrspace(5) %this.addr to ptr
190+
// DEV: store ptr %this, ptr %this.addr.ascast, align 8
191+
// DEV: %this1 = load ptr, ptr %this.addr.ascast, align 8
192+
// DEV: call void @_ZN1IC2Ev(ptr noundef nonnull align 8 dereferenceable(16) %this1)
193+
// DEV: store ptr addrspace(1) getelementptr inbounds inrange(-16, 24) ({ [5 x ptr addrspace(1)] }, ptr addrspace(1) @_ZTV1J, i32 0, i32 0, i32 2), ptr %this1, align 8
194+
195+
// DEV: define linkonce_odr void @_ZN1IC2Ev(ptr noundef nonnull align 8 dereferenceable(16) %this)
196+
// DEV: %this.addr = alloca ptr, align 8, addrspace(5)
197+
// DEV: %this.addr.ascast = addrspacecast ptr addrspace(5) %this.addr to ptr
198+
// DEV: store ptr %this, ptr %this.addr.ascast, align 8
199+
// DEV: %this1 = load ptr, ptr %this.addr.ascast, align 8
200+
// DEV: store ptr addrspace(1) getelementptr inbounds inrange(-16, 24) ({ [5 x ptr addrspace(1)] }, ptr addrspace(1) @_ZTV1I, i32 0, i32 0, i32 2), ptr %this1, align 8

0 commit comments

Comments
 (0)