Skip to content

Commit bc05eca

Browse files
committed
Use conditional variable to prevent closing prematurely
1 parent f101660 commit bc05eca

File tree

3 files changed

+39
-37
lines changed

3 files changed

+39
-37
lines changed

eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,11 +551,9 @@ namespace Aws
551551
std::mutex m_continuationMutex;
552552
bool m_resultReceived;
553553
std::promise<TaggedResult> m_initialResponsePromise;
554-
CloseState m_closeState;
555-
std::atomic_int m_numCloses;
554+
std::atomic_int m_expectedCloses;
556555
std::atomic_bool m_streamClosedCalled;
557-
std::promise<void> m_closedPromise;
558-
std::condition_variable m_promiseReady;
556+
std::condition_variable m_closeReady;
559557
};
560558

561559
class AWS_EVENTSTREAMRPC_API ClientConnection final

eventstream_rpc/source/EventStreamClient.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ namespace Aws
10681068
Crt::Allocator *allocator) noexcept
10691069
: m_operationModelContext(operationModelContext), m_messageCount(0), m_allocator(allocator),
10701070
m_streamHandler(streamHandler), m_clientContinuation(connection.NewStream(*this)),
1071-
m_closeState(WONT_CLOSE), m_numCloses(0), m_streamClosedCalled(false)
1071+
m_expectedCloses(0), m_streamClosedCalled(false)
10721072
{
10731073
}
10741074

@@ -1083,10 +1083,8 @@ namespace Aws
10831083
ClientOperation::~ClientOperation() noexcept
10841084
{
10851085
Close().wait();
1086-
if (m_numCloses.load() > 0)
1087-
{
1088-
m_closedPromise.get_future().wait();
1089-
}
1086+
std::unique_lock<std::mutex> lock(m_continuationMutex);
1087+
m_closeReady.wait(lock, [this]{return m_expectedCloses.load() == 0;});
10901088
}
10911089

10921090
TaggedResult::TaggedResult(Crt::ScopedResource<AbstractShapeBase> operationResponse) noexcept
@@ -1326,8 +1324,7 @@ namespace Aws
13261324
if (messageFlags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM)
13271325
{
13281326
const std::lock_guard<std::mutex> lock(m_continuationMutex);
1329-
m_closeState = WILL_CLOSE;
1330-
m_numCloses.fetch_add(1);
1327+
m_expectedCloses.fetch_add(1);
13311328
}
13321329

13331330
m_messageCount += 1;
@@ -1434,12 +1431,10 @@ namespace Aws
14341431
/* Promises must be reset in case the client would like to send a subsequent request with the same
14351432
* `ClientOperation`. */
14361433
m_initialResponsePromise = {};
1437-
m_closedPromise = {};
1434+
//m_closedPromise = {};
14381435
{
14391436
const std::lock_guard<std::mutex> lock(m_continuationMutex);
14401437
m_resultReceived = false;
1441-
if (m_closeState != WILL_CLOSE)
1442-
m_closeState = WONT_CLOSE;
14431438
}
14441439

14451440
Crt::List<EventStreamHeader> headers;
@@ -1461,29 +1456,29 @@ namespace Aws
14611456

14621457
void ClientOperation::OnContinuationClosed()
14631458
{
1464-
std::unique_lock<std::mutex> lock(m_continuationMutex);
1459+
const std::lock_guard<std::mutex> lock(m_continuationMutex);
14651460
if (!m_resultReceived)
14661461
{
14671462
m_initialResponsePromise.set_value(TaggedResult({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0}));
14681463
m_resultReceived = true;
14691464
}
14701465

1471-
m_numCloses.fetch_sub(1);
1472-
if (m_numCloses.load() == 0 && m_closeState != ALREADY_CLOSED)
1466+
if (m_expectedCloses.load() > 0)
14731467
{
1474-
m_closedPromise.set_value();
1475-
m_closeState = ALREADY_CLOSED;
1468+
m_expectedCloses.fetch_sub(1);
14761469
if (!m_streamClosedCalled.load() && m_streamHandler)
14771470
{
14781471
m_streamHandler->OnStreamClosed();
14791472
m_streamClosedCalled.store(true);
14801473
}
1474+
m_closeReady.notify_one();
14811475
}
14821476
}
14831477

