Skip to content

Commit 21c0ae6

Browse files
sami-hatna66bb-sycl
authored andcommitted
[SYCL][CUDA][HIP] Expand Config/select_device.cpp for CUDA and HIP (intel#1086)
1 parent b4baf0c commit 21c0ae6

File tree

1 file changed

+97
-26
lines changed

1 file changed

+97
-26
lines changed

SYCL/Config/select_device.cpp

Lines changed: 97 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@
3737
// RUN: env READ_PLATVER_MALFORMED_INFO=1 %GPU_RUN_PLACEHOLDER %t.out
3838
//
3939
// REQUIRES: gpu
40-
//
41-
// XFAIL: cuda || hip
42-
//
43-
// TODO: Update this test when SYCL_DEVICE_FILTER support in enabled.
4440

4541
//==------------ select_device.cpp - SYCL_DEVICE_ALLOWLIST test ------------==//
4642
//
@@ -203,6 +199,18 @@ int main() {
203199
<< "}}" << std::endl;
204200
passed = true;
205201
break;
202+
} else if ((plt.get_backend() == backend::ext_oneapi_cuda) &&
203+
(sycl_be.find("cuda") != std::string::npos)) {
204+
fs << "DeviceName:{{" << name << "}},DriverVersion:{{" << ver
205+
<< "}}" << std::endl;
206+
passed = true;
207+
break;
208+
} else if ((plt.get_backend() == backend::ext_oneapi_hip) &&
209+
(sycl_be.find("hip") != std::string::npos)) {
210+
fs << "DeviceName:{{" << name << "}},DriverVersion:{{" << ver
211+
<< "}}" << std::endl;
212+
passed = true;
213+
break;
206214
}
207215
}
208216
}
@@ -254,7 +262,11 @@ int main() {
254262
if (((plt.get_backend() == backend::opencl) &&
255263
(sycl_be.find("opencl") != std::string::npos)) ||
256264
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
257-
(sycl_be.find("level_zero") != std::string::npos))) {
265+
(sycl_be.find("level_zero") != std::string::npos)) ||
266+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
267+
(sycl_be.find("cuda") != std::string::npos)) ||
268+
((plt.get_backend() == backend::ext_oneapi_hip) &&
269+
(sycl_be.find("hip") != std::string::npos))) {
258270
fs << "PlatformName:{{" << name << "}},PlatformVersion:{{" << ver
259271
<< "}}" << std::endl;
260272
passed = true;
@@ -310,7 +322,11 @@ int main() {
310322
if (((plt.get_backend() == backend::opencl) &&
311323
(sycl_be.find("opencl") != std::string::npos)) ||
312324
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
313-
(sycl_be.find("level_zero") != std::string::npos))) {
325+
(sycl_be.find("level_zero") != std::string::npos)) ||
326+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
327+
(sycl_be.find("cuda") != std::string::npos)) ||
328+
((plt.get_backend() == backend::ext_oneapi_hip) &&
329+
(sycl_be.find("hip") != std::string::npos))) {
314330
fs << "DeviceName:{{" << name << "}},DriverVersion:{{" << ver
315331
<< "}}" << std::endl;
316332
passed = true;
@@ -375,6 +391,20 @@ int main() {
375391
<< "}}" << std::endl;
376392
passed = true;
377393
break;
394+
} else if ((plt.get_backend() == backend::ext_oneapi_cuda) &&
395+
(sycl_be.find("cuda") != std::string::npos)) {
396+
std::string ver("CUDA 89.78");
397+
fs << "PlatformName:{{" << name << "}},PlatformVersion:{{" << ver
398+
<< "}}" << std::endl;
399+
passed = true;
400+
break;
401+
} else if ((plt.get_backend() == backend::ext_oneapi_hip) &&
402+
(sycl_be.find("hip") != std::string::npos)) {
403+
std::string ver("67.88.9");
404+
fs << "PlatformName:{{" << name << "}},PlatformVersion:{{" << ver
405+
<< "}}" << std::endl;
406+
passed = true;
407+
break;
378408
}
379409
}
380410
}
@@ -423,24 +453,37 @@ int main() {
423453
addEscapeSymbolToSpecialCharacters(name);
424454
std::string ver = dev.get_info<info::device::driver_version>();
425455
size_t pos = 0;
426-
if ((pos = ver.find(".")) == std::string::npos) {
427-
throw std::runtime_error("Malformed syntax in version string");
428-
}
429-
pos++;
430-
size_t start = pos;
431-
if ((pos = ver.find(".", pos)) == std::string::npos) {
432-
throw std::runtime_error("Malformed syntax in version string");
433-
}
434-
ver.replace(start, pos - start, "*");
435456
if (((plt.get_backend() == backend::opencl) &&
436457
(sycl_be.find("opencl") != std::string::npos)) ||
437458
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
438459
(sycl_be.find("level_zero") != std::string::npos))) {
439-
fs << "DeviceName:{{" << name << "}},DriverVersion:{{" << ver
440-
<< "}}" << std::endl;
441-
passed = true;
442-
break;
460+
if ((pos = ver.find(".")) == std::string::npos) {
461+
throw std::runtime_error(
462+
"Malformed syntax in version string");
463+
}
464+
pos++;
465+
size_t start = pos;
466+
if ((pos = ver.find(".", pos)) == std::string::npos) {
467+
throw std::runtime_error(
468+
"Malformed syntax in version string");
469+
}
470+
ver.replace(start, pos - start, "*");
471+
} else if (((plt.get_backend() == backend::ext_oneapi_cuda) &&
472+
(sycl_be.find("cuda") != std::string::npos)) ||
473+
((plt.get_backend() == backend::ext_oneapi_hip) &&
474+
(sycl_be.find("hip") != std::string::npos))) {
475+
if ((pos = ver.find(".")) == std::string::npos) {
476+
throw std::runtime_error(
477+
"Malformed syntax in version string");
478+
}
479+
pos++;
480+
ver.replace(pos, ver.length(), "*");
443481
}
482+
483+
fs << "DeviceName:{{" << name << "}},DriverVersion:{{" << ver
484+
<< "}}" << std::endl;
485+
passed = true;
486+
break;
444487
}
445488
}
446489
}
@@ -492,7 +535,11 @@ int main() {
492535
if (((plt.get_backend() == backend::opencl) &&
493536
(sycl_be.find("opencl") != std::string::npos)) ||
494537
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
495-
(sycl_be.find("level_zero") != std::string::npos))) {
538+
(sycl_be.find("level_zero") != std::string::npos)) ||
539+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
540+
(sycl_be.find("cuda") != std::string::npos)) ||
541+
((plt.get_backend() == backend::ext_oneapi_hip) &&
542+
(sycl_be.find("hip") != std::string::npos))) {
496543
fs << "DeviceName:{{" << name << "}}" << std::endl;
497544
passed = true;
498545
break;
@@ -541,7 +588,11 @@ int main() {
541588
if (((plt.get_backend() == backend::opencl) &&
542589
(sycl_be.find("opencl") != std::string::npos)) ||
543590
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
544-
(sycl_be.find("level_zero") != std::string::npos))) {
591+
(sycl_be.find("level_zero") != std::string::npos)) ||
592+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
593+
(sycl_be.find("cuda") != std::string::npos)) ||
594+
((plt.get_backend() == backend::ext_oneapi_hip) &&
595+
(sycl_be.find("hip") != std::string::npos))) {
545596
fs << "PlatformName:{{" << name << "}}" << std::endl;
546597
passed = true;
547598
break;
@@ -594,7 +645,11 @@ int main() {
594645
if (((plt.get_backend() == backend::opencl) &&
595646
(sycl_be.find("opencl") != std::string::npos)) ||
596647
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
597-
(sycl_be.find("level_zero") != std::string::npos))) {
648+
(sycl_be.find("level_zero") != std::string::npos)) ||
649+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
650+
(sycl_be.find("cuda") != std::string::npos)) ||
651+
((plt.get_backend() == backend::ext_oneapi_hip) &&
652+
(sycl_be.find("hip") != std::string::npos))) {
598653
if (count > 0) {
599654
ss << " | ";
600655
}
@@ -656,7 +711,11 @@ int main() {
656711
if (((plt.get_backend() == backend::opencl) &&
657712
(sycl_be.find("opencl") != std::string::npos)) ||
658713
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
659-
(sycl_be.find("level_zero") != std::string::npos))) {
714+
(sycl_be.find("level_zero") != std::string::npos)) ||
715+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
716+
(sycl_be.find("cuda") != std::string::npos)) ||
717+
((plt.get_backend() == backend::ext_oneapi_hip) &&
718+
(sycl_be.find("hip") != std::string::npos))) {
660719
fs << "DeviceName:HAHA{{" << name << "}}" << std::endl;
661720
passed = true;
662721
break;
@@ -717,7 +776,11 @@ int main() {
717776
if (((plt.get_backend() == backend::opencl) &&
718777
(sycl_be.find("opencl") != std::string::npos)) ||
719778
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
720-
(sycl_be.find("level_zero") != std::string::npos))) {
779+
(sycl_be.find("level_zero") != std::string::npos)) ||
780+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
781+
(sycl_be.find("cuda") != std::string::npos)) ||
782+
((plt.get_backend() == backend::ext_oneapi_hip) &&
783+
(sycl_be.find("hip") != std::string::npos))) {
721784
fs << "PlatformName:HAHA{{" << name << "}}" << std::endl;
722785
passed = true;
723786
break;
@@ -779,7 +842,11 @@ int main() {
779842
if (((plt.get_backend() == backend::opencl) &&
780843
(sycl_be.find("opencl") != std::string::npos)) ||
781844
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
782-
(sycl_be.find("level_zero") != std::string::npos))) {
845+
(sycl_be.find("level_zero") != std::string::npos)) ||
846+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
847+
(sycl_be.find("cuda") != std::string::npos)) ||
848+
((plt.get_backend() == backend::ext_oneapi_hip) &&
849+
(sycl_be.find("hip") != std::string::npos))) {
783850
fs << "DeviceName:{{" << name << "}},DriverVersion:HAHA{{"
784851
<< ver << "}}" << std::endl;
785852
passed = true;
@@ -842,7 +909,11 @@ int main() {
842909
if (((plt.get_backend() == backend::opencl) &&
843910
(sycl_be.find("opencl") != std::string::npos)) ||
844911
((plt.get_backend() == backend::ext_oneapi_level_zero) &&
845-
(sycl_be.find("level_zero") != std::string::npos))) {
912+
(sycl_be.find("level_zero") != std::string::npos)) ||
913+
((plt.get_backend() == backend::ext_oneapi_cuda) &&
914+
(sycl_be.find("cuda") != std::string::npos)) ||
915+
((plt.get_backend() == backend::ext_oneapi_hip) &&
916+
(sycl_be.find("hip") != std::string::npos))) {
846917
fs << "PlatformName:{{" << name << "}},PlatformVersion:HAHA{{"
847918
<< ver << "}}" << std::endl;
848919
passed = true;

0 commit comments

Comments
 (0)