20
20
#include < cuda.h>
21
21
#endif // USE_PI_CUDA
22
22
23
+ #ifdef USE_PI_ROCM
24
+ #include < hip/hip_runtime.h>
25
+ #endif // USE_PI_ROCM
26
+
23
27
#include < algorithm>
24
28
#include < cstdlib>
25
29
#include < cstring>
@@ -32,7 +36,7 @@ static const std::string help =
32
36
" Help\n "
33
37
" Example: ./get_device_count_by_type cpu opencl\n "
34
38
" Supported device types: cpu/gpu/accelerator/default/all\n "
35
- " Supported backends: PI_CUDA/PI_OPENCL/PI_LEVEL_ZERO \n "
39
+ " Supported backends: PI_CUDA/PI_ROCM/ PI_OPENCL/PI_LEVEL_ZERO \n "
36
40
" Output format: <number_of_devices>:<additional_Information>" ;
37
41
38
42
// Return the string with all characters translated to lower case.
@@ -224,6 +228,49 @@ static bool queryCUDA(cl_device_type deviceType, cl_uint &deviceCount,
224
228
#endif
225
229
}
226
230
231
+ static bool queryROCm (cl_device_type deviceType, cl_uint &deviceCount,
232
+ std::string &msg) {
233
+ deviceCount = 0u ;
234
+ #ifdef USE_PI_ROCM
235
+ switch (deviceType) {
236
+ case CL_DEVICE_TYPE_DEFAULT: // Fall through.
237
+ case CL_DEVICE_TYPE_ALL: // Fall through.
238
+ case CL_DEVICE_TYPE_GPU: {
239
+ int count = 0 ;
240
+ hipError_t err = hipGetDeviceCount (&count);
241
+ if (err != hipSuccess || count < 0 ) {
242
+ msg = " ERROR: ROCm error querying device count" ;
243
+ return false ;
244
+ }
245
+ if (count < 1 ) {
246
+ msg = " ERROR: ROCm no device found" ;
247
+ return false ;
248
+ }
249
+ deviceCount = static_cast <cl_uint>(count);
250
+ #if defined(__HIP_PLATFORM_AMD__)
251
+ msg = " rocm-amd " ;
252
+ #elif defined(__HIP_PLATFORM_NVIDIA__)
253
+ msg = " rocm-nvidia " ;
254
+ #else
255
+ #error ("Must define one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
256
+ #endif
257
+ msg += deviceTypeToString (deviceType);
258
+ return true ;
259
+ } break ;
260
+ default :
261
+ msg = " WARNING: ROCm unsupported device type " ;
262
+ msg += deviceTypeToString (deviceType);
263
+ return true ;
264
+ }
265
+ #else
266
+ (void )deviceType;
267
+ msg = " ERROR: ROCm not supported" ;
268
+ deviceCount = 0u ;
269
+
270
+ return false ;
271
+ #endif
272
+ }
273
+
227
274
int main (int argc, char *argv[]) {
228
275
if (argc < 3 ) {
229
276
std::cout << " 0:ERROR: Please set a device type and backend to find"
@@ -264,6 +311,8 @@ int main(int argc, char *argv[]) {
264
311
querySuccess = queryLevelZero (deviceType, deviceCount, msg);
265
312
} else if (backend == " cuda" || backend == " pi_cuda" ) {
266
313
querySuccess = queryCUDA (deviceType, deviceCount, msg);
314
+ } else if (backend == " rocm" || backend == " pi_rocm" ) {
315
+ querySuccess = queryROCm (deviceType, deviceCount, msg);
267
316
} else {
268
317
msg = " ERROR: Unknown backend " + backend + " \n " + help + " \n " ;
269
318
}
0 commit comments