14841478
std::future<RpcError> ClientOperation::Close(OnMessageFlushCallback onMessageFlushCallback) noexcept
14851479
{
1486-
if (m_numCloses.load() > 0 || m_clientContinuation.IsClosed())
1480+
const std::lock_guard<std::mutex> lock(m_continuationMutex);
1481+
if (m_expectedCloses.load() > 0 || m_clientContinuation.IsClosed())
14871482
{
14881483
std::promise<RpcError> errorPromise;
14891484
errorPromise.set_value({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0});
@@ -1530,9 +1525,7 @@ namespace Aws
15301525
}
15311526
else
15321527
{
1533-
const std::lock_guard<std::mutex> lock(m_continuationMutex);
1534-
m_closeState = WILL_CLOSE;
1535-
m_numCloses.fetch_add(1);
1528+
m_expectedCloses.fetch_add(1);
15361529
return callbackContainer->onFlushPromise.get_future();
15371530
}
15381531

eventstream_rpc/tests/EventStreamClientTest.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,9 @@ class ThreadPool
380380
ThreadPool(int numThreads = std::thread::hardware_concurrency()) noexcept
381381
: m_numThreads(numThreads), m_stopped(false)
382382
{
383-
m_terminatePool = false;
384383
for (int i = 0; i < numThreads; i++)
385384
{
386-
m_threadPool.push_back(std::thread(&ThreadPool::TaskConsumer, this));
385+
m_threadPool.push_back(std::thread(&ThreadPool::TaskWorker, this));
387386
}
388387
m_taskErrorCode = AWS_OP_SUCCESS;
389388
}
@@ -399,11 +398,6 @@ class ThreadPool
399398

400399
void Shutdown() noexcept
401400
{
402-
{
403-
std::unique_lock<std::mutex> lock(m_poolMutex);
404-
m_terminatePool = true;
405-
}
406-
407401
/* Wake up all threads so that they can complete. */
408402
m_taskReady.notify_all();
409403

@@ -425,7 +419,14 @@ class ThreadPool
425419
{
426420
if (m_queue.empty())
427421
{
422+
m_stopped = true;
428423
m_queueMutex.unlock();
424+
/* Wait for all threads to complete. */
425+
m_taskReady.notify_all();
426+
for (std::thread &thread : m_threadPool)
427+
{
428+
thread.join();
429+
}
429430
break;
430431
}
431432
else
@@ -452,30 +453,30 @@ class ThreadPool
452453
std::mutex m_queueMutex;
453454
std::queue<std::function<int()>> m_queue;
454455
std::condition_variable m_taskReady;
455-
bool m_terminatePool;
456456
int m_taskErrorCode;
457457
bool m_stopped;
458458

459-
void TaskConsumer()
459+
void TaskWorker()
460460
{
461461
while (true)
462462
{
463463
{
464464
std::unique_lock<std::mutex> lock(m_queueMutex);
465465

466-
m_taskReady.wait(lock, [this] { return !m_queue.empty() || m_terminatePool; });
466+
m_taskReady.wait(lock, [this] { return !m_queue.empty() || m_stopped; });
467467
if (!m_queue.empty())
468468
{
469469
std::function<int()> currentJob = m_queue.front();
470+
m_queue.pop();
471+
lock.unlock();
470472
if (currentJob)
471473
{
472474
int errorCode = currentJob();
473475
if (errorCode)
474476
m_taskErrorCode = errorCode;
475477
}
476-
m_queue.pop();
477478
}
478-
if (m_terminatePool)
479+
else if (m_stopped)
479480
{
480481
break;
481482
}
@@ -504,7 +505,17 @@ static int s_TestStressClient(struct aws_allocator *allocator, void *ctx)
504505
messageData.SetStringMessage(expectedMessage);
505506
echoMessageRequest.SetMessage(messageData);
506507
auto requestFuture = echoMessage.Activate(echoMessageRequest, s_onMessageFlush);
507-
requestFuture.wait();
508+
std::future_status status = requestFuture.wait_for(std::chrono::seconds(1));
509+
if (status != std::future_status::ready)
510+
{
511+
return AWS_OP_SUCCESS;
512+
}
513+
auto resultFuture = echoMessage.GetResult();
514+
status = resultFuture.wait_for(std::chrono::seconds(1));
515+
if (status != std::future_status::ready)
516+
{
517+
return AWS_OP_SUCCESS;
518+
}
508519
auto result = echoMessage.GetResult().get();
509520
ASSERT_TRUE(result);
510521
auto response = result.GetOperationResponse();
@@ -514,7 +525,7 @@ static int s_TestStressClient(struct aws_allocator *allocator, void *ctx)
514525
return AWS_OP_SUCCESS;
515526
};
516527

517-
for (int i = 0; i < 100; i++)
528+
for (int i = 0; i < 10000; i++)
518529
threadPool.AddTask(invokeOperation);
519530

520531
threadPool.BlockUntilTasksFinish();

0 commit comments

Comments
 (0)