Skip to content

Commit 8f11b0e

Browse files
authored
[SYCL][CUDA] Add tests for exceeding maximum number of work groups (intel/llvm-test-suite#952)
Add tests for checking the error for exceeding maximum number of work groups. Tests changes in intel#4563.
1 parent 91789b5 commit 8f11b0e

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

SYCL/Basic/cuda_max_wgs_error.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -fno-sycl-id-queries-fit-in-int
2+
// RUN: %GPU_RUN_PLACEHOLDER %t.out
3+
//
4+
// REQUIRES: cuda
5+
6+
#include <numeric>
7+
#include <sycl/sycl.hpp>
8+
9+
using namespace sycl;
10+
11+
const size_t lsize = 32;
12+
const std::string expected_msg =
13+
"Number of work-groups exceed limit for dimension ";
14+
15+
template <int N>
16+
void check(range<N> global, range<N> local, bool expect_fail = false) {
17+
queue q;
18+
try {
19+
q.submit([&](handler &cgh) {
20+
cgh.parallel_for(nd_range<N>(global, local), [=](nd_item<N> item) {});
21+
});
22+
} catch (nd_range_error e) {
23+
if (expect_fail) {
24+
std::string msg = e.what();
25+
assert(msg.rfind(expected_msg, 0) == 0);
26+
} else {
27+
throw e;
28+
}
29+
}
30+
}
31+
32+
int main() {
33+
queue q;
34+
device d = q.get_device();
35+
id<1> max_1 = d.get_info<sycl::info::device::ext_oneapi_max_work_groups_1d>();
36+
check(range<1>(max_1[0] * lsize), range<1>(lsize));
37+
check(range<1>((max_1[0] + 1) * lsize), range<1>(lsize), true);
38+
39+
id<2> max_2 = d.get_info<sycl::info::device::ext_oneapi_max_work_groups_2d>();
40+
check(range<2>(1, max_2[1] * lsize), range<2>(1, lsize));
41+
check(range<2>(1, (max_2[1] + 1) * lsize), range<2>(1, lsize), true);
42+
check(range<2>(max_2[0] * lsize, 1), range<2>(lsize, 1));
43+
check(range<2>((max_2[0] + 1) * lsize, 1), range<2>(lsize, 1), true);
44+
45+
id<3> max_3 = d.get_info<sycl::info::device::ext_oneapi_max_work_groups_3d>();
46+
check(range<3>(1, 1, max_3[2] * lsize), range<3>(1, 1, lsize));
47+
check(range<3>(1, 1, (max_3[2] + 1) * lsize), range<3>(1, 1, lsize), true);
48+
check(range<3>(1, max_3[1] * lsize, 1), range<3>(1, lsize, 1));
49+
check(range<3>(1, (max_3[1] + 1) * lsize, 1), range<3>(1, lsize, 1), true);
50+
check(range<3>(max_3[0] * lsize, 1, 1), range<3>(lsize, 1, 1));
51+
check(range<3>((max_3[0] + 1) * lsize, 1, 1), range<3>(lsize, 1, 1), true);
52+
}

0 commit comments

Comments
 (0)