@@ -886,6 +886,79 @@ namespace dpct
886
886
return -1;
887
887
}
888
888
889
+ inline std::string get_preferred_gpu_platform_name() {
890
+ std::string result;
891
+
892
+ std::string filter = " " ;
893
+ char* env = getenv(" ONEAPI_DEVICE_SELECTOR" );
894
+ if (env) {
895
+ if (std::strstr(env, " level_zero" )) {
896
+ filter = " level-zero" ;
897
+ }
898
+ else if (std::strstr(env, " opencl" )) {
899
+ filter = " opencl" ;
900
+ }
901
+ else if (std::strstr(env, " cuda" )) {
902
+ filter = " cuda" ;
903
+ }
904
+ else if (std::strstr(env, " hip" )) {
905
+ filter = " hip" ;
906
+ }
907
+ else {
908
+ throw std::runtime_error(" invalid device filter: " + std::string(env));
909
+ }
910
+ } else {
911
+ auto default_device = sycl::device(sycl::default_selector_v);
912
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
913
+
914
+ if (std::strstr(default_platform_name.c_str(), " Level-Zero" ) || default_device.is_cpu()) {
915
+ filter = " level-zero" ;
916
+ }
917
+ else if (std::strstr(default_platform_name.c_str(), " CUDA" )) {
918
+ filter = " cuda" ;
919
+ }
920
+ else if (std::strstr(default_platform_name.c_str(), " HIP" )) {
921
+ filter = " hip" ;
922
+ }
923
+ }
924
+
925
+ auto platform_list = sycl::platform::get_platforms();
926
+
927
+ for (const auto& platform : platform_list) {
928
+ auto devices = platform.get_devices();
929
+ auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
930
+ return d.is_gpu();
931
+ });
932
+
933
+ if (gpu_dev == devices.end()) {
934
+ // cout << " platform [" << platform_name
935
+ // << " ] does not contain GPU devices, skipping\n" ;
936
+ continue;
937
+ }
938
+
939
+ auto platform_name = platform.get_info<sycl::info::platform::name>();
940
+ std::string platform_name_low_case;
941
+ platform_name_low_case.resize(platform_name.size());
942
+
943
+ std::transform(
944
+ platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
945
+
946
+ if (platform_name_low_case.find(filter) == std::string::npos) {
947
+ // cout << " platform [" << platform_name
948
+ // << " ] does not match with requested "
949
+ // << filter << " , skipping\n" ;
950
+ continue;
951
+ }
952
+
953
+ result = platform_name;
954
+ }
955
+
956
+ if (result.empty())
957
+ throw std::runtime_error(" can not find preferred GPU platform" );
958
+
959
+ return result;
960
+ }
961
+
889
962
template <class DeviceSelector>
890
963
std::enable_if_t<
891
964
std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
0 commit comments