@@ -167,10 +167,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
167
167
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp (
168
168
ur_kernel_handle_t hKernel, size_t localWorkSize,
169
169
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
170
- (void )hKernel;
171
- (void )localWorkSize;
172
- (void )dynamicSharedMemorySize;
173
- *pGroupCountRet = 1 ;
170
+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
171
+
172
+ // We need to set the active current device for this kernel explicitly here,
173
+ // because the occupancy querying API does not take device parameter.
174
+ ur_device_handle_t Device = hKernel->getProgram ()->getDevice ();
175
+ ScopedContext Active (Device);
176
+ try {
177
+ // We need to calculate max num of work-groups using per-device semantics.
178
+
179
+ int MaxNumActiveGroupsPerCU{0 };
180
+ UR_CHECK_ERROR (cuOccupancyMaxActiveBlocksPerMultiprocessor (
181
+ &MaxNumActiveGroupsPerCU, hKernel->get (), localWorkSize,
182
+ dynamicSharedMemorySize));
183
+ detail::ur::assertion (MaxNumActiveGroupsPerCU >= 0 );
184
+ // Handle the case where we can't have all SMs active with at least 1 group
185
+ // per SM. In that case, the device is still able to run 1 work-group, hence
186
+ // we will manually check if it is possible with the available HW resources.
187
+ if (MaxNumActiveGroupsPerCU == 0 ) {
188
+ size_t MaxWorkGroupSize{};
189
+ urKernelGetGroupInfo (
190
+ hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191
+ sizeof (MaxWorkGroupSize), &MaxWorkGroupSize, nullptr );
192
+ size_t MaxLocalSizeBytes{};
193
+ urDeviceGetInfo (Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194
+ sizeof (MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr );
195
+ if (localWorkSize > MaxWorkGroupSize ||
196
+ dynamicSharedMemorySize > MaxLocalSizeBytes ||
197
+ hasExceededMaxRegistersPerBlock (Device, hKernel, localWorkSize))
198
+ *pGroupCountRet = 0 ;
199
+ else
200
+ *pGroupCountRet = 1 ;
201
+ } else {
202
+ // Multiply by the number of SMs (CUs = compute units) on the device in
203
+ // order to retreive the total number of groups/blocks that can be
204
+ // launched.
205
+ *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
206
+ }
207
+ } catch (ur_result_t Err) {
208
+ return Err;
209
+ }
174
210
return UR_RESULT_SUCCESS;
175
211
}
176
212
0 commit comments