Skip to content

Commit adebd42

Browse files
[SYCL] Fix a race condition when enqueueing an interop kernel (#8111)
piKernelSetArg isn't supposed to be called from different threads on the same kernel.
1 parent 2864eac commit adebd42

File tree

5 files changed

+214
-127
lines changed

5 files changed

+214
-127
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,7 +2106,12 @@ pi_int32 enqueueImpKernel(
21062106
auto ContextImpl = Queue->getContextImplPtr();
21072107
auto DeviceImpl = Queue->getDeviceImplPtr();
21082108
RT::PiKernel Kernel = nullptr;
2109-
std::mutex *KernelMutex = nullptr;
2109+
// Cacheable kernels use per-kernel mutexes that will be fetched from the
2110+
// cache, others (e.g. interoperability kernels) share a single mutex.
2111+
// TODO consider adding a PiKernel -> mutex map for allowing to enqueue
2112+
// different PiKernel's in parallel.
2113+
static std::mutex NoncacheableEnqueueMutex;
2114+
std::mutex *KernelMutex = &NoncacheableEnqueueMutex;
21102115
RT::PiProgram Program = nullptr;
21112116

21122117
std::shared_ptr<kernel_impl> SyclKernelImpl;
@@ -2179,18 +2184,13 @@ pi_int32 enqueueImpKernel(
21792184
detail::ProgramManager::getInstance().getEliminatedKernelArgMask(
21802185
OSModuleHandle, Program, KernelName);
21812186
}
2182-
if (KernelMutex != nullptr) {
2183-
// For cacheable kernels, we use per-kernel mutex
2187+
{
2188+
assert(KernelMutex);
21842189
std::lock_guard<std::mutex> Lock(*KernelMutex);
21852190
Error = SetKernelParamsAndLaunch(Queue, Args, DeviceImageImpl, Kernel,
21862191
NDRDesc, EventsWaitList, OutEvent,
21872192
EliminatedArgMask, getMemAllocationFunc);
2188-
} else {
2189-
Error = SetKernelParamsAndLaunch(Queue, Args, DeviceImageImpl, Kernel,
2190-
NDRDesc, EventsWaitList, OutEvent,
2191-
EliminatedArgMask, getMemAllocationFunc);
21922193
}
2193-
21942194
if (PI_SUCCESS != Error) {
21952195
// If we have got non-success error code, let's analyze it to emit nice
21962196
// exception explaining what was wrong
Lines changed: 14 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1+
//==------- SetArgForLocalAccessor.cpp --- Handler unit tests --------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
#include <gtest/gtest.h>
10+
#include <helpers/KernelInteropCommon.hpp>
211
#include <helpers/PiMock.hpp>
312

413
#include <sycl/sycl.hpp>
@@ -9,135 +18,20 @@
918

1019
namespace {
1120

12-
struct TestContext {
13-
size_t localBufferArgSize = 0;
14-
15-
// SYCL RT has number of checks that all devices and contexts are consistent
16-
// between kernel, kernel_bundle and other objects.
17-
//
18-
// To ensure that those checks pass, we intercept some PI calls to extract
19-
// the exact PI handles of device and context used in queue creation to later
20-
// return them when program/context/kernel info is requested.
21-
pi_device deviceHandle;
22-
pi_context contextHandle;
23-
24-
pi_program programHandle = createDummyHandle<pi_program>();
25-
26-
~TestContext() { releaseDummyHandle<pi_program>(programHandle); }
27-
};
28-
29-
TestContext GlobalContext;
30-
31-
} // namespace
21+
size_t LocalBufferArgSize = 0;
3222

3323
pi_result redefined_piKernelSetArg(pi_kernel kernel, pi_uint32 arg_index,
3424
size_t arg_size, const void *arg_value) {
35-
GlobalContext.localBufferArgSize = arg_size;
36-
37-
return PI_SUCCESS;
38-
}
39-
40-
pi_result after_piContextGetInfo(pi_context context, pi_context_info param_name,
41-
size_t param_value_size, void *param_value,
42-
size_t *param_value_size_ret) {
43-
switch (param_name) {
44-
case PI_CONTEXT_INFO_DEVICES:
45-
if (param_value)
46-
*static_cast<pi_device *>(param_value) = GlobalContext.deviceHandle;
47-
if (param_value_size_ret)
48-
*param_value_size_ret = sizeof(GlobalContext.deviceHandle);
49-
break;
50-
default:;
51-
}
52-
53-
return PI_SUCCESS;
54-
}
55-
56-
pi_result after_piProgramGetInfo(pi_program program, pi_program_info param_name,
57-
size_t param_value_size, void *param_value,
58-
size_t *param_value_size_ret) {
59-
60-
switch (param_name) {
61-
case PI_PROGRAM_INFO_DEVICES:
62-
if (param_value_size_ret)
63-
*param_value_size_ret = sizeof(GlobalContext.deviceHandle);
64-
if (param_value)
65-
*static_cast<pi_device *>(param_value) = GlobalContext.deviceHandle;
66-
break;
67-
default:;
68-
}
69-
70-
return PI_SUCCESS;
71-
}
72-
73-
pi_result redefined_piProgramGetBuildInfo(pi_program program, pi_device device,
74-
_pi_program_build_info param_name,
75-
size_t param_value_size,
76-
void *param_value,
77-
size_t *param_value_size_ret) {
78-
switch (param_name) {
79-
case PI_PROGRAM_BUILD_INFO_BINARY_TYPE:
80-
if (param_value_size_ret)
81-
*param_value_size_ret = sizeof(pi_program_binary_type);
82-
if (param_value)
83-
*static_cast<pi_program_binary_type *>(param_value) =
84-
PI_PROGRAM_BINARY_TYPE_EXECUTABLE;
85-
break;
86-
default:;
87-
}
88-
89-
return PI_SUCCESS;
90-
}
91-
92-
pi_result after_piContextCreate(const pi_context_properties *properties,
93-
pi_uint32 num_devices, const pi_device *devices,
94-
void (*pfn_notify)(const char *errinfo,
95-
const void *private_info,
96-
size_t cb, void *user_data),
97-
void *user_data, pi_context *ret_context) {
98-
if (ret_context)
99-
GlobalContext.contextHandle = *ret_context;
100-
GlobalContext.deviceHandle = *devices;
101-
return PI_SUCCESS;
102-
}
103-
104-
pi_result after_piKernelGetInfo(pi_kernel kernel, pi_kernel_info param_name,
105-
size_t param_value_size, void *param_value,
106-
size_t *param_value_size_ret) {
107-
switch (param_name) {
108-
case PI_KERNEL_INFO_CONTEXT:
109-
if (param_value_size_ret)
110-
*param_value_size_ret = sizeof(GlobalContext.contextHandle);
111-
if (param_value)
112-
*static_cast<pi_context *>(param_value) = GlobalContext.contextHandle;
113-
break;
114-
case PI_KERNEL_INFO_PROGRAM:
115-
if (param_value_size_ret)
116-
*param_value_size_ret = sizeof(GlobalContext.programHandle);
117-
if (param_value)
118-
*(pi_program *)param_value = GlobalContext.programHandle;
119-
break;
120-
default:;
121-
}
25+
LocalBufferArgSize = arg_size;
12226

12327
return PI_SUCCESS;
12428
}
12529

12630
TEST(HandlerSetArg, LocalAccessor) {
12731
sycl::unittest::PiMock Mock;
128-
32+
redefineMockForKernelInterop(Mock);
12933
Mock.redefine<sycl::detail::PiApiKind::piKernelSetArg>(
13034
redefined_piKernelSetArg);
131-
Mock.redefineAfter<sycl::detail::PiApiKind::piContextCreate>(
132-
after_piContextCreate);
133-
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramGetInfo>(
134-
after_piProgramGetInfo);
135-
Mock.redefineAfter<sycl::detail::PiApiKind::piContextGetInfo>(
136-
after_piContextGetInfo);
137-
Mock.redefineAfter<sycl::detail::PiApiKind::piKernelGetInfo>(
138-
after_piKernelGetInfo);
139-
Mock.redefine<sycl::detail::PiApiKind::piProgramGetBuildInfo>(
140-
redefined_piProgramGetBuildInfo);
14135

14236
constexpr size_t Size = 128;
14337
sycl::queue Q;
@@ -154,5 +48,6 @@ TEST(HandlerSetArg, LocalAccessor) {
15448
CGH.single_task(Kernel);
15549
}).wait();
15650

157-
ASSERT_EQ(GlobalContext.localBufferArgSize, Size * sizeof(float));
51+
ASSERT_EQ(LocalBufferArgSize, Size * sizeof(float));
15852
}
53+
} // namespace
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//==-- KernelInteropCommon.hpp --- Common kernel interop redefinitions -----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <helpers/PiMock.hpp>
10+
11+
struct TestContext {
12+
13+
// SYCL RT has number of checks that all devices and contexts are consistent
14+
// between kernel, kernel_bundle and other objects.
15+
//
16+
// To ensure that those checks pass, we intercept some PI calls to extract
17+
// the exact PI handles of device and context used in queue creation to later
18+
// return them when program/context/kernel info is requested.
19+
pi_device deviceHandle;
20+
pi_context contextHandle;
21+
22+
pi_program programHandle = createDummyHandle<pi_program>();
23+
24+
~TestContext() { releaseDummyHandle<pi_program>(programHandle); }
25+
};
26+
27+
TestContext GlobalContext;
28+
29+
pi_result after_piContextGetInfo(pi_context context, pi_context_info param_name,
30+
size_t param_value_size, void *param_value,
31+
size_t *param_value_size_ret) {
32+
switch (param_name) {
33+
case PI_CONTEXT_INFO_DEVICES:
34+
if (param_value)
35+
*static_cast<pi_device *>(param_value) = GlobalContext.deviceHandle;
36+
if (param_value_size_ret)
37+
*param_value_size_ret = sizeof(GlobalContext.deviceHandle);
38+
break;
39+
default:;
40+
}
41+
42+
return PI_SUCCESS;
43+
}
44+
45+
pi_result after_piProgramGetInfo(pi_program program, pi_program_info param_name,
46+
size_t param_value_size, void *param_value,
47+
size_t *param_value_size_ret) {
48+
49+
switch (param_name) {
50+
case PI_PROGRAM_INFO_DEVICES:
51+
if (param_value_size_ret)
52+
*param_value_size_ret = sizeof(GlobalContext.deviceHandle);
53+
if (param_value)
54+
*static_cast<pi_device *>(param_value) = GlobalContext.deviceHandle;
55+
break;
56+
default:;
57+
}
58+
59+
return PI_SUCCESS;
60+
}
61+
62+
pi_result redefined_piProgramGetBuildInfo(pi_program program, pi_device device,
63+
_pi_program_build_info param_name,
64+
size_t param_value_size,
65+
void *param_value,
66+
size_t *param_value_size_ret) {
67+
switch (param_name) {
68+
case PI_PROGRAM_BUILD_INFO_BINARY_TYPE:
69+
if (param_value_size_ret)
70+
*param_value_size_ret = sizeof(pi_program_binary_type);
71+
if (param_value)
72+
*static_cast<pi_program_binary_type *>(param_value) =
73+
PI_PROGRAM_BINARY_TYPE_EXECUTABLE;
74+
break;
75+
default:;
76+
}
77+
78+
return PI_SUCCESS;
79+
}
80+
81+
pi_result after_piContextCreate(const pi_context_properties *properties,
82+
pi_uint32 num_devices, const pi_device *devices,
83+
void (*pfn_notify)(const char *errinfo,
84+
const void *private_info,
85+
size_t cb, void *user_data),
86+
void *user_data, pi_context *ret_context) {
87+
if (ret_context)
88+
GlobalContext.contextHandle = *ret_context;
89+
GlobalContext.deviceHandle = *devices;
90+
return PI_SUCCESS;
91+
}
92+
93+
pi_result after_piKernelGetInfo(pi_kernel kernel, pi_kernel_info param_name,
94+
size_t param_value_size, void *param_value,
95+
size_t *param_value_size_ret) {
96+
switch (param_name) {
97+
case PI_KERNEL_INFO_CONTEXT:
98+
if (param_value_size_ret)
99+
*param_value_size_ret = sizeof(GlobalContext.contextHandle);
100+
if (param_value)
101+
*static_cast<pi_context *>(param_value) = GlobalContext.contextHandle;
102+
break;
103+
case PI_KERNEL_INFO_PROGRAM:
104+
if (param_value_size_ret)
105+
*param_value_size_ret = sizeof(GlobalContext.programHandle);
106+
if (param_value)
107+
*(pi_program *)param_value = GlobalContext.programHandle;
108+
break;
109+
default:;
110+
}
111+
112+
return PI_SUCCESS;
113+
}
114+
115+
void redefineMockForKernelInterop(sycl::unittest::PiMock &Mock) {
116+
Mock.redefineAfter<sycl::detail::PiApiKind::piContextCreate>(
117+
after_piContextCreate);
118+
Mock.redefineAfter<sycl::detail::PiApiKind::piProgramGetInfo>(
119+
after_piProgramGetInfo);
120+
Mock.redefineAfter<sycl::detail::PiApiKind::piContextGetInfo>(
121+
after_piContextGetInfo);
122+
Mock.redefineAfter<sycl::detail::PiApiKind::piKernelGetInfo>(
123+
after_piKernelGetInfo);
124+
Mock.redefine<sycl::detail::PiApiKind::piProgramGetBuildInfo>(
125+
redefined_piProgramGetBuildInfo);
126+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
add_sycl_unittest(ThreadSafetyTests OBJECT
22
HostAccessorDeadLock.cpp
3+
InteropKernelEnqueue.cpp
34
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//==-------- InteropKernelEnqueue.cpp --- Thread safety unit tests ---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <cstddef>
10+
#include <gtest/gtest.h>
11+
#include <helpers/KernelInteropCommon.hpp>
12+
#include <helpers/PiMock.hpp>
13+
#include <sycl/sycl.hpp>
14+
15+
#include "ThreadUtils.h"
16+
17+
namespace {
18+
using namespace sycl;
19+
20+
constexpr std::size_t NArgs = 16;
21+
constexpr std::size_t ThreadCount = 4;
22+
constexpr std::size_t LaunchCount = 8;
23+
24+
pi_uint32 LastArgSet = -1;
25+
std::size_t LastThread = -1;
26+
pi_result redefined_piKernelSetArg(pi_kernel kernel, pi_uint32 arg_index,
27+
size_t arg_size, const void *arg_value) {
28+
EXPECT_EQ((LastArgSet + 1) % NArgs, arg_index);
29+
LastArgSet = arg_index;
30+
std::size_t ArgValue = *static_cast<const std::size_t *>(arg_value);
31+
if (arg_index == 0)
32+
LastThread = ArgValue;
33+
else
34+
EXPECT_EQ(LastThread, ArgValue);
35+
return PI_SUCCESS;
36+
}
37+
38+
TEST(KernelEnqueue, InteropKernel) {
39+
unittest::PiMock Mock;
40+
redefineMockForKernelInterop(Mock);
41+
Mock.redefine<sycl::detail::PiApiKind::piKernelSetArg>(
42+
redefined_piKernelSetArg);
43+
44+
platform Plt = Mock.getPlatform();
45+
queue Q;
46+
47+
DummyHandleT Handle;
48+
auto KernelCL = reinterpret_cast<typename sycl::backend_traits<
49+
sycl::backend::opencl>::template input_type<sycl::kernel>>(&Handle);
50+
auto Kernel =
51+
sycl::make_kernel<sycl::backend::opencl>(KernelCL, Q.get_context());
52+
53+
auto TestLambda = [&](std::size_t ThreadId) {
54+
Q.submit([&](sycl::handler &CGH) {
55+
for (std::size_t I = 0; I < NArgs; ++I)
56+
CGH.set_arg(I, ThreadId);
57+
CGH.single_task(Kernel);
58+
}).wait();
59+
};
60+
61+
for (std::size_t I = 0; I < LaunchCount; ++I) {
62+
ThreadPool Pool(ThreadCount, TestLambda);
63+
}
64+
}
65+
} // namespace

0 commit comments

Comments
 (0)