@@ -22,31 +22,8 @@ namespace detail::enqueue_kernel_launch {
22
22
23
23
void handleInvalidWorkGroupSize (const device_impl &DeviceImpl, pi_kernel Kernel,
24
24
const NDRDescT &NDRDesc) {
25
- const bool HasLocalSize = (NDRDesc.LocalSize [0 ] != 0 );
26
-
27
- const PluginPtr &Plugin = DeviceImpl.getPlugin ();
28
- sycl::detail::pi::PiDevice Device = DeviceImpl.getHandleRef ();
29
25
sycl::platform Platform = DeviceImpl.get_platform ();
30
26
31
- if (HasLocalSize) {
32
- size_t MaxThreadsPerBlock[3 ] = {};
33
- Plugin->call <PiApiKind::piDeviceGetInfo>(
34
- Device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES, sizeof (MaxThreadsPerBlock),
35
- MaxThreadsPerBlock, nullptr );
36
-
37
- for (size_t I = 0 ; I < 3 ; ++I) {
38
- if (MaxThreadsPerBlock[I] < NDRDesc.LocalSize [I]) {
39
- throw sycl::nd_range_error (
40
- " The number of work-items in each dimension of a work-group cannot "
41
- " exceed {" +
42
- std::to_string (MaxThreadsPerBlock[0 ]) + " , " +
43
- std::to_string (MaxThreadsPerBlock[1 ]) + " , " +
44
- std::to_string (MaxThreadsPerBlock[2 ]) + " } for this device" ,
45
- PI_ERROR_INVALID_WORK_GROUP_SIZE);
46
- }
47
- }
48
- }
49
-
50
27
// Some of the error handling below is special for particular OpenCL
51
28
// versions. If this is an OpenCL backend, get the version.
52
29
bool IsOpenCL = false ; // Backend is any OpenCL version
@@ -68,6 +45,9 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
68
45
IsCuda = true ;
69
46
}
70
47
48
+ const PluginPtr &Plugin = DeviceImpl.getPlugin ();
49
+ sycl::detail::pi::PiDevice Device = DeviceImpl.getHandleRef ();
50
+
71
51
size_t CompileWGSize[3 ] = {0 };
72
52
Plugin->call <PiApiKind::piKernelGetGroupInfo>(
73
53
Kernel, Device, PI_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
@@ -77,6 +57,9 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
77
57
Plugin->call <PiApiKind::piDeviceGetInfo>(Device,
78
58
PI_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
79
59
sizeof (size_t ), &MaxWGSize, nullptr );
60
+
61
+ const bool HasLocalSize = (NDRDesc.LocalSize [0 ] != 0 );
62
+
80
63
if (CompileWGSize[0 ] != 0 ) {
81
64
if (CompileWGSize[0 ] > MaxWGSize || CompileWGSize[1 ] > MaxWGSize ||
82
65
CompileWGSize[2 ] > MaxWGSize)
@@ -111,6 +94,26 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl, pi_kernel Kernel,
111
94
std::to_string (CompileWGSize[0 ]) + " }" ,
112
95
PI_ERROR_INVALID_WORK_GROUP_SIZE);
113
96
}
97
+
98
+ if (HasLocalSize) {
99
+ size_t MaxThreadsPerBlock[3 ] = {};
100
+ Plugin->call <PiApiKind::piDeviceGetInfo>(
101
+ Device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES, sizeof (MaxThreadsPerBlock),
102
+ MaxThreadsPerBlock, nullptr );
103
+
104
+ for (size_t I = 0 ; I < 3 ; ++I) {
105
+ if (MaxThreadsPerBlock[I] < NDRDesc.LocalSize [I]) {
106
+ throw sycl::nd_range_error (
107
+ " The number of work-items in each dimension of a work-group cannot "
108
+ " exceed {" +
109
+ std::to_string (MaxThreadsPerBlock[0 ]) + " , " +
110
+ std::to_string (MaxThreadsPerBlock[1 ]) + " , " +
111
+ std::to_string (MaxThreadsPerBlock[2 ]) + " } for this device" ,
112
+ PI_ERROR_INVALID_WORK_GROUP_SIZE);
113
+ }
114
+ }
115
+ }
116
+
114
117
if (IsOpenCLV1x) {
115
118
// OpenCL 1.x:
116
119
// PI_ERROR_INVALID_WORK_GROUP_SIZE if local_work_size is specified and
0 commit comments