Skip to content

Commit fd54d92

Browse files
[SYCL][CUDA][Test] Testing for use of CUDA primary context (#1174)
Signed-off-by: Steffen Larsen <[email protected]>
1 parent 9497f55 commit fd54d92

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// REQUIRES: cuda
2+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -I%opencl_include_dir -I%cuda_toolkit_include -o %t.out -lcuda -lsycl
3+
// RUN: env SYCL_DEVICE_TYPE=GPU %t.out
4+
// NOTE: OpenCL is required for the runtime, even when using the CUDA BE.
5+
6+
//==---------- primary_context.cpp - SYCL cuda primary context test --------==//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <CL/sycl.hpp>
15+
#include <CL/sycl/detail/pi_cuda.hpp>
16+
#include <cuda.h>
17+
#include <iostream>
18+
19+
using namespace cl::sycl;
20+
21+
void check(bool condition, const char *conditionString, const char *filename,
22+
const long line) noexcept {
23+
if (!condition) {
24+
std::cerr << "CHECK failed in " << filename << "#" << line << " "
25+
<< conditionString << "\n";
26+
std::abort();
27+
}
28+
}
29+
30+
#define CHECK(CONDITION) check(CONDITION, #CONDITION, __FILE__, __LINE__)
31+
32+
bool isCudaDevice(const device &dev) {
33+
const platform platform = dev.get_info<info::device::platform>();
34+
const std::string platformVersion =
35+
platform.get_info<info::platform::version>();
36+
// If using PI_CUDA, don't accept a non-CUDA device
37+
return platformVersion.find("CUDA") != std::string::npos;
38+
}
39+
40+
class cuda_device_selector : public device_selector {
41+
public:
42+
int operator()(const device &dev) const {
43+
return isCudaDevice(dev) ? 1 : -1;
44+
}
45+
};
46+
47+
class other_cuda_device_selector : public device_selector {
48+
public:
49+
other_cuda_device_selector(const device &dev) : excludeDevice{dev} {}
50+
51+
int operator()(const device &dev) const {
52+
if (!isCudaDevice(dev)) {
53+
return -1;
54+
}
55+
if (dev.get() == excludeDevice.get()) {
56+
// Return only this device if it is the only available
57+
return 0;
58+
}
59+
return 1;
60+
}
61+
62+
private:
63+
const device &excludeDevice;
64+
};
65+
66+
int main() {
67+
try {
68+
context c;
69+
} catch (device_error &e) {
70+
std::cout << "Failed to create device for context" << std::endl;
71+
}
72+
73+
device DeviceA = cuda_device_selector().select_device();
74+
device DeviceB = other_cuda_device_selector(DeviceA).select_device();
75+
76+
CHECK(isCudaDevice(DeviceA));
77+
78+
{
79+
std::cout << "create single context" << std::endl;
80+
context Context(DeviceA, async_handler{}, /*UsePrimaryContext=*/true);
81+
82+
CUdevice CudaDevice = reinterpret_cast<pi_device>(DeviceA.get())->get();
83+
CUcontext CudaContext = reinterpret_cast<pi_context>(Context.get())->get();
84+
85+
CUcontext PrimaryCudaContext;
86+
cuDevicePrimaryCtxRetain(&PrimaryCudaContext, CudaDevice);
87+
88+
CHECK(CudaContext == PrimaryCudaContext);
89+
90+
cuDevicePrimaryCtxRelease(CudaDevice);
91+
}
92+
{
93+
std::cout << "create multiple contexts for one device" << std::endl;
94+
context ContextA(DeviceA, async_handler{}, /*UsePrimaryContext=*/true);
95+
context ContextB(DeviceA, async_handler{}, /*UsePrimaryContext=*/true);
96+
97+
CUcontext CudaContextA =
98+
reinterpret_cast<pi_context>(ContextA.get())->get();
99+
CUcontext CudaContextB =
100+
reinterpret_cast<pi_context>(ContextB.get())->get();
101+
102+
CHECK(CudaContextA == CudaContextB);
103+
}
104+
if (isCudaDevice(DeviceB) && DeviceA.get() != DeviceB.get()) {
105+
std::cout << "create multiple contexts for multiple devices" << std::endl;
106+
context ContextA(DeviceA, async_handler{}, /*UsePrimaryContext=*/true);
107+
context ContextB(DeviceB, async_handler{}, /*UsePrimaryContext=*/true);
108+
109+
CUcontext CudaContextA =
110+
reinterpret_cast<pi_context>(ContextA.get())->get();
111+
CUcontext CudaContextB =
112+
reinterpret_cast<pi_context>(ContextB.get())->get();
113+
114+
CHECK(CudaContextA != CudaContextB);
115+
}
116+
}

sycl/test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
config.substitutions.append( ('%sycl_include', config.sycl_include ) )
7272
config.substitutions.append( ('%opencl_libs_dir', config.opencl_libs_dir) )
7373
config.substitutions.append( ('%sycl_source_dir', config.sycl_source_dir) )
74+
config.substitutions.append( ('%opencl_include_dir', config.opencl_include_dir) )
75+
config.substitutions.append( ('%cuda_toolkit_include', config.cuda_toolkit_include) )
7476

7577
llvm_config.use_clang()
7678

sycl/test/lit.site.cfg.py.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ config.opencl_libs_dir = os.path.dirname("@OpenCL_LIBRARIES@")
1212
config.sycl_libs_dir = lit_config.params.get('SYCL_LIBS_DIR', "@LLVM_LIBS_DIR@")
1313
config.target_triple = "@TARGET_TRIPLE@"
1414
config.host_triple = "@LLVM_HOST_TRIPLE@"
15+
config.opencl_include_dir = "@OpenCL_INCLUDE_DIR@"
16+
config.cuda_toolkit_include = "@CUDA_TOOLKIT_INCLUDE@"
1517

1618
config.llvm_enable_projects = "@LLVM_ENABLE_PROJECTS@"
1719

0 commit comments

Comments
 (0)