Skip to content

Commit 2f64227

Browse files
[SYCL] Add support of multiple devices within a context (#2343)
This patch adds support of multiple devices within a context. Programs can be created from images or from SPIR-V binaries. Only kernels, created using invoking kernels functions (parallel_for, single_task, ...) are supported. Kernels, created in OpenCL interoperability mode (using sycl::program and sycl::kernel functions) are not supported.
1 parent 047e2ec commit 2f64227

File tree

7 files changed

+206
-146
lines changed

7 files changed

+206
-146
lines changed

sycl/source/detail/context_impl.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ class context_impl {
134134
/// more details.
135135
///
136136
/// \returns a map with device library programs.
137-
std::map<DeviceLibExt, RT::PiProgram> &getCachedLibPrograms() {
137+
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram> &
138+
getCachedLibPrograms() {
138139
return MCachedLibPrograms;
139140
}
140141

@@ -155,7 +156,8 @@ class context_impl {
155156
PlatformImplPtr MPlatform;
156157
bool MHostContext;
157158
bool MUseCUDAPrimaryContext;
158-
std::map<DeviceLibExt, RT::PiProgram> MCachedLibPrograms;
159+
std::map<std::pair<DeviceLibExt, RT::PiDevice>, RT::PiProgram>
160+
MCachedLibPrograms;
159161
mutable KernelProgramCache MKernelProgramCache;
160162
};
161163

sycl/source/detail/kernel_program_cache.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,17 @@ class KernelProgramCache {
6767
using PiProgramT = std::remove_pointer<RT::PiProgram>::type;
6868
using PiProgramPtrT = std::atomic<PiProgramT *>;
6969
using ProgramWithBuildStateT = BuildResult<PiProgramT>;
70-
using ProgramCacheKeyT = std::pair<SerializedObj, KernelSetId>;
70+
using ProgramCacheKeyT =
71+
std::pair<std::pair<SerializedObj, KernelSetId>, RT::PiDevice>;
7172
using ProgramCacheT = std::map<ProgramCacheKeyT, ProgramWithBuildStateT>;
7273
using ContextPtr = context_impl *;
7374

7475
using PiKernelT = std::remove_pointer<RT::PiKernel>::type;
7576

7677
using PiKernelPtrT = std::atomic<PiKernelT *>;
7778
using KernelWithBuildStateT = BuildResult<PiKernelT>;
78-
using KernelByNameT = std::map<string_class, KernelWithBuildStateT>;
79+
using KernelByNameT =
80+
std::map<std::pair<string_class, RT::PiDevice>, KernelWithBuildStateT>;
7981
using KernelCacheT = std::map<RT::PiProgram, KernelByNameT>;
8082

8183
~KernelProgramCache();

sycl/source/detail/program_impl.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,14 @@ program_impl::program_impl(ContextImplPtr Context)
2828

2929
program_impl::program_impl(ContextImplPtr Context,
3030
vector_class<device> DeviceList)
31-
: MContext(Context), MDevices(DeviceList) {}
31+
: MContext(Context), MDevices(DeviceList) {
32+
if (Context->getDevices().size() > 1) {
33+
throw feature_not_supported(
34+
"multiple devices within a context are not supported with "
35+
"sycl::program and sycl::kernel",
36+
PI_INVALID_OPERATION);
37+
}
38+
}
3239

3340
program_impl::program_impl(
3441
vector_class<shared_ptr_class<program_impl>> ProgramList,
@@ -51,6 +58,12 @@ program_impl::program_impl(
5158
}
5259

5360
MContext = ProgramList[0]->MContext;
61+
if (MContext->getDevices().size() > 1) {
62+
throw feature_not_supported(
63+
"multiple devices within a context are not supported with "
64+
"sycl::program and sycl::kernel",
65+
PI_INVALID_OPERATION);
66+
}
5467
MDevices = ProgramList[0]->MDevices;
5568
vector_class<device> DevicesSorted;
5669
if (!is_host()) {
@@ -105,6 +118,13 @@ program_impl::program_impl(ContextImplPtr Context,
105118
RT::PiProgram Program)
106119
: MProgram(Program), MContext(Context), MLinkable(true) {
107120

121+
if (Context->getDevices().size() > 1) {
122+
throw feature_not_supported(
123+
"multiple devices within a context are not supported with "
124+
"sycl::program and sycl::kernel",
125+
PI_INVALID_OPERATION);
126+
}
127+
108128
const detail::plugin &Plugin = getPlugin();
109129
if (MProgram == nullptr) {
110130
assert(InteropProgram &&
@@ -233,7 +253,7 @@ void program_impl::build_with_kernel_name(string_class KernelName,
233253
if (is_cacheable_with_options(BuildOptions)) {
234254
MProgramAndKernelCachingAllowed = true;
235255
MProgram = ProgramManager::getInstance().getBuiltPIProgram(
236-
Module, get_context(), KernelName, this,
256+
Module, get_context(), get_devices()[0], KernelName, this,
237257
/*JITCompilationIsRequired=*/(!BuildOptions.empty()));
238258
const detail::plugin &Plugin = getPlugin();
239259
Plugin.call<PiApiKind::piProgramRetain>(MProgram);
@@ -356,7 +376,7 @@ void program_impl::build(const string_class &Options) {
356376
check_device_feature_support<info::device::is_compiler_available>(MDevices);
357377
vector_class<RT::PiDevice> Devices(get_pi_devices());
358378
const detail::plugin &Plugin = getPlugin();
359-
ProgramManager::getInstance().flushSpecConstants(*this);
379+
ProgramManager::getInstance().flushSpecConstants(*this, get_pi_devices()[0]);
360380
RT::PiResult Err = Plugin.call_nocheck<PiApiKind::piProgramBuild>(
361381
MProgram, Devices.size(), Devices.data(), Options.c_str(), nullptr,
362382
nullptr);
@@ -404,7 +424,8 @@ RT::PiKernel program_impl::get_pi_kernel(const string_class &KernelName) const {
404424
if (is_cacheable()) {
405425
std::tie(Kernel, std::ignore) =
406426
ProgramManager::getInstance().getOrCreateKernel(
407-
MProgramModuleHandle, get_context(), KernelName, this);
427+
MProgramModuleHandle, get_context(), get_devices()[0], KernelName,
428+
this);
408429
getPlugin().call<PiApiKind::piKernelRetain>(Kernel);
409430
} else {
410431
const detail::plugin &Plugin = getPlugin();
@@ -453,9 +474,10 @@ void program_impl::create_pi_program_with_kernel_name(
453474
bool JITCompilationIsRequired) {
454475
assert(!MProgram && "This program already has an encapsulated PI program");
455476
ProgramManager &PM = ProgramManager::getInstance();
477+
const device FirstDevice = get_devices()[0];
456478
RTDeviceBinaryImage &Img = PM.getDeviceImage(
457-
Module, KernelName, get_context(), JITCompilationIsRequired);
458-
MProgram = PM.createPIProgram(Img, get_context());
479+
Module, KernelName, get_context(), FirstDevice, JITCompilationIsRequired);
480+
MProgram = PM.createPIProgram(Img, get_context(), FirstDevice);
459481
}
460482

461483
template <>

0 commit comments

Comments
 (0)