Skip to content

Commit 5998d7c

Browse files
[SYCL] Fix PI_KERNEL_MAX_SUB_GROUP_SIZE in OpenCL backend (#6849)
Currently PI_KERNEL_MAX_SUB_GROUP_SIZE in the PI OpenCL backend uses the max work item sizes as the input to the corresponding OpenCL query to avoid truncation. However, using the max work item sizes in all dimensions may exceed the total max work items limitations. To prevent this limit from being exceeded, this commit changes the query to only use the max work-item size in the first dimension and using 1s in the other dimensions. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 4d0df22 commit 5998d7c

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -865,14 +865,24 @@ pi_result piKernelGetSubGroupInfo(pi_kernel kernel, pi_device device,
865865
std::shared_ptr<void> implicit_input_value;
866866
if (param_name == PI_KERNEL_MAX_SUB_GROUP_SIZE && !input_value) {
867867
// OpenCL needs an input value for PI_KERNEL_MAX_SUB_GROUP_SIZE so if no
868-
// value is given we use the max work item sizes of the device to avoid
869-
// truncation of max sub-group size.
870-
implicit_input_value = std::shared_ptr<size_t[]>(new size_t[3]);
871-
pi_result pi_ret_err = piDeviceGetInfo(
872-
device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES, 3 * sizeof(size_t),
873-
implicit_input_value.get(), nullptr);
868+
// value is given we use the max work item size of the device in the first
869+
// dimention to avoid truncation of max sub-group size.
870+
pi_uint32 max_dims = 0;
871+
pi_result pi_ret_err =
872+
piDeviceGetInfo(device, PI_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS,
873+
sizeof(pi_uint32), &max_dims, nullptr);
874874
if (pi_ret_err != PI_SUCCESS)
875875
return pi_ret_err;
876+
std::shared_ptr<size_t[]> WGSizes{new size_t[max_dims]};
877+
pi_ret_err =
878+
piDeviceGetInfo(device, PI_DEVICE_INFO_MAX_WORK_ITEM_SIZES,
879+
max_dims * sizeof(size_t), WGSizes.get(), nullptr);
880+
if (pi_ret_err != PI_SUCCESS)
881+
return pi_ret_err;
882+
for (size_t i = 1; i < max_dims; ++i)
883+
WGSizes.get()[i] = 1;
884+
implicit_input_value = std::move(WGSizes);
885+
input_value_size = max_dims * sizeof(size_t);
876886
input_value = implicit_input_value.get();
877887
}
878888

0 commit comments

Comments
 (0)