Skip to content

Commit c8b831f

Browse files
[SYCL][PI][L0] Add unit test for level zero plugin kernel batching implementation (#2738)
This adds a test that checks the level zero plugin batching implementation
1 parent cb5ddb4 commit c8b831f

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
// REQUIRES: gpu, level_zero
2+
3+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
4+
5+
// Default batching should be 4
6+
// RUN: env SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB4 %s
7+
8+
// Set batching to 4 explicitly
9+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=4 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB4 %s
10+
11+
// Set batching to 1 explicitly
12+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=1 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB1 %s
13+
14+
// Set batching to 3 explicitly
15+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=3 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB3 %s
16+
17+
// Set batching to 5 explicitly
18+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=5 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB5 %s
19+
20+
// Set batching to 7 explicitly
21+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=7 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB7 %s
22+
23+
// Set batching to 8 explicitly
24+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=8 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB8 %s
25+
26+
// Set batching to 9 explicitly
27+
// RUN: env SYCL_PI_LEVEL_ZERO_BATCH_SIZE=9 SYCL_PI_TRACE=2 ZE_DEBUG=1 %GPU_RUN_PLACEHOLDER %t.out 2>&1 | FileCheck --check-prefixes=CKALL,CKB9 %s
28+
29+
// level_zero_batch_test.cpp
30+
//
31+
// This tests the level zero plugin's kernel batching code. The default
32+
// batching is 4, and exact batch size can be controlled with environment
33+
// variable SYCL_PI_LEVEL_ZEOR+BATCH_SIZE=N.
34+
// This test enqueues 8 kernels and then does a wait. And it does this 3 times.
35+
// Expected output is that for batching =1 you will see zeCommandListClose,
36+
// and zeCommandQueueExecuteCommandLists after every piEnqueueKernelLaunch.
37+
// For batching=3 you will see that after 3rd and 6th enqueues, and then after
38+
// piEventsWait. For 5, after 5th piEnqueue, and then after piEventsWait. For
39+
// 4 you will see these after 4th and 8th Enqueue, and for 8, only after the
40+
// 8th enqueue. And lastly for 9, you will see the Close and Execute calls
41+
// only after the piEventsWait.
42+
// Since the test does this 3 times, this pattern will repeat 2 more times,
43+
// and then the test will print Test Passed 8 times, once for each kernel
44+
// validation check.
45+
// Pattern starts first set of kernel executions.
46+
// CKALL: ---> piEnqueueKernelLaunch(
47+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
48+
// CKB1: ZE ---> zeCommandListClose(
49+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
50+
// CKALL: ---> piEnqueueKernelLaunch(
51+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
52+
// CKB1: ZE ---> zeCommandListClose(
53+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
54+
// CKALL: ---> piEnqueueKernelLaunch(
55+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
56+
// CKB1: ZE ---> zeCommandListClose(
57+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
58+
// CKB3: ZE ---> zeCommandListClose(
59+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
60+
// CKALL: ---> piEnqueueKernelLaunch(
61+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
62+
// CKB1: ZE ---> zeCommandListClose(
63+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
64+
// CKB4: ZE ---> zeCommandListClose(
65+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
66+
// CKALL: ---> piEnqueueKernelLaunch(
67+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
68+
// CKB1: ZE ---> zeCommandListClose(
69+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
70+
// CKB5: ZE ---> zeCommandListClose(
71+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
72+
// CKALL: ---> piEnqueueKernelLaunch(
73+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
74+
// CKB1: ZE ---> zeCommandListClose(
75+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
76+
// CKB3: ZE ---> zeCommandListClose(
77+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
78+
// CKALL: ---> piEnqueueKernelLaunch(
79+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
80+
// CKB1: ZE ---> zeCommandListClose(
81+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
82+
// CKB7: ZE ---> zeCommandListClose(
83+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
84+
// CKALL: ---> piEnqueueKernelLaunch(
85+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
86+
// CKB1: ZE ---> zeCommandListClose(
87+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
88+
// CKB4: ZE ---> zeCommandListClose(
89+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
90+
// CKB8: ZE ---> zeCommandListClose(
91+
// CKB8: ZE ---> zeCommandQueueExecuteCommandLists(
92+
// CKALL: ---> piEventsWait(
93+
// CKB3: ZE ---> zeCommandListClose(
94+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
95+
// CKB5: ZE ---> zeCommandListClose(
96+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
97+
// CKB7: ZE ---> zeCommandListClose(
98+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
99+
// CKB9: ZE ---> zeCommandListClose(
100+
// CKB9: ZE ---> zeCommandQueueExecuteCommandLists(
101+
// Pattern starts 2nd set of kernel executions
102+
// CKALL: ---> piEnqueueKernelLaunch(
103+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
104+
// CKB1: ZE ---> zeCommandListClose(
105+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
106+
// CKALL: ---> piEnqueueKernelLaunch(
107+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
108+
// CKB1: ZE ---> zeCommandListClose(
109+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
110+
// CKALL: ---> piEnqueueKernelLaunch(
111+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
112+
// CKB1: ZE ---> zeCommandListClose(
113+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
114+
// CKB3: ZE ---> zeCommandListClose(
115+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
116+
// CKALL: ---> piEnqueueKernelLaunch(
117+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
118+
// CKB1: ZE ---> zeCommandListClose(
119+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
120+
// CKB4: ZE ---> zeCommandListClose(
121+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
122+
// CKALL: ---> piEnqueueKernelLaunch(
123+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
124+
// CKB1: ZE ---> zeCommandListClose(
125+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
126+
// CKB5: ZE ---> zeCommandListClose(
127+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
128+
// CKALL: ---> piEnqueueKernelLaunch(
129+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
130+
// CKB1: ZE ---> zeCommandListClose(
131+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
132+
// CKB3: ZE ---> zeCommandListClose(
133+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
134+
// CKALL: ---> piEnqueueKernelLaunch(
135+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
136+
// CKB1: ZE ---> zeCommandListClose(
137+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
138+
// CKB7: ZE ---> zeCommandListClose(
139+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
140+
// CKALL: ---> piEnqueueKernelLaunch(
141+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
142+
// CKB1: ZE ---> zeCommandListClose(
143+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
144+
// CKB4: ZE ---> zeCommandListClose(
145+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
146+
// CKB8: ZE ---> zeCommandListClose(
147+
// CKB8: ZE ---> zeCommandQueueExecuteCommandLists(
148+
// CKALL: ---> piEventsWait(
149+
// CKB3: ZE ---> zeCommandListClose(
150+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
151+
// CKB5: ZE ---> zeCommandListClose(
152+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
153+
// CKB7: ZE ---> zeCommandListClose(
154+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
155+
// CKB9: ZE ---> zeCommandListClose(
156+
// CKB9: ZE ---> zeCommandQueueExecuteCommandLists(
157+
// Pattern starts 3rd set of kernel executions
158+
// CKALL: ---> piEnqueueKernelLaunch(
159+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
160+
// CKB1: ZE ---> zeCommandListClose(
161+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
162+
// CKALL: ---> piEnqueueKernelLaunch(
163+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
164+
// CKB1: ZE ---> zeCommandListClose(
165+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
166+
// CKALL: ---> piEnqueueKernelLaunch(
167+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
168+
// CKB1: ZE ---> zeCommandListClose(
169+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
170+
// CKB3: ZE ---> zeCommandListClose(
171+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
172+
// CKALL: ---> piEnqueueKernelLaunch(
173+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
174+
// CKB1: ZE ---> zeCommandListClose(
175+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
176+
// CKB4: ZE ---> zeCommandListClose(
177+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
178+
// CKALL: ---> piEnqueueKernelLaunch(
179+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
180+
// CKB1: ZE ---> zeCommandListClose(
181+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
182+
// CKB5: ZE ---> zeCommandListClose(
183+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
184+
// CKALL: ---> piEnqueueKernelLaunch(
185+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
186+
// CKB1: ZE ---> zeCommandListClose(
187+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
188+
// CKB3: ZE ---> zeCommandListClose(
189+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
190+
// CKALL: ---> piEnqueueKernelLaunch(
191+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
192+
// CKB1: ZE ---> zeCommandListClose(
193+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
194+
// CKB7: ZE ---> zeCommandListClose(
195+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
196+
// CKALL: ---> piEnqueueKernelLaunch(
197+
// CKALL: ZE ---> zeCommandListAppendLaunchKernel(
198+
// CKB1: ZE ---> zeCommandListClose(
199+
// CKB1: ZE ---> zeCommandQueueExecuteCommandLists(
200+
// CKB4: ZE ---> zeCommandListClose(
201+
// CKB4: ZE ---> zeCommandQueueExecuteCommandLists(
202+
// CKB8: ZE ---> zeCommandListClose(
203+
// CKB8: ZE ---> zeCommandQueueExecuteCommandLists(
204+
// CKALL: ---> piEventsWait(
205+
// CKB3: ZE ---> zeCommandListClose(
206+
// CKB3: ZE ---> zeCommandQueueExecuteCommandLists(
207+
// CKB5: ZE ---> zeCommandListClose(
208+
// CKB5: ZE ---> zeCommandQueueExecuteCommandLists(
209+
// CKB7: ZE ---> zeCommandListClose(
210+
// CKB7: ZE ---> zeCommandQueueExecuteCommandLists(
211+
// CKB9: ZE ---> zeCommandListClose(
212+
// CKB9: ZE ---> zeCommandQueueExecuteCommandLists(
213+
// Now just check for 8 Test Pass kernel validations.
214+
// CKALL: Test Pass
215+
// CKALL: Test Pass
216+
// CKALL: Test Pass
217+
// CKALL: Test Pass
218+
// CKALL: Test Pass
219+
// CKALL: Test Pass
220+
// CKALL: Test Pass
221+
// CKALL: Test Pass
222+
223+
#include "CL/sycl.hpp"
224+
#include <chrono>
225+
#include <cmath>
226+
#include <iostream>
227+
228+
namespace sycl = cl::sycl;
229+
230+
void validate(uint32_t *result, uint32_t *expect, size_t n) {
231+
int error = 0;
232+
for (int i = 0; i < n; i++) {
233+
if (result[i] != expect[i]) {
234+
error++;
235+
if (error < 10) {
236+
printf("Error: %d, expect: %d\n", result[i], expect[i]);
237+
}
238+
}
239+
}
240+
error > 0 ? printf("Error: %d\n", error) : printf("Test Pass\n");
241+
}
242+
243+
int main(int argc, char *argv[]) {
244+
size_t M = 65536;
245+
size_t N = 512 / 4;
246+
size_t AL = M * N * sizeof(uint32_t);
247+
248+
sycl::queue q(sycl::default_selector{});
249+
auto ctx = q.get_context();
250+
auto dev = q.get_device();
251+
252+
uint32_t *Y1 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
253+
uint32_t *Z1 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
254+
uint32_t *Z2 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
255+
uint32_t *Z3 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
256+
uint32_t *Z4 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
257+
uint32_t *Z5 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
258+
uint32_t *Z6 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
259+
uint32_t *Z7 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
260+
uint32_t *Z8 = static_cast<uint32_t *>(sycl::malloc_shared(AL, dev, ctx));
261+
262+
for (size_t i = 0; i < M * N; i++) {
263+
Y1[i] = i % 255;
264+
}
265+
266+
memset(Z1, '\0', AL);
267+
memset(Z2, '\0', AL);
268+
memset(Z3, '\0', AL);
269+
memset(Z4, '\0', AL);
270+
memset(Z5, '\0', AL);
271+
memset(Z6, '\0', AL);
272+
memset(Z7, '\0', AL);
273+
memset(Z8, '\0', AL);
274+
275+
{
276+
for (size_t j = 0; j < 3; j++) {
277+
q.submit([&](sycl::handler &h) {
278+
h.parallel_for<class u32_copy1>(sycl::range<2>{M, N},
279+
[=](sycl::id<2> it) {
280+
const int m = it[0];
281+
const int n = it[1];
282+
Z1[m * N + n] = Y1[m * N + n];
283+
});
284+
});
285+
q.submit([&](sycl::handler &h) {
286+
h.parallel_for<class u32_copy2>(sycl::range<2>{M, N},
287+
[=](sycl::id<2> it) {
288+
const int m = it[0];
289+
const int n = it[1];
290+
Z2[m * N + n] = Y1[m * N + n];
291+
});
292+
});
293+
q.submit([&](sycl::handler &h) {
294+
h.parallel_for<class u32_copy3>(sycl::range<2>{M, N},
295+
[=](sycl::id<2> it) {
296+
const int m = it[0];
297+
const int n = it[1];
298+
Z3[m * N + n] = Y1[m * N + n];
299+
});
300+
});
301+
q.submit([&](sycl::handler &h) {
302+
h.parallel_for<class u32_copy4>(sycl::range<2>{M, N},
303+
[=](sycl::id<2> it) {
304+
const int m = it[0];
305+
const int n = it[1];
306+
Z4[m * N + n] = Y1[m * N + n];
307+
});
308+
});
309+
q.submit([&](sycl::handler &h) {
310+
h.parallel_for<class u32_copy5>(sycl::range<2>{M, N},
311+
[=](sycl::id<2> it) {
312+
const int m = it[0];
313+
const int n = it[1];
314+
Z5[m * N + n] = Y1[m * N + n];
315+
});
316+
});
317+
q.submit([&](sycl::handler &h) {
318+
h.parallel_for<class u32_copy6>(sycl::range<2>{M, N},
319+
[=](sycl::id<2> it) {
320+
const int m = it[0];
321+
const int n = it[1];
322+
Z6[m * N + n] = Y1[m * N + n];
323+
});
324+
});
325+
q.submit([&](sycl::handler &h) {
326+
h.parallel_for<class u32_copy7>(sycl::range<2>{M, N},
327+
[=](sycl::id<2> it) {
328+
const int m = it[0];
329+
const int n = it[1];
330+
Z7[m * N + n] = Y1[m * N + n];
331+
});
332+
});
333+
q.submit([&](sycl::handler &h) {
334+
h.parallel_for<class u32_copy8>(sycl::range<2>{M, N},
335+
[=](sycl::id<2> it) {
336+
const int m = it[0];
337+
const int n = it[1];
338+
Z8[m * N + n] = Y1[m * N + n];
339+
});
340+
});
341+
342+
q.wait();
343+
}
344+
}
345+
validate(Y1, Z1, M * N);
346+
validate(Y1, Z2, M * N);
347+
validate(Y1, Z3, M * N);
348+
validate(Y1, Z4, M * N);
349+
validate(Y1, Z5, M * N);
350+
validate(Y1, Z6, M * N);
351+
validate(Y1, Z7, M * N);
352+
validate(Y1, Z8, M * N);
353+
354+
sycl::free(Y1, ctx);
355+
sycl::free(Z1, ctx);
356+
sycl::free(Z2, ctx);
357+
sycl::free(Z3, ctx);
358+
sycl::free(Z4, ctx);
359+
sycl::free(Z5, ctx);
360+
sycl::free(Z6, ctx);
361+
sycl::free(Z7, ctx);
362+
sycl::free(Z8, ctx);
363+
364+
return 0;
365+
}

0 commit comments

Comments
 (0)