@@ -174,15 +174,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
174
174
ur_device_handle_t Device = hKernel->getProgram ()->getDevice ();
175
175
ScopedContext Active (Device);
176
176
try {
177
+ // We need to calculate max num of work-groups using per-device semantics.
178
+
177
179
int MaxNumActiveGroupsPerCU{0 };
178
180
UR_CHECK_ERROR (cuOccupancyMaxActiveBlocksPerMultiprocessor (
179
181
&MaxNumActiveGroupsPerCU, hKernel->get (), localWorkSize,
180
182
dynamicSharedMemorySize));
181
183
detail::ur::assertion (MaxNumActiveGroupsPerCU >= 0 );
182
-
183
- // Multiply by the number of SMs (CUs = compute units) on the device in
184
- // order to retreive the total number of groups/blocks that can be launched.
185
- *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
184
+ // Handle the case where we can't have all SMs active with at least 1 group
185
+ // per SM. In that case, the device is still able to run 1 work-group, hence
186
+ // we will manually check if it is possible with the available HW resources.
187
+ if (MaxNumActiveGroupsPerCU == 0 ) {
188
+ size_t MaxWorkGroupSize{};
189
+ urKernelGetGroupInfo (
190
+ hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191
+ sizeof (MaxWorkGroupSize), &MaxWorkGroupSize, nullptr );
192
+ size_t MaxLocalSizeBytes{};
193
+ urDeviceGetInfo (Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194
+ sizeof (MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr );
195
+ if (localWorkSize > MaxWorkGroupSize ||
196
+ dynamicSharedMemorySize > MaxLocalSizeBytes ||
197
+ hasExceededMaxRegistersPerBlock (Device, hKernel, localWorkSize))
198
+ *pGroupCountRet = 0 ;
199
+ else
200
+ *pGroupCountRet = 1 ;
201
+ } else {
202
+ // Multiply by the number of SMs (CUs = compute units) on the device in
203
+ // order to retreive the total number of groups/blocks that can be
204
+ // launched.
205
+ *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
206
+ }
186
207
} catch (ur_result_t Err) {
187
208
return Err;
188
209
}
0 commit comments