Skip to content

Commit 58dbd29

Browse files
authored
[SYCL][PI] Add interoperability with generic handles to device and program classes (#1244)
Signed-off-by: Garima Gupta <[email protected]>
1 parent a083318 commit 58dbd29

File tree

13 files changed

+174
-33
lines changed

13 files changed

+174
-33
lines changed

sycl/include/CL/sycl/detail/pi.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_PI_API(piPlatformsGet)
1919
_PI_API(piPlatformGetInfo)
2020
// Device
21+
_PI_API(piextDeviceConvert)
2122
_PI_API(piDevicesGet)
2223
_PI_API(piDeviceGetInfo)
2324
_PI_API(piDevicePartition)
@@ -45,6 +46,7 @@ _PI_API(piMemRetain)
4546
_PI_API(piMemRelease)
4647
_PI_API(piMemBufferPartition)
4748
// Program
49+
_PI_API(piextProgramConvert)
4850
_PI_API(piProgramCreate)
4951
_PI_API(piclProgramCreateWithSource)
5052
_PI_API(piclProgramCreateWithBinary)

sycl/include/CL/sycl/detail/pi.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,16 @@ pi_result piPlatformGetInfo(pi_platform platform, pi_platform_info param_name,
711711
//
712712
// Device
713713
//
714+
///
715+
/// Create PI device from the given raw device handle (if the "device"
716+
/// points to null), or, vice versa, extract the raw device handle into
717+
/// the "handle" (if it was pointing to a null) from the given PI device.
718+
/// NOTE: The instance of the PI device created is retained.
719+
///
720+
pi_result piextDeviceConvert(
721+
pi_device *device, ///< [in,out] the pointer to PI device
722+
void **handle); ///< [in,out] the pointer to the raw device handle
723+
714724
pi_result piDevicesGet(pi_platform platform, pi_device_type device_type,
715725
pi_uint32 num_entries, pi_device *devices,
716726
pi_uint32 *num_devices);
@@ -811,6 +821,17 @@ pi_result piMemBufferPartition(pi_mem buffer, pi_mem_flags flags,
811821
//
812822
// Program
813823
//
824+
///
825+
/// Create PI program from the given raw program handle (if the "program"
826+
/// points to null), or, vice versa, extract the raw program handle into
827+
/// the "handle" (if it was pointing to a null) from the given PI program.
828+
/// NOTE: The instance of the PI program created is retained.
829+
///
830+
pi_result piextProgramConvert(
831+
pi_context context, ///< [in] the PI context of the program
832+
pi_program *program, ///< [in,out] the pointer to PI program
833+
void **handle); ///< [in,out] the pointer to the raw program handle
834+
814835
pi_result piProgramCreate(pi_context context, const void *il, size_t length,
815836
pi_program *res_program);
816837

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,23 @@ namespace RT = cl::sycl::detail::pi;
171171

172172
// Want all the needed casts be explicit, do not define conversion
173173
// operators.
174-
template <class To, class From> To pi::cast(From value) {
174+
template <class To, class From> To inline pi::cast(From value) {
175175
// TODO: see if more sanity checks are possible.
176176
RT::assertion((sizeof(From) == sizeof(To)), "assert: cast failed size check");
177177
return (To)(value);
178178
}
179179

180+
// These conversions should use PI interop API.
181+
template <> pi::PiProgram inline pi::cast(cl_program interop) {
182+
RT::assertion(false, "pi::cast -> use piextProgramConvert");
183+
return {};
184+
}
185+
186+
template <> pi::PiDevice inline pi::cast(cl_device_id interop) {
187+
RT::assertion(false, "pi::cast -> use piextDeviceConvert");
188+
return {};
189+
}
190+
180191
} // namespace detail
181192

182193
// For shortness of using PI from the top-level sycl files.

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,11 @@ pi_result cuda_piPlatformGetInfo(pi_platform platform,
680680
return {};
681681
}
682682

683+
pi_result cuda_piextDeviceConvert(pi_device *device, void **handle) {
684+
cl::sycl::detail::pi::die("cuda_piextDeviceConvert not implemented");
685+
return {};
686+
}
687+
683688
pi_result cuda_piDevicesGet(pi_platform platform, pi_device_type device_type,
684689
pi_uint32 num_entries, pi_device *devices,
685690
pi_uint32 *num_devices) {
@@ -2138,6 +2143,15 @@ pi_result cuda_piMemRetain(pi_mem mem) {
21382143
//
21392144
// Program
21402145
//
2146+
pi_result cuda_piextProgramConvert(
2147+
pi_context context, ///< [in] the PI context of the program
2148+
pi_program *program, ///< [in,out] the pointer to PI program
2149+
void **handle) ///< [in,out] the pointer to the raw program handle
2150+
{
2151+
cl::sycl::detail::pi::die("cuda_piextProgramConvert not implemented");
2152+
return {};
2153+
}
2154+
21412155
pi_result cuda_piProgramCreate(pi_context context, const void *il,
21422156
size_t length, pi_program *res_program) {
21432157
cl::sycl::detail::pi::die("cuda_piProgramCreate not implemented");
@@ -3480,6 +3494,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
34803494
_PI_CL(piPlatformsGet, cuda_piPlatformsGet)
34813495
_PI_CL(piPlatformGetInfo, cuda_piPlatformGetInfo)
34823496
// Device
3497+
_PI_CL(piextDeviceConvert, cuda_piextDeviceConvert)
34833498
_PI_CL(piDevicesGet, cuda_piDevicesGet)
34843499
_PI_CL(piDeviceGetInfo, cuda_piDeviceGetInfo)
34853500
_PI_CL(piDevicePartition, cuda_piDevicePartition)
@@ -3507,6 +3522,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
35073522
_PI_CL(piMemRelease, cuda_piMemRelease)
35083523
_PI_CL(piMemBufferPartition, cuda_piMemBufferPartition)
35093524
// Program
3525+
_PI_CL(piextProgramConvert, cuda_piextProgramConvert)
35103526
_PI_CL(piProgramCreate, cuda_piProgramCreate)
35113527
_PI_CL(piclProgramCreateWithSource, cuda_piclProgramCreateWithSource)
35123528
_PI_CL(piclProgramCreateWithBinary, cuda_piclProgramCreateWithBinary)

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,24 @@ pi_result OCL(piPlatformsGet)(pi_uint32 num_entries, pi_platform *platforms,
176176
return static_cast<pi_result>(result);
177177
}
178178

179+
pi_result OCL(piextDeviceConvert)(pi_device *device, void **handle) {
180+
// The PI device is the same as OpenCL device handle.
181+
assert(device);
182+
assert(handle);
183+
184+
if (*device == nullptr) {
185+
// unitialized *device.
186+
assert(*handle);
187+
*device = cast<pi_device>(*handle);
188+
} else {
189+
assert(*handle == nullptr);
190+
*handle = *device;
191+
}
192+
193+
cl_int result = clRetainDevice(cast<cl_device_id>(*handle));
194+
return cast<pi_result>(result);
195+
}
196+
179197
// Example of a PI interface that does not map exactly to an OpenCL one.
180198
pi_result OCL(piDevicesGet)(pi_platform platform, pi_device_type device_type,
181199
pi_uint32 num_entries, pi_device *devices,
@@ -305,6 +323,27 @@ pi_result OCL(piQueueCreate)(pi_context context, pi_device device,
305323
return cast<pi_result>(ret_err);
306324
}
307325

326+
pi_result OCL(piextProgramConvert)(
327+
pi_context context, ///< [in] the PI context of the program
328+
pi_program *program, ///< [in,out] the pointer to PI program
329+
void **handle) ///< [in,out] the pointer to the raw program handle
330+
{
331+
// The PI program is the same as OpenCL program handle.
332+
assert(program);
333+
assert(handle);
334+
335+
if (*program == nullptr) {
336+
// uninitialized *program.
337+
assert(*handle);
338+
*program = cast<pi_program>(*handle);
339+
} else {
340+
assert(*handle == nullptr);
341+
*handle = *program;
342+
}
343+
cl_int result = clRetainProgram(cast<cl_program>(*handle));
344+
return cast<pi_result>(result);
345+
}
346+
308347
pi_result OCL(piProgramCreate)(pi_context context, const void *il,
309348
size_t length, pi_program *res_program) {
310349

@@ -992,6 +1031,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
9921031
_PI_CL(piPlatformsGet, OCL(piPlatformsGet))
9931032
_PI_CL(piPlatformGetInfo, clGetPlatformInfo)
9941033
// Device
1034+
_PI_CL(piextDeviceConvert, OCL(piextDeviceConvert))
9951035
_PI_CL(piDevicesGet, OCL(piDevicesGet))
9961036
_PI_CL(piDeviceGetInfo, clGetDeviceInfo)
9971037
_PI_CL(piDevicePartition, clCreateSubDevices)
@@ -1019,6 +1059,7 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
10191059
_PI_CL(piMemRelease, clReleaseMemObject)
10201060
_PI_CL(piMemBufferPartition, OCL(piMemBufferPartition))
10211061
// Program
1062+
_PI_CL(piextProgramConvert, OCL(piextProgramConvert))
10221063
_PI_CL(piProgramCreate, OCL(piProgramCreate))
10231064
_PI_CL(piclProgramCreateWithSource, OCL(piclProgramCreateWithSource))
10241065
_PI_CL(piclProgramCreateWithBinary, OCL(piclProgramCreateWithBinary))

sycl/source/detail/device_impl.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,30 @@ device_impl::device_impl()
1919
: MIsHostDevice(true),
2020
MPlatform(std::make_shared<platform_impl>(platform_impl())) {}
2121

22+
device_impl::device_impl(device_interop_handle_t InteropDeviceHandle,
23+
const plugin &Plugin)
24+
: device_impl(InteropDeviceHandle, nullptr, nullptr, Plugin) {}
25+
2226
device_impl::device_impl(RT::PiDevice Device, PlatformImplPtr Platform)
23-
: device_impl(Device, Platform, Platform->getPlugin()) {}
27+
: device_impl(nullptr, Device, Platform, Platform->getPlugin()) {}
2428

2529
device_impl::device_impl(RT::PiDevice Device, const plugin &Plugin)
26-
: device_impl(Device, nullptr, Plugin) {}
30+
: device_impl(nullptr, Device, nullptr, Plugin) {}
2731

28-
device_impl::device_impl(RT::PiDevice Device, PlatformImplPtr Platform,
32+
device_impl::device_impl(device_interop_handle_t InteropDeviceHandle,
33+
RT::PiDevice Device, PlatformImplPtr Platform,
2934
const plugin &Plugin)
3035
: MDevice(Device), MIsHostDevice(false) {
36+
37+
bool InteroperabilityConstructor = false;
38+
if (Device == nullptr) {
39+
assert(InteropDeviceHandle != nullptr);
40+
// Get PI device from the raw device handle.
41+
Plugin.call<PiApiKind::piextDeviceConvert>(&MDevice,
42+
(void **)&InteropDeviceHandle);
43+
InteroperabilityConstructor = true;
44+
}
45+
3146
// TODO catch an exception and put it to list of asynchronous exceptions
3247
Plugin.call<PiApiKind::piDeviceGetInfo>(
3348
MDevice, PI_DEVICE_INFO_TYPE, sizeof(RT::PiDeviceType), &MType, nullptr);
@@ -38,16 +53,18 @@ device_impl::device_impl(RT::PiDevice Device, PlatformImplPtr Platform,
3853
MDevice, PI_DEVICE_INFO_PARENT_DEVICE, sizeof(RT::PiDevice), &parent, nullptr);
3954

4055
MIsRootDevice = (nullptr == parent);
41-
if (!MIsRootDevice) {
56+
if (!MIsRootDevice && !InteroperabilityConstructor) {
4257
// TODO catch an exception and put it to list of asynchronous exceptions
58+
// Interoperability Constructor already calls DeviceRetain in
59+
// piextDeviceConvert.
4360
Plugin.call<PiApiKind::piDeviceRetain>(MDevice);
4461
}
4562

4663
// set MPlatform
4764
if (!Platform) {
4865
RT::PiPlatform plt = nullptr; // TODO catch an exception and put it to list
4966
// of asynchronous exceptions
50-
Plugin.call<PiApiKind::piDeviceGetInfo>(Device, PI_DEVICE_INFO_PLATFORM,
67+
Plugin.call<PiApiKind::piDeviceGetInfo>(MDevice, PI_DEVICE_INFO_PLATFORM,
5168
sizeof(plt), &plt, nullptr);
5269
Platform = std::make_shared<platform_impl>(plt, Plugin);
5370
}
@@ -75,13 +92,15 @@ cl_device_id device_impl::get() const {
7592
throw invalid_object_error("This instance of device is a host instance",
7693
PI_INVALID_DEVICE);
7794

95+
const detail::plugin &Plugin = getPlugin();
7896
if (!MIsRootDevice) {
7997
// TODO catch an exception and put it to list of asynchronous exceptions
80-
const detail::plugin &Plugin = getPlugin();
8198
Plugin.call<PiApiKind::piDeviceRetain>(MDevice);
8299
}
83-
// TODO: check that device is an OpenCL interop one
84-
return pi::cast<cl_device_id>(MDevice);
100+
void *handle = nullptr;
101+
Plugin.call<PiApiKind::piextDeviceConvert>(
102+
const_cast<RT::PiDevice *>(&MDevice), &handle);
103+
return pi::cast<cl_device_id>(handle);
85104
}
86105

87106
platform device_impl::get_platform() const {

sycl/source/detail/device_impl.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,20 @@ namespace detail {
2727
class platform_impl;
2828
using PlatformImplPtr = std::shared_ptr<platform_impl>;
2929

30+
// TODO: SYCL BE generalization will change this to something better.
31+
// For now this saves us from unwanted implicit casts.
32+
struct _device_interop_handle_t;
33+
using device_interop_handle_t = _device_interop_handle_t *;
34+
3035
// TODO: Make code thread-safe
3136
class device_impl {
3237
public:
3338
/// Constructs a SYCL device instance as a host device.
3439
device_impl();
3540

41+
/// Constructs a SYCL device instance using the provided raw device handle.
42+
explicit device_impl(device_interop_handle_t, const plugin &Plugin);
43+
3644
/// Constructs a SYCL device instance using the provided
3745
/// PI device instance.
3846
explicit device_impl(RT::PiDevice Device, PlatformImplPtr Platform);
@@ -196,7 +204,8 @@ class device_impl {
196204
is_affinity_supported(info::partition_affinity_domain AffinityDomain) const;
197205

198206
private:
199-
explicit device_impl(RT::PiDevice Device, PlatformImplPtr Platform,
207+
explicit device_impl(device_interop_handle_t InteropDevice,
208+
RT::PiDevice Device, PlatformImplPtr Platform,
200209
const plugin &Plugin);
201210
RT::PiDevice MDevice = 0;
202211
RT::PiDeviceType MType;

sycl/source/detail/program_impl.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,32 @@ program_impl::program_impl(
8080
}
8181
}
8282

83-
program_impl::program_impl(ContextImplPtr Context, RT::PiProgram Program)
83+
program_impl::program_impl(ContextImplPtr Context,
84+
program_interop_handle_t InteropProgram)
85+
: program_impl(Context, InteropProgram, nullptr) {}
86+
87+
program_impl::program_impl(ContextImplPtr Context,
88+
program_interop_handle_t InteropProgram,
89+
RT::PiProgram Program)
8490
: MProgram(Program), MContext(Context), MLinkable(true) {
8591

92+
const detail::plugin &Plugin = getPlugin();
93+
if (MProgram == nullptr) {
94+
assert(InteropProgram != nullptr &&
95+
"No InteropProgram/PiProgram defined with piextProgramConvert");
96+
// Translate the raw program handle into PI program.
97+
Plugin.call<PiApiKind::piextProgramConvert>(
98+
Context->getHandleRef(), &MProgram, (void **)&InteropProgram);
99+
} else
100+
Plugin.call<PiApiKind::piProgramRetain>(Program);
101+
86102
// TODO handle the case when cl_program build is in progress
87103
pi_uint32 NumDevices;
88-
const detail::plugin &Plugin = getPlugin();
89-
Plugin.call<PiApiKind::piProgramGetInfo>(Program, PI_PROGRAM_INFO_NUM_DEVICES,
90-
sizeof(pi_uint32), &NumDevices,
91-
nullptr);
104+
Plugin.call<PiApiKind::piProgramGetInfo>(
105+
MProgram, PI_PROGRAM_INFO_NUM_DEVICES, sizeof(pi_uint32), &NumDevices,
106+
nullptr);
92107
vector_class<RT::PiDevice> PiDevices(NumDevices);
93-
Plugin.call<PiApiKind::piProgramGetInfo>(Program, PI_PROGRAM_INFO_DEVICES,
108+
Plugin.call<PiApiKind::piProgramGetInfo>(MProgram, PI_PROGRAM_INFO_DEVICES,
94109
sizeof(RT::PiDevice) * NumDevices,
95110
PiDevices.data(), nullptr);
96111
vector_class<device> SyclContextDevices =
@@ -109,16 +124,17 @@ program_impl::program_impl(ContextImplPtr Context, RT::PiProgram Program)
109124
SyclContextDevices.erase(NewEnd, SyclContextDevices.end());
110125
MDevices = SyclContextDevices;
111126
RT::PiDevice Device = getSyclObjImpl(MDevices[0])->getHandleRef();
127+
assert(!MDevices.empty() && "No device found for this program");
112128
// TODO check build for each device instead
113129
cl_program_binary_type BinaryType;
114130
Plugin.call<PiApiKind::piProgramGetBuildInfo>(
115-
Program, Device, CL_PROGRAM_BINARY_TYPE, sizeof(cl_program_binary_type),
131+
MProgram, Device, CL_PROGRAM_BINARY_TYPE, sizeof(cl_program_binary_type),
116132
&BinaryType, nullptr);
117133
size_t Size = 0;
118134
Plugin.call<PiApiKind::piProgramGetBuildInfo>(
119-
Program, Device, CL_PROGRAM_BUILD_OPTIONS, 0, nullptr, &Size);
135+
MProgram, Device, CL_PROGRAM_BUILD_OPTIONS, 0, nullptr, &Size);
120136
std::vector<char> OptionsVector(Size);
121-
Plugin.call<PiApiKind::piProgramGetBuildInfo>(Program, Device,
137+
Plugin.call<PiApiKind::piProgramGetBuildInfo>(MProgram, Device,
122138
CL_PROGRAM_BUILD_OPTIONS, Size,
123139
OptionsVector.data(), nullptr);
124140
string_class Options(OptionsVector.begin(), OptionsVector.end());
@@ -137,12 +153,11 @@ program_impl::program_impl(ContextImplPtr Context, RT::PiProgram Program)
137153
MLinkOptions = "";
138154
MBuildOptions = Options;
139155
}
140-
Plugin.call<PiApiKind::piProgramRetain>(Program);
141156
}
142157

143158
program_impl::program_impl(ContextImplPtr Context, RT::PiKernel Kernel)
144-
: program_impl(Context,
145-
ProgramManager::getInstance().getClProgramFromClKernel(
159+
: program_impl(Context, nullptr,
160+
ProgramManager::getInstance().getPiProgramFromPiKernel(
146161
Kernel, Context)) {}
147162

148163
program_impl::~program_impl() {

0 commit comments

Comments
 (0)