Skip to content

Commit 09c03e1

Browse files
author
Ivan Karachun
committed
Updated the test
Signed-off-by: Ivan Karachun <[email protected]>
1 parent bb423be commit 09c03e1

File tree

4 files changed

+58
-113
lines changed

4 files changed

+58
-113
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
add_sycl_unittest(ThreadSafetyTests
2-
ThreadUtils.cpp
32
HostAccessorDeadLock.cpp
43
)
Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==---- SchedulerThreadSafety.cpp --- Thread Safety unit tests ------------==//
1+
//==----- HostAccessorDeadLock.cpp --- Thread Safety unit tests ------------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,40 +15,47 @@
1515
namespace {
1616
constexpr auto sycl_read_write = cl::sycl::access::mode::read_write;
1717

18-
template <typename T, int Dim> class TestDeadLock : public ParallelTask {
19-
public:
20-
TestDeadLock(T *Data, std::size_t Size)
21-
: MBuffer(Data, cl::sycl::range<Dim>(Size)), MBufferSize(Size) {}
22-
23-
void taskBody(size_t ThreadId) {
24-
auto acc = MBuffer.template get_access<sycl_read_write>();
25-
for (std::size_t i = 0; i < MBufferSize; ++i) {
26-
acc[i] = ThreadId;
27-
if (i == 0) {
28-
MMutex.lock();
29-
MThreadOrder.push_back(ThreadId);
30-
MMutex.unlock();
31-
}
32-
}
33-
}
18+
class HostAccessorDeadLockTest : public ::testing::Test {
19+
protected:
20+
HostAccessorDeadLockTest() : MPool() {}
21+
~HostAccessorDeadLockTest() override = default;
3422

35-
std::size_t getLastWorkingThread() { return MThreadOrder.back(); }
36-
37-
private:
38-
std::vector<std::size_t> MThreadOrder;
39-
cl::sycl::buffer<T, Dim> MBuffer;
40-
std::size_t MBufferSize;
41-
std::mutex MMutex;
23+
ThreadPool MPool;
4224
};
4325

44-
class HostAccessorDeadLockTest : public ::testing::Test {};
45-
4626
TEST_F(HostAccessorDeadLockTest, CheckThreadOrder) {
47-
constexpr size_t size = 1024;
48-
constexpr size_t threadCount = 4;
27+
constexpr std::size_t size = 1024;
28+
constexpr std::size_t threadCount = 4;
4929
std::size_t data[size];
50-
TestDeadLock<std::size_t, 1> Task(data, size);
51-
Task.execute(threadCount);
52-
EXPECT_EQ(data[size - 1], Task.getLastWorkingThread());
30+
std::size_t lastThreadNum = -1, launchCount = 5;
31+
32+
{
33+
std::vector<std::size_t> threadOrder;
34+
cl::sycl::buffer<std::size_t, 1> buffer(data, size);
35+
std::mutex mutex;
36+
37+
auto testLambda = [&](std::size_t threadId) {
38+
auto acc = buffer.get_access<sycl_read_write>();
39+
for (std::size_t i = 0; i < size; ++i) {
40+
acc[i] = threadId;
41+
if (i == 0) {
42+
mutex.lock();
43+
threadOrder.push_back(threadId);
44+
mutex.unlock();
45+
}
46+
}
47+
};
48+
49+
for (std::size_t k = 0; k < launchCount; ++k) {
50+
MPool.clear();
51+
for (std::size_t i = 0; i < threadCount; ++i)
52+
MPool.enqueue(testLambda, i);
53+
MPool.wait();
54+
}
55+
56+
lastThreadNum = threadOrder.back();
57+
}
58+
59+
EXPECT_EQ(data[size - 1], lastThreadNum);
5360
}
5461
} // namespace

sycl/unittests/thread_safety/ThreadUtils.cpp

Lines changed: 0 additions & 47 deletions
This file was deleted.
Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,28 @@
11
#include <thread>
22
#include <vector>
33

4-
class ParallelTask;
5-
6-
class Thread {
7-
public:
8-
Thread(ParallelTask *Ptr) : MTask(Ptr) {}
9-
void start(size_t id);
10-
void wait();
11-
void body(size_t id);
12-
13-
private:
14-
std::thread MThread;
15-
ParallelTask *MTask;
16-
};
17-
184
class ThreadPool {
195
public:
20-
ThreadPool(ParallelTask *p);
21-
void initialize(int size);
22-
void start();
23-
void wait();
24-
25-
private:
26-
std::vector<Thread *> MThreadPool;
27-
ParallelTask *MTask;
28-
};
29-
30-
class ParallelTask {
31-
friend class ThreadPool;
32-
33-
public:
34-
ParallelTask() : MPool(this) {}
35-
36-
void execute(int threadCount);
37-
38-
virtual void taskBody(std::size_t id) = 0;
6+
void clear() { MThreadPool.clear(); }
7+
8+
template <typename Func, typename... Args>
9+
void enqueueNTimes(std::size_t N, Func &&func, Args &&... args) {
10+
for (std::size_t i = 0; i < N; ++i)
11+
enqueue(std::forward<Func>(func), std::forward<Args>(args)...);
12+
}
13+
14+
template <typename Func, typename... Args>
15+
void enqueue(Func &&func, Args &&... args) {
16+
MThreadPool.push_back(
17+
std::thread(std::forward<Func>(func), std::forward<Args>(args)...));
18+
}
19+
20+
void wait() {
21+
for (auto &t : MThreadPool) {
22+
t.join();
23+
}
24+
}
3925

4026
private:
41-
ThreadPool MPool;
27+
std::vector<std::thread> MThreadPool;
4228
};

0 commit comments

Comments
 (0)