Skip to content

Commit eb63d1a

Browse files
Merge pull request #1796 from GeorgeWeb/georgi/ur_kernel_max_active_wgs
[CUDA] Implement urKernelSuggestMaxCooperativeGroupCountExp for Cuda
2 parents e26bba5 + 45a781f commit eb63d1a

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

source/adapters/cuda/device.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
5757
return ReturnValue(4318u);
5858
}
5959
case UR_DEVICE_INFO_MAX_COMPUTE_UNITS: {
60-
int ComputeUnits = 0;
61-
UR_CHECK_ERROR(cuDeviceGetAttribute(
62-
&ComputeUnits, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT,
63-
hDevice->get()));
64-
detail::ur::assertion(ComputeUnits >= 0);
65-
return ReturnValue(static_cast<uint32_t>(ComputeUnits));
60+
return ReturnValue(hDevice->getNumComputeUnits());
6661
}
6762
case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: {
6863
return ReturnValue(MaxWorkItemDimensions);

source/adapters/cuda/device.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct ur_device_handle_t_ {
3232
int MaxCapacityLocalMem{0};
3333
int MaxChosenLocalMem{0};
3434
bool MaxLocalMemSizeChosen{false};
35+
uint32_t NumComputeUnits{0};
3536

3637
public:
3738
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
@@ -54,6 +55,10 @@ struct ur_device_handle_t_ {
5455
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize,
5556
nullptr));
5657

58+
UR_CHECK_ERROR(cuDeviceGetAttribute(
59+
reinterpret_cast<int *>(&NumComputeUnits),
60+
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, cuDevice));
61+
5762
// Set local mem max size if env var is present
5863
static const char *LocalMemSizePtrUR =
5964
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE");
@@ -107,6 +112,8 @@ struct ur_device_handle_t_ {
107112
int getMaxChosenLocalMem() const noexcept { return MaxChosenLocalMem; };
108113

109114
bool maxLocalMemSizeChosen() { return MaxLocalMemSizeChosen; };
115+
116+
uint32_t getNumComputeUnits() const noexcept { return NumComputeUnits; };
110117
};
111118

112119
int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);

source/adapters/cuda/kernel.cpp

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
167167
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
168168
ur_kernel_handle_t hKernel, size_t localWorkSize,
169169
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
170-
(void)hKernel;
171-
(void)localWorkSize;
172-
(void)dynamicSharedMemorySize;
173-
*pGroupCountRet = 1;
170+
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
171+
172+
// We need to set the active current device for this kernel explicitly here,
173+
// because the occupancy querying API does not take device parameter.
174+
ur_device_handle_t Device = hKernel->getProgram()->getDevice();
175+
ScopedContext Active(Device);
176+
try {
177+
// We need to calculate max num of work-groups using per-device semantics.
178+
179+
int MaxNumActiveGroupsPerCU{0};
180+
UR_CHECK_ERROR(cuOccupancyMaxActiveBlocksPerMultiprocessor(
181+
&MaxNumActiveGroupsPerCU, hKernel->get(), localWorkSize,
182+
dynamicSharedMemorySize));
183+
detail::ur::assertion(MaxNumActiveGroupsPerCU >= 0);
184+
// Handle the case where we can't have all SMs active with at least 1 group
185+
// per SM. In that case, the device is still able to run 1 work-group, hence
186+
// we will manually check if it is possible with the available HW resources.
187+
if (MaxNumActiveGroupsPerCU == 0) {
188+
size_t MaxWorkGroupSize{};
189+
urKernelGetGroupInfo(
190+
hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191+
sizeof(MaxWorkGroupSize), &MaxWorkGroupSize, nullptr);
192+
size_t MaxLocalSizeBytes{};
193+
urDeviceGetInfo(Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194+
sizeof(MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr);
195+
if (localWorkSize > MaxWorkGroupSize ||
196+
dynamicSharedMemorySize > MaxLocalSizeBytes ||
197+
hasExceededMaxRegistersPerBlock(Device, hKernel, localWorkSize))
198+
*pGroupCountRet = 0;
199+
else
200+
*pGroupCountRet = 1;
201+
} else {
202+
// Multiply by the number of SMs (CUs = compute units) on the device in
203+
// order to retreive the total number of groups/blocks that can be
204+
// launched.
205+
*pGroupCountRet = Device->getNumComputeUnits() * MaxNumActiveGroupsPerCU;
206+
}
207+
} catch (ur_result_t Err) {
208+
return Err;
209+
}
174210
return UR_RESULT_SUCCESS;
175211
}
176212

0 commit comments

Comments
 (0)