Skip to content

Commit 9dcdc62

Browse files
committed
Move Cuda device-specific resource limit checking logic into the adapter backend from the sycl runtime
This change is required in order to implement per-device semantics for the urKernelSuggestMaxCooperativeGroupCountExp query.
1 parent 77da3fa commit 9dcdc62

File tree

1 file changed

+25
-4
lines changed

1 file changed

+25
-4
lines changed

source/adapters/cuda/kernel.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
174174
ur_device_handle_t Device = hKernel->getProgram()->getDevice();
175175
ScopedContext Active(Device);
176176
try {
177+
// We need to calculate max num of work-groups using per-device semantics.
178+
177179
int MaxNumActiveGroupsPerCU{0};
178180
UR_CHECK_ERROR(cuOccupancyMaxActiveBlocksPerMultiprocessor(
179181
&MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize,
180182
dynamicSharedMemorySize));
181183
detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0);
182-
183-
// Multiply by the number of SMs (CUs = compute units) on the device in
184-
// order to retreive the total number of groups/blocks that can be launched.
185-
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
184+
// Handle the case where we can't have all SMs active with at least 1 group
185+
// per SM. In that case, the device is still able to run 1 work-group, hence
186+
// we will manually check if it is possible with the available HW resources.
187+
if (MaxNumActiveGroupsPerCU == 0) {
188+
size_t MaxWorkGroupSize{};
189+
urKernelGetGroupInfo(
190+
hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191+
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
192+
size_t MaxLocalSizeBytes{};
193+
urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194+
sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr);
195+
if (localWorkSize > MaxWorkGroupSize ||
196+
dynamicSharedMemorySize > MaxLocalSizeBytes ||
197+
hasExceededMaxRegistersPerBlock(Device, hKernel, localWorkSize))
198+
*pGroupCountRet = 0;
199+
else
200+
*pGroupCountRet = 1;
201+
} else {
202+
// Multiply by the number of SMs (CUs = compute units) on the device in
203+
// order to retreive the total number of groups/blocks that can be
204+
// launched.
205+
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
206+
}
186207
} catch (ur_result_t Err) {
187208
return Err;
188209
}

0 commit comments

Comments
 (0)