Skip to content

Commit be7a9dc

Browse files
yxsamliuYaxun Liu
andauthored
[HIP] add test split-kernel-args (#230)
Add a test split-kernel-args.hip for testing AMDGPUSplitKernelArguments pass, which splits byref struct-type kernel arguments so that they can be preloaded for better performance. This test checks struct-type kearnel arguments are passed correctly. Co-authored-by: Yaxun Liu <[email protected]>
1 parent 395115e commit be7a9dc

File tree

3 files changed

+242
-0
lines changed

3 files changed

+242
-0
lines changed

External/HIP/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@ macro(create_local_hip_tests VariantSuffix)
1313
# Set per-source compilation/link options
1414
set_source_files_properties(with-fopenmp.hip PROPERTIES
1515
COMPILE_FLAGS -fopenmp)
16+
# TODO: Add the flag after the kernel argument splitting pass is enabled.
17+
#set_source_files_properties(split-kernel-args.hip PROPERTIES
18+
# COMPILE_FLAGS "-mllvm -amdgpu-enable-split-kernel-args")
1619
# Add HIP tests to be added to hip-tests-simple
1720
list(APPEND HIP_LOCAL_TESTS empty)
1821
list(APPEND HIP_LOCAL_TESTS with-fopenmp)
1922
list(APPEND HIP_LOCAL_TESTS saxpy)
2023
list(APPEND HIP_LOCAL_TESTS memmove)
24+
list(APPEND HIP_LOCAL_TESTS split-kernel-args)
2125

2226
# TODO: Re-enable InOneWeekend after it is fixed
2327
#list(APPEND HIP_LOCAL_TESTS InOneWeekend)

External/HIP/split-kernel-args.hip

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#include <hip/hip_runtime.h>
2+
#include <iostream>
3+
#include <cstring>
4+
#include <cassert>
5+
6+
// Simple error check macro
7+
#define HIP_CHECK(call) \
8+
do { \
9+
hipError_t err = call; \
10+
if (err != hipSuccess) { \
11+
std::cerr << "HIP error: " << hipGetErrorString(err) \
12+
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
13+
std::exit(EXIT_FAILURE); \
14+
} \
15+
} while (0)
16+
17+
// Parametrized kernel that calls the given function pointer and stores the result
18+
template <auto CalcF, typename R, typename... Args>
19+
__global__ void calcWrapperKernel(R* out, Args... args) {
20+
if (threadIdx.x == 0 && blockIdx.x == 0) {
21+
*out = CalcF(args...);
22+
}
23+
}
24+
25+
// Host test launcher: computes expected on host, launches kernel, compares
26+
template <auto CalcF, typename R, typename... Args>
27+
void runTest(Args... args) {
28+
R expected = CalcF(args...);
29+
30+
R* dOut;
31+
R hOut;
32+
HIP_CHECK(hipMalloc(&dOut, sizeof(R)));
33+
HIP_CHECK(hipMemset(dOut, 0, sizeof(R)));
34+
35+
calcWrapperKernel<CalcF, R, Args...><<<1, 1>>>(dOut, args...);
36+
HIP_CHECK(hipGetLastError());
37+
HIP_CHECK(hipMemcpy(&hOut, dOut, sizeof(R), hipMemcpyDeviceToHost));
38+
HIP_CHECK(hipFree(dOut));
39+
40+
if (memcmp(&hOut, &expected, sizeof(R)) == 0) {
41+
std::cout << "[PASS] Result = " << hOut << std::endl;
42+
} else {
43+
std::cerr << "[FAIL] Expected " << expected << ", but got " << hOut << std::endl;
44+
std::exit(EXIT_FAILURE);
45+
}
46+
}
47+
48+
//
49+
// Test Case 1: Mixed types with padding
50+
//
51+
namespace test_mixed_types {
52+
struct Args {
53+
int a;
54+
float b;
55+
char c; // unused
56+
double d;
57+
};
58+
59+
__host__ __device__
60+
int calc(Args args) {
61+
return static_cast<int>(args.a + args.b + args.d);
62+
}
63+
64+
void run() {
65+
Args args = {3, 2.5f, 'x', 4.5};
66+
runTest<calc, int>(args);
67+
}
68+
}
69+
70+
//
71+
// Test Case 2: Deeply nested struct
72+
//
73+
namespace test_nested_struct {
74+
struct Inner {
75+
int x;
76+
float y; // unused
77+
};
78+
79+
struct Middle {
80+
Inner inner;
81+
int z;
82+
};
83+
84+
struct Outer {
85+
Middle mid;
86+
int w;
87+
};
88+
89+
__host__ __device__
90+
int calc(Outer args) {
91+
return args.mid.inner.x + args.mid.z + args.w;
92+
}
93+
94+
void run() {
95+
Outer args = {{{2, 1.0f}, 3}, 4};
96+
runTest<calc, int>(args);
97+
}
98+
}
99+
100+
//
101+
// Test Case 3: Partial field usage
102+
//
103+
namespace test_partial_use {
104+
struct Args {
105+
int a;
106+
float b;
107+
double c; // unused
108+
};
109+
110+
__host__ __device__
111+
int calc(Args args) {
112+
return static_cast<int>(args.a * args.b);
113+
}
114+
115+
void run() {
116+
Args args = {4, 2.0f, 9.9};
117+
runTest<calc, int>(args);
118+
}
119+
}
120+
121+
//
122+
// Test Case 4: Struct with array
123+
//
124+
namespace test_array_member {
125+
struct Args {
126+
int arr[4];
127+
int idx;
128+
};
129+
130+
__host__ __device__
131+
int calc(Args args) {
132+
return args.arr[args.idx];
133+
}
134+
135+
void run() {
136+
Args args = {{10, 20, 30, 40}, 2};
137+
runTest<calc, int>(args);
138+
}
139+
}
140+
141+
//
142+
// Test Case 5: Struct with address taken (indirect access)
143+
//
144+
namespace test_address_taken {
145+
struct Args {
146+
int a;
147+
int b;
148+
};
149+
150+
__device__ __host__
151+
int getA(const Args* p) {
152+
return p->a;
153+
}
154+
155+
__host__ __device__
156+
int calc(Args args) {
157+
return getA(&args) + args.b;
158+
}
159+
160+
void run() {
161+
Args args = {5, 7};
162+
runTest<calc, int>(args);
163+
}
164+
}
165+
166+
//
167+
// Test Case 6: Mixed struct and non-struct arguments with unused fields
168+
//
169+
namespace test_mixed_struct_and_scalars {
170+
struct A {
171+
int a1;
172+
float a2; // unused
173+
char a3;
174+
};
175+
176+
struct B {
177+
double b1;
178+
int b2; // unused
179+
};
180+
181+
__host__ __device__
182+
int calc(char c, A a, int i1, int /*i2*/, B b) {
183+
return static_cast<int>(c) + a.a1 + a.a3 + i1 + static_cast<int>(b.b1);
184+
}
185+
186+
void run() {
187+
A a = {10, 3.14f, 2}; // a1 = 10, a3 = 2, a2 = unused
188+
B b = {5.0, 42}; // b1 = 5.0, b2 = unused
189+
char c = 1;
190+
int i1 = 7;
191+
int i2 = 99; // unused
192+
runTest<calc, int>(c, a, i1, i2, b);
193+
}
194+
}
195+
196+
//
197+
// Test Case 7: Struct with vector type (float3) and other fields
198+
//
199+
namespace test_struct_with_vector_types {
200+
struct A {
201+
float3 v; // Used
202+
int id; // Used
203+
float unused; // Unused
204+
};
205+
206+
__host__ __device__
207+
int calc(A a) {
208+
int sum = static_cast<int>(a.v.x + a.v.y + a.v.z);
209+
return sum + a.id;
210+
}
211+
212+
void run() {
213+
A a;
214+
a.v = make_float3(1.0f, 2.0f, 3.0f);
215+
a.id = 4;
216+
a.unused = 99.0f;
217+
runTest<calc, int>(a);
218+
}
219+
}
220+
221+
int main() {
222+
test_mixed_types::run();
223+
test_nested_struct::run();
224+
test_partial_use::run();
225+
test_array_member::run();
226+
test_address_taken::run();
227+
test_mixed_struct_and_scalars::run();
228+
test_struct_with_vector_types::run();
229+
return 0;
230+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[PASS] Result = 10
2+
[PASS] Result = 9
3+
[PASS] Result = 8
4+
[PASS] Result = 30
5+
[PASS] Result = 12
6+
[PASS] Result = 25
7+
[PASS] Result = 10
8+
exit 0

0 commit comments

Comments
 (0)