8
8
9
9
#include " gtest/gtest.h"
10
10
11
+ #include " TestGetPlatforms.hpp"
11
12
#include < CL/sycl.hpp>
12
13
#include < CL/sycl/backend/cuda.hpp>
13
14
#include < cuda.h>
14
15
#include < iostream>
15
16
16
17
using namespace cl ::sycl;
17
18
18
- struct CudaInteropGetNativeTests : public ::testing::Test {
19
+ struct CudaInteropGetNativeTests : public ::testing::TestWithParam<platform> {
19
20
20
21
protected:
21
22
queue syclQueue_;
22
23
context syclContext_;
23
24
device syclDevice_;
24
25
25
- CudaInteropGetNativeTests ()
26
- : syclQueue_(cuda_device_selector()),
27
- syclContext_ (syclQueue_.get_context()),
28
- syclDevice_(syclQueue_.get_device()) {}
29
-
30
- static bool isCudaDevice (const device &dev) {
31
- const platform platform = dev.get_info <info::device::platform>();
32
- const std::string platformVersion =
33
- platform.get_info <info::platform::version>();
34
- const std::string platformName = platform.get_info <info::platform::name>();
35
- // If using PI_CUDA, don't accept a non-CUDA device
36
- return platformVersion.find (" CUDA" ) != std::string::npos &&
37
- platformName.find (" NVIDIA CUDA" ) != std::string::npos;
26
+ void SetUp () override {
27
+ syclDevice_ = GetParam ().get_devices ()[0 ];
28
+ syclQueue_ = queue{syclDevice_};
29
+ syclContext_ = syclQueue_.get_context ();
38
30
}
39
31
40
- class cuda_device_selector : public device_selector {
41
- public:
42
- int operator ()(const device &dev) const {
43
- return isCudaDevice (dev) ? 1000 : -1000 ;
44
- }
45
- };
46
-
47
- void SetUp () override {}
48
-
49
32
void TearDown () override {}
50
33
};
51
34
52
- TEST_F (CudaInteropGetNativeTests, getNativeDevice) {
35
+ TEST_P (CudaInteropGetNativeTests, getNativeDevice) {
53
36
CUdevice cudaDevice = get_native<backend::cuda>(syclDevice_);
54
37
char cudaDeviceName[2 ] = {0 , 0 };
55
38
CUresult result = cuDeviceGetName (cudaDeviceName, 2 , cudaDevice);
56
39
ASSERT_EQ (result, CUDA_SUCCESS);
57
40
ASSERT_NE (cudaDeviceName[0 ], 0 );
58
41
}
59
42
60
- TEST_F (CudaInteropGetNativeTests, getNativeContext) {
43
+ TEST_P (CudaInteropGetNativeTests, getNativeContext) {
61
44
CUcontext cudaContext = get_native<backend::cuda>(syclContext_);
62
45
ASSERT_NE (cudaContext, nullptr );
63
46
}
64
47
65
- TEST_F (CudaInteropGetNativeTests, getNativeQueue) {
48
+ TEST_P (CudaInteropGetNativeTests, getNativeQueue) {
66
49
CUstream cudaStream = get_native<backend::cuda>(syclQueue_);
67
50
ASSERT_NE (cudaStream, nullptr );
68
51
@@ -74,21 +57,25 @@ TEST_F(CudaInteropGetNativeTests, getNativeQueue) {
74
57
ASSERT_EQ (streamContext, cudaContext);
75
58
}
76
59
77
- TEST_F (CudaInteropGetNativeTests, interopTaskGetMem) {
60
+ TEST_P (CudaInteropGetNativeTests, interopTaskGetMem) {
78
61
buffer<int , 1 > syclBuffer (range<1 >{1 });
79
62
syclQueue_.submit ([&](handler &cgh) {
80
63
auto syclAccessor = syclBuffer.get_access <access::mode::read>(cgh);
81
64
cgh.interop_task ([=](interop_handler ih) {
82
65
CUdeviceptr cudaPtr = ih.get_mem <backend::cuda>(syclAccessor);
83
66
CUdeviceptr cudaPtrBase;
84
67
size_t cudaPtrSize = 0 ;
85
- cuMemGetAddressRange (&cudaPtrBase, &cudaPtrSize, cudaPtr);
86
- ASSERT_EQ (cudaPtrSize, sizeof (int ));
68
+ CUcontext cudaContext = get_native<backend::cuda>(syclContext_);
69
+ ASSERT_EQ (CUDA_SUCCESS, cuCtxPushCurrent (cudaContext));
70
+ ASSERT_EQ (CUDA_SUCCESS,
71
+ cuMemGetAddressRange (&cudaPtrBase, &cudaPtrSize, cudaPtr));
72
+ ASSERT_EQ (CUDA_SUCCESS, cuCtxPopCurrent (nullptr ));
73
+ ASSERT_EQ (sizeof (int ), cudaPtrSize);
87
74
});
88
75
});
89
76
}
90
77
91
- TEST_F (CudaInteropGetNativeTests, interopTaskGetBufferMem) {
78
+ TEST_P (CudaInteropGetNativeTests, interopTaskGetBufferMem) {
92
79
CUstream cudaStream = get_native<backend::cuda>(syclQueue_);
93
80
syclQueue_.submit ([&](handler &cgh) {
94
81
cgh.interop_task ([=](interop_handler ih) {
@@ -97,3 +84,7 @@ TEST_F(CudaInteropGetNativeTests, interopTaskGetBufferMem) {
97
84
});
98
85
});
99
86
}
87
+
88
+ INSTANTIATE_TEST_CASE_P (
89
+ OnCudaPlatform, CudaInteropGetNativeTests,
90
+ ::testing::ValuesIn (pi::getPlatformsWithName(" CUDA BACKEND" )), );
0 commit comments