|
1 |
| -//==---- SchedulerThreadSafety.cpp --- Thread Safety unit tests ------------==// |
| 1 | +//==----- HostAccessorDeadLock.cpp --- Thread Safety unit tests ------------==// |
2 | 2 | //
|
3 | 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
4 | 4 | // See https://llvm.org/LICENSE.txt for license information.
|
|
15 | 15 | namespace {
|
16 | 16 | constexpr auto sycl_read_write = cl::sycl::access::mode::read_write;
|
17 | 17 |
|
18 |
| -template <typename T, int Dim> class TestDeadLock : public ParallelTask { |
19 |
| -public: |
20 |
| - TestDeadLock(T *Data, std::size_t Size) |
21 |
| - : MBuffer(Data, cl::sycl::range<Dim>(Size)), MBufferSize(Size) {} |
22 |
| - |
23 |
| - void taskBody(size_t ThreadId) { |
24 |
| - auto acc = MBuffer.template get_access<sycl_read_write>(); |
25 |
| - for (std::size_t i = 0; i < MBufferSize; ++i) { |
26 |
| - acc[i] = ThreadId; |
27 |
| - if (i == 0) { |
28 |
| - MMutex.lock(); |
29 |
| - MThreadOrder.push_back(ThreadId); |
30 |
| - MMutex.unlock(); |
31 |
| - } |
32 |
| - } |
33 |
| - } |
| 18 | +class HostAccessorDeadLockTest : public ::testing::Test { |
| 19 | +protected: |
| 20 | + HostAccessorDeadLockTest() : MPool() {} |
| 21 | + ~HostAccessorDeadLockTest() override = default; |
34 | 22 |
|
35 |
| - std::size_t getLastWorkingThread() { return MThreadOrder.back(); } |
36 |
| - |
37 |
| -private: |
38 |
| - std::vector<std::size_t> MThreadOrder; |
39 |
| - cl::sycl::buffer<T, Dim> MBuffer; |
40 |
| - std::size_t MBufferSize; |
41 |
| - std::mutex MMutex; |
| 23 | + ThreadPool MPool; |
42 | 24 | };
|
43 | 25 |
|
44 |
| -class HostAccessorDeadLockTest : public ::testing::Test {}; |
45 |
| - |
46 | 26 | TEST_F(HostAccessorDeadLockTest, CheckThreadOrder) {
|
47 |
| - constexpr size_t size = 1024; |
48 |
| - constexpr size_t threadCount = 4; |
| 27 | + constexpr std::size_t size = 1024; |
| 28 | + constexpr std::size_t threadCount = 4; |
49 | 29 | std::size_t data[size];
|
50 |
| - TestDeadLock<std::size_t, 1> Task(data, size); |
51 |
| - Task.execute(threadCount); |
52 |
| - EXPECT_EQ(data[size - 1], Task.getLastWorkingThread()); |
| 30 | + std::size_t lastThreadNum = -1, launchCount = 5; |
| 31 | + |
| 32 | + { |
| 33 | + std::vector<std::size_t> threadOrder; |
| 34 | + cl::sycl::buffer<std::size_t, 1> buffer(data, size); |
| 35 | + std::mutex mutex; |
| 36 | + |
| 37 | + auto testLambda = [&](std::size_t threadId) { |
| 38 | + auto acc = buffer.get_access<sycl_read_write>(); |
| 39 | + for (std::size_t i = 0; i < size; ++i) { |
| 40 | + acc[i] = threadId; |
| 41 | + if (i == 0) { |
| 42 | + mutex.lock(); |
| 43 | + threadOrder.push_back(threadId); |
| 44 | + mutex.unlock(); |
| 45 | + } |
| 46 | + } |
| 47 | + }; |
| 48 | + |
| 49 | + for (std::size_t k = 0; k < launchCount; ++k) { |
| 50 | + MPool.clear(); |
| 51 | + for (std::size_t i = 0; i < threadCount; ++i) |
| 52 | + MPool.enqueue(testLambda, i); |
| 53 | + MPool.wait(); |
| 54 | + } |
| 55 | + |
| 56 | + lastThreadNum = threadOrder.back(); |
| 57 | + } |
| 58 | + |
| 59 | + EXPECT_EQ(data[size - 1], lastThreadNum); |
53 | 60 | }
|
54 | 61 | } // namespace
|
0 commit comments