Skip to content

Commit 256d2cc

Browse files
[SYCL] Reset PiFunctionTable in PiMock::~PiMock (#6128)
1 parent 858d04d commit 256d2cc

File tree

6 files changed

+48
-35
lines changed

6 files changed

+48
-35
lines changed

sycl/unittests/helpers/PiMock.hpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <detail/platform_impl.hpp>
3636

3737
#include <functional>
38+
#include <optional>
3839

3940
__SYCL_INLINE_NAMESPACE(cl) {
4041
namespace sycl {
@@ -133,14 +134,26 @@ class PiMock {
133134
MPiPluginMockPtr = &NewPluginPtr->getPiPlugin();
134135
// Save a copy of the platform resource
135136
MPlatform = OriginalPlatform;
137+
OrigFuncTable = OriginalPiPlugin.getPiPlugin().PiFunctionTable;
136138
}
137139

138140
/// Explicit construction from a host_selector is forbidden.
139141
PiMock(const cl::sycl::host_selector &HostSelector) = delete;
140142

143+
PiMock(PiMock &&Other) {
144+
MPlatform = std::move(Other.MPlatform);
145+
OrigFuncTable = std::move(Other.OrigFuncTable);
146+
Other.OrigFuncTable = {}; // Move above doesn't reset the optional.
147+
MPiPluginMockPtr = std::move(Other.MPiPluginMockPtr);
148+
}
141149
PiMock(const PiMock &) = delete;
142150
PiMock &operator=(const PiMock &) = delete;
143-
~PiMock() = default;
151+
~PiMock() {
152+
if (!OrigFuncTable)
153+
return;
154+
155+
MPiPluginMockPtr->PiFunctionTable = *OrigFuncTable;
156+
}
144157

145158
/// Returns a handle to the SYCL platform instance.
146159
///
@@ -184,6 +197,7 @@ class PiMock {
184197

185198
private:
186199
cl::sycl::platform MPlatform;
200+
std::optional<pi_plugin::FunctionPointers> OrigFuncTable;
187201
// Extracted at initialization for convenience purposes. The resource
188202
// itself is owned by the platform instance.
189203
RT::PiPlugin *MPiPluginMockPtr;

sycl/unittests/queue/EventClear.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ pi_result redefinedEventRelease(pi_event event) {
8282
return PI_SUCCESS;
8383
}
8484

85-
bool preparePiMock(platform &Plt) {
85+
std::optional<unittest::PiMock> preparePiMock(platform &Plt) {
8686
if (Plt.is_host()) {
8787
std::cout << "Not run on host - no PI events created in that case"
8888
<< std::endl;
89-
return false;
89+
return {};
9090
}
9191

9292
unittest::PiMock Mock{Plt};
@@ -98,14 +98,15 @@ bool preparePiMock(platform &Plt) {
9898
Mock.redefine<detail::PiApiKind::piEventGetInfo>(redefinedEventGetInfo);
9999
Mock.redefine<detail::PiApiKind::piEventRetain>(redefinedEventRetain);
100100
Mock.redefine<detail::PiApiKind::piEventRelease>(redefinedEventRelease);
101-
return true;
101+
return std::move(Mock);
102102
}
103103

104104
// Check that the USM events are cleared from the queue upon call to wait(),
105105
// so that they are not waited for multiple times.
106106
TEST(QueueEventClear, ClearOnQueueWait) {
107107
platform Plt{default_selector()};
108-
if (!preparePiMock(Plt))
108+
auto Mock = preparePiMock(Plt);
109+
if (!Mock)
109110
return;
110111

111112
context Ctx{Plt.get_devices()[0]};
@@ -126,7 +127,8 @@ TEST(QueueEventClear, ClearOnQueueWait) {
126127
// exceeds a threshold.
127128
TEST(QueueEventClear, CleanupOnThreshold) {
128129
platform Plt{default_selector()};
129-
if (!preparePiMock(Plt))
130+
auto Mock = preparePiMock(Plt);
131+
if (!Mock)
130132
return;
131133

132134
context Ctx{Plt.get_devices()[0]};

sycl/unittests/queue/GetProfilingInfo.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#include <helpers/PiImage.hpp>
1818
#include <helpers/PiMock.hpp>
1919

20+
#include <detail/context_impl.hpp>
21+
2022
class InfoTestKernel;
2123

2224
__SYCL_INLINE_NAMESPACE(cl) {
@@ -85,14 +87,14 @@ TEST(GetProfilingInfo, normal_pass_without_exception) {
8587
Mock.redefine<sycl::detail::PiApiKind::piEventGetProfilingInfo>(
8688
redefinedPiEventGetProfilingInfo);
8789
const sycl::device Dev = Plt.get_devices()[0];
90+
sycl::context Ctx{Dev};
8891
static sycl::unittest::PiImage DevImage_1 =
8992
generateTestImage<InfoTestKernel>();
9093

9194
static sycl::unittest::PiImageArray<1> DevImageArray = {&DevImage_1};
9295
auto KernelID_1 = sycl::get_kernel_id<InfoTestKernel>();
9396
sycl::queue Queue{
94-
Dev, sycl::property_list{sycl::property::queue::enable_profiling{}}};
95-
const sycl::context Ctx = Queue.get_context();
97+
Ctx, Dev, sycl::property_list{sycl::property::queue::enable_profiling{}}};
9698
auto KernelBundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(
9799
Ctx, {Dev}, {KernelID_1});
98100

@@ -139,13 +141,13 @@ TEST(GetProfilingInfo, command_exception_check) {
139141
redefinedPiEventGetProfilingInfo);
140142

141143
const sycl::device Dev = Plt.get_devices()[0];
144+
sycl::context Ctx{Dev};
142145
static sycl::unittest::PiImage DevImage_1 =
143146
generateTestImage<InfoTestKernel>();
144147

145148
static sycl::unittest::PiImageArray<1> DevImageArray = {&DevImage_1};
146149
auto KernelID_1 = sycl::get_kernel_id<InfoTestKernel>();
147-
sycl::queue Queue{Dev};
148-
const sycl::context Ctx = Queue.get_context();
150+
sycl::queue Queue{Ctx, Dev};
149151
auto KernelBundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(
150152
Ctx, {Dev}, {KernelID_1});
151153
const int globalWIs{512};
@@ -219,6 +221,7 @@ TEST(GetProfilingInfo, check_if_now_dead_queue_property_set) {
219221
Mock.redefine<sycl::detail::PiApiKind::piEventGetProfilingInfo>(
220222
redefinedPiEventGetProfilingInfo);
221223
const sycl::device Dev = Plt.get_devices()[0];
224+
sycl::context Ctx{Dev};
222225
static sycl::unittest::PiImage DevImage_1 =
223226
generateTestImage<InfoTestKernel>();
224227

@@ -228,8 +231,8 @@ TEST(GetProfilingInfo, check_if_now_dead_queue_property_set) {
228231
cl::sycl::event event;
229232
{
230233
sycl::queue Queue{
231-
Dev, sycl::property_list{sycl::property::queue::enable_profiling{}}};
232-
const sycl::context Ctx = Queue.get_context();
234+
Ctx, Dev,
235+
sycl::property_list{sycl::property::queue::enable_profiling{}}};
233236
auto KernelBundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(
234237
Ctx, {Dev}, {KernelID_1});
235238
event = Queue.submit([&](sycl::handler &cgh) {
@@ -274,6 +277,7 @@ TEST(GetProfilingInfo, check_if_now_dead_queue_property_not_set) {
274277
Mock.redefine<sycl::detail::PiApiKind::piEventGetProfilingInfo>(
275278
redefinedPiEventGetProfilingInfo);
276279
const sycl::device Dev = Plt.get_devices()[0];
280+
sycl::context Ctx{Dev};
277281
static sycl::unittest::PiImage DevImage_1 =
278282
generateTestImage<InfoTestKernel>();
279283

@@ -282,8 +286,7 @@ TEST(GetProfilingInfo, check_if_now_dead_queue_property_not_set) {
282286
const int globalWIs{512};
283287
cl::sycl::event event;
284288
{
285-
sycl::queue Queue{Dev};
286-
const sycl::context Ctx = Queue.get_context();
289+
sycl::queue Queue{Ctx, Dev};
287290
auto KernelBundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(
288291
Ctx, {Dev}, {KernelID_1});
289292
event = Queue.submit([&](sycl::handler &cgh) {
@@ -325,4 +328,6 @@ TEST(GetProfilingInfo, check_if_now_dead_queue_property_not_set) {
325328
"'enable_profiling' queue property");
326329
}
327330
}
331+
// The test passes without this, but keep it still, just in case.
332+
sycl::detail::getSyclObjImpl(Ctx)->getKernelProgramCache().reset();
328333
}

sycl/unittests/queue/USM.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ pi_result redefinedUSMEnqueueMemset(pi_queue, void *, pi_int32, size_t,
5454
}
5555

5656
pi_result redefinedEventRelease(pi_event) { return PI_SUCCESS; }
57+
pi_result redefinedEventsWait(pi_uint32 /* num_events */,
58+
const pi_event * /* event_list */) {
59+
return PI_SUCCESS;
60+
}
5761

58-
bool preparePiMock(platform &Plt) {
62+
// Check that zero-length USM memset/memcpy use piEnqueueEventsWait.
63+
TEST(USM, NoOpPreservesDependencyChain) {
64+
platform Plt{default_selector()};
5965
if (Plt.is_host()) {
6066
std::cout << "Not run on host - no PI events created in that case"
6167
<< std::endl;
62-
return false;
68+
return;
6369
}
6470

6571
unittest::PiMock Mock{Plt};
@@ -70,14 +76,7 @@ bool preparePiMock(platform &Plt) {
7076
Mock.redefine<detail::PiApiKind::piextUSMEnqueueMemset>(
7177
redefinedUSMEnqueueMemset);
7278
Mock.redefine<detail::PiApiKind::piEventRelease>(redefinedEventRelease);
73-
return true;
74-
}
75-
76-
// Check that zero-length USM memset/memcpy use piEnqueueEventsWait.
77-
TEST(USM, NoOpPreservesDependencyChain) {
78-
platform Plt{default_selector()};
79-
if (!preparePiMock(Plt))
80-
return;
79+
Mock.redefine<detail::PiApiKind::piEventsWait>(redefinedEventsWait);
8180

8281
context Ctx{Plt.get_devices()[0]};
8382
queue Q{Ctx, default_selector()};
@@ -102,6 +101,6 @@ TEST(USM, NoOpPreservesDependencyChain) {
102101

103102
free(Src, Q);
104103
free(Dst, Q);
104+
TestContext.Deps.clear();
105105
}
106-
107106
} // namespace

sycl/unittests/queue/Wait.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@ pi_result redefinedEventRelease(pi_event event) {
8686
return PI_SUCCESS;
8787
}
8888

89-
bool preparePiMock(platform &Plt) {
89+
TEST(QueueWait, QueueWaitTest) {
90+
platform Plt{default_selector()};
9091
if (Plt.is_host()) {
9192
std::cout << "Not run on host - no PI events created in that case"
9293
<< std::endl;
93-
return false;
94+
return;
9495
}
9596

9697
unittest::PiMock Mock{Plt};
@@ -105,13 +106,6 @@ bool preparePiMock(platform &Plt) {
105106
Mock.redefine<detail::PiApiKind::piEventGetInfo>(redefinedEventGetInfo);
106107
Mock.redefine<detail::PiApiKind::piEventRetain>(redefinedEventRetain);
107108
Mock.redefine<detail::PiApiKind::piEventRelease>(redefinedEventRelease);
108-
return true;
109-
}
110-
111-
TEST(QueueWait, QueueWaitTest) {
112-
platform Plt{default_selector()};
113-
if (!preparePiMock(Plt))
114-
return;
115109
context Ctx{Plt.get_devices()[0]};
116110
queue Q{Ctx, default_selector()};
117111

sycl/unittests/scheduler/RequiredWGSize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,5 @@ TEST(RequiredWGSize, NoRequiredSize) {
238238
TEST(RequiredWGSize, HasRequiredSize) {
239239
reset();
240240
RequiredLocalSize = {1, 2, 3};
241-
return; // FIXME: Resolve post-commit failures.
242241
performChecks();
243242
}

0 commit comments

Comments
 (0)