Skip to content

Commit 39b1d18

Browse files
committed
Before changing ClientOperation::Close behavior
1 parent 4c60e50 commit 39b1d18

File tree

7 files changed

+89
-132
lines changed

7 files changed

+89
-132
lines changed

eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ namespace Aws
266266
EventStreamRpcStatusCode baseStatus;
267267
int crtError;
268268
operator bool() const noexcept { return baseStatus == EVENT_STREAM_RPC_SUCCESS; }
269+
Crt::String ErrorToString();
269270
};
270271

271272
template <class T> class ProtectedPromise
@@ -407,6 +408,7 @@ namespace Aws
407408
uint32_t messageFlags,
408409
OnMessageFlushCallback onMessageFlushCallback) noexcept;
409410
bool IsClosed() noexcept;
411+
void Close() noexcept;
410412
std::future<RpcError> SendMessage(
411413
const Crt::List<EventStreamHeader> &headers,
412414
const Crt::Optional<Crt::ByteBuf> &payload,
@@ -436,7 +438,7 @@ namespace Aws
436438
virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const = 0;
437439
virtual Crt::String GetModelName() const noexcept = 0;
438440

439-
private:
441+
protected:
440442
Crt::Allocator *m_allocator;
441443
};
442444

@@ -635,8 +637,7 @@ namespace Aws
635637
ClientContinuation m_clientContinuation;
636638
ProtectedPromise<TaggedResult> m_initialResponsePromise;
637639
/* ProtectedPromise not necessary because it's only ever being set by one thread. */
638-
std::promise<void> m_closedPromise;
639-
std::atomic<bool> m_isClosed;
640+
std::promise<RpcError> m_closedPromise;
640641
};
641642
} // namespace Eventstreamrpc
642643
} // namespace Aws

eventstream_rpc/source/EventStreamClient.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ namespace Aws
494494
return m_closedPromise.get_future();
495495
} else {
496496
std::promise<RpcError> closedPromise;
497-
closedPromise.set_value({EVENT_STREAM_RPC_CONNECTION_CLOSED, 0});
497+
closedPromise.set_value({EVENT_STREAM_RPC_SUCCESS, 0});
498498
return closedPromise.get_future();
499499
}
500500
}
@@ -749,7 +749,7 @@ namespace Aws
749749

750750
void AbstractShapeBase::s_customDeleter(AbstractShapeBase *shape) noexcept
751751
{
752-
if (shape->m_allocator)
752+
if (shape->m_allocator != nullptr)
753753
Crt::Delete<AbstractShapeBase>(shape, shape->m_allocator);
754754
}
755755

@@ -768,7 +768,13 @@ namespace Aws
768768
aws_event_stream_rpc_client_connection_new_stream(connection->m_underlyingConnection, &options);
769769
}
770770

771-
ClientContinuation::~ClientContinuation() noexcept {}
771+
void ClientContinuation::Close() noexcept {
772+
aws_event_stream_rpc_client_continuation_release(m_continuationToken);
773+
}
774+
775+
ClientContinuation::~ClientContinuation() noexcept {
776+
Close();
777+
}
772778

773779
void ClientContinuation::s_onContinuationMessage(
774780
struct aws_event_stream_rpc_client_continuation_token *continuationToken,
@@ -962,7 +968,7 @@ namespace Aws
962968
const OperationModelContext &operationModelContext,
963969
Crt::Allocator *allocator) noexcept
964970
: m_operationModelContext(operationModelContext), m_messageCount(0), m_allocator(allocator),
965-
m_streamHandler(streamHandler), m_clientContinuation(connection.NewStream(*this)), m_isClosed(false)
971+
m_streamHandler(streamHandler), m_clientContinuation(connection.NewStream(*this))
966972
{
967973
}
968974

@@ -973,7 +979,6 @@ namespace Aws
973979
m_initialResponsePromise(std::move(rhs.m_initialResponsePromise)),
974980
m_closedPromise(std::move(rhs.m_closedPromise))
975981
{
976-
m_isClosed.store(rhs.m_isClosed.load());
977982
}
978983

979984
ClientOperation::~ClientOperation() noexcept { Close().wait(); }
@@ -1146,11 +1151,6 @@ namespace Aws
11461151
{
11471152
bool streamAlreadyTerminated = messageFlags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM;
11481153

1149-
if (streamAlreadyTerminated)
1150-
{
1151-
m_isClosed.store(true);
1152-
}
1153-
11541154
Crt::StringView payloadStringView;
11551155
if (payload.has_value())
11561156
{
@@ -1169,7 +1169,7 @@ namespace Aws
11691169
m_initialResponsePromise.SetValue(std::move(taggedResult));
11701170
/* Close the stream unless the server already closed it for us. This condition is checked
11711171
* so that TERMINATE_STREAM messages aren't resent by the client. */
1172-
if (!streamAlreadyTerminated && !m_isClosed.exchange(true))
1172+
if (!streamAlreadyTerminated && !m_clientContinuation.IsClosed())
11731173
{
11741174
Close().wait();
11751175
}
@@ -1179,7 +1179,7 @@ namespace Aws
11791179
bool shouldCloseNow = true;
11801180
if (m_streamHandler)
11811181
shouldCloseNow = m_streamHandler->OnStreamError(std::move(error), {EVENT_STREAM_RPC_SUCCESS, 0});
1182-
if (!streamAlreadyTerminated && shouldCloseNow && !m_isClosed.exchange(true))
1182+
if (!streamAlreadyTerminated && shouldCloseNow && !m_clientContinuation.IsClosed())
11831183
{
11841184
Close().wait();
11851185
}
@@ -1222,10 +1222,11 @@ namespace Aws
12221222
contentHeader = GetHeaderByName(headers, Crt::String(CONTENT_TYPE_HEADER));
12231223
if (contentHeader == nullptr)
12241224
{
1225+
/* TODO: Log an error. */
12251226
/* Missing required content type header. */
12261227
errorCode = EVENT_STREAM_RPC_UNMAPPED_DATA;
12271228
}
1228-
else if (contentHeader->GetValueAsString(contentType) && contentType != CONTENT_TYPE_APPLICATION_JSON)
1229+
else if (contentHeader == nullptr && contentHeader->GetValueAsString(contentType) && contentType != CONTENT_TYPE_APPLICATION_JSON)
12291230
{
12301231
errorCode = EVENT_STREAM_RPC_UNSUPPORTED_CONTENT_TYPE;
12311232
}
@@ -1257,7 +1258,7 @@ namespace Aws
12571258
bool shouldClose = true;
12581259
if (m_streamHandler)
12591260
shouldClose = m_streamHandler->OnStreamError(nullptr, {errorCode, 0});
1260-
if (!m_isClosed.load() && shouldClose)
1261+
if (!m_clientContinuation.IsClosed() && shouldClose)
12611262
{
12621263
Close().wait();
12631264
}
@@ -1298,7 +1299,7 @@ namespace Aws
12981299
* potentially set by `OnContinuationMessage` will just do nothing. */
12991300
m_initialResponsePromise.SetValue(TaggedResult({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0}));
13001301

1301-
m_closedPromise.set_value();
1302+
m_closedPromise.set_value({EVENT_STREAM_RPC_SUCCESS, 0});
13021303

13031304
if (m_streamHandler)
13041305
{
@@ -1308,20 +1309,22 @@ namespace Aws
13081309

13091310
std::future<RpcError> ClientOperation::Close(OnMessageFlushCallback onMessageFlushCallback) noexcept
13101311
{
1311-
if (m_isClosed.load())
1312+
if (m_clientContinuation.IsClosed())
13121313
{
13131314
std::promise<RpcError> alreadyClosedPromise;
13141315
alreadyClosedPromise.set_value({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0});
13151316
return alreadyClosedPromise.get_future();
13161317
}
13171318
else
13181319
{
1319-
return m_clientContinuation.SendMessage(
1320+
auto rpcError = m_clientContinuation.SendMessage(
13201321
Crt::List<EventStreamHeader>(),
13211322
Crt::Optional<Crt::ByteBuf>(),
13221323
AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE,
13231324
AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM,
1324-
onMessageFlushCallback);
1325+
onMessageFlushCallback).get();
1326+
m_clientContinuation.Close();
1327+
return m_closedPromise.get_future();
13251328
}
13261329
}
13271330

eventstream_rpc/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ aws_use_package(aws-crt-cpp)
4444
aws_use_package(EventstreamRpc-cpp)
4545

4646
add_test_case(EventStreamConnect)
47+
add_test_case(EchoOperation)
4748
add_test_case(OperateWhileDisconnected)
4849
generate_cpp_test_driver(${TEST_BINARY_NAME})
4950
target_include_directories(${TEST_BINARY_NAME} PUBLIC

eventstream_rpc/tests/EchoTestRpcClient.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace Awstest
1010
{
1111
m_echoTestRpcServiceModel.AssignModelNameToErrorResponse(
1212
Aws::Crt::String("awstest#ServiceError"), ServiceError::s_allocateFromPayload);
13+
aws_event_stream_library_init(m_allocator);
1314
}
1415

1516
std::future<RpcError> EchoTestRpcClient::Connect(
@@ -21,7 +22,7 @@ namespace Awstest
2122

2223
std::future<RpcError> EchoTestRpcClient::Close() noexcept { return m_connection.Close(); }
2324

24-
EchoTestRpcClient::~EchoTestRpcClient() noexcept { Close().wait(); }
25+
EchoTestRpcClient::~EchoTestRpcClient() noexcept { Close().wait(); aws_event_stream_library_clean_up(); }
2526

2627
GetAllProductsOperation EchoTestRpcClient::NewGetAllProducts() noexcept
2728
{

eventstream_rpc/tests/EventStreamClientTest.cpp

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using namespace Aws::Eventstreamrpc;
1515
using namespace Awstest;
1616

1717
static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx);
18+
static int s_TestEchoOperation(struct aws_allocator *allocator, void *ctx);
1819
static int s_TestOperationWhileDisconnected(struct aws_allocator *allocator, void *ctx);
1920

2021
class TestLifecycleHandler : public ConnectionLifecycleHandler
@@ -75,7 +76,7 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
7576
Aws::Crt::Io::EventLoopGroup eventLoopGroup(0, allocator);
7677
ASSERT_TRUE(eventLoopGroup);
7778

78-
Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 8, 30, allocator);
79+
Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 1, 1, allocator);
7980
ASSERT_TRUE(defaultHostResolver);
8081

8182
Aws::Crt::Io::ClientBootstrap clientBootstrap(eventLoopGroup, defaultHostResolver, allocator);
@@ -87,6 +88,7 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
8788
ConnectionConfig config;
8889
config.SetHostName(Aws::Crt::String("127.0.0.1"));
8990
config.SetPort(8033U);
91+
aws_event_stream_library_init(allocator);
9092

9193
/* Empty amendment headers. */
9294
{
@@ -115,6 +117,8 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
115117
ASSERT_TRUE(connectedStatus.get().baseStatus == EVENT_STREAM_RPC_SUCCESS);
116118
client.Close().wait();
117119
}
120+
121+
//aws_event_stream_library_clean_up();
118122
}
119123

120124
return AWS_OP_SUCCESS;
@@ -134,7 +138,7 @@ static int s_TestOperationWhileDisconnected(struct aws_allocator *allocator, voi
134138
Aws::Crt::Io::EventLoopGroup eventLoopGroup(0, allocator);
135139
ASSERT_TRUE(eventLoopGroup);
136140

137-
Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 8, 30, allocator);
141+
Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 1, 1, allocator);
138142
ASSERT_TRUE(defaultHostResolver);
139143

140144
Aws::Crt::Io::ClientBootstrap clientBootstrap(eventLoopGroup, defaultHostResolver, allocator);
@@ -148,30 +152,73 @@ static int s_TestOperationWhileDisconnected(struct aws_allocator *allocator, voi
148152
Awstest::EchoTestRpcClient client(clientBootstrap, allocator);
149153
auto connectedStatus = client.Connect(lifecycleHandler);
150154
ASSERT_TRUE(connectedStatus.get().baseStatus == EVENT_STREAM_RPC_SUCCESS);
155+
ASSERT_TRUE(client.Close().get().baseStatus == EVENT_STREAM_RPC_SUCCESS);
151156
auto echoMessage = client.NewEchoMessage();
152157
EchoMessageRequest echoMessageRequest;
153158
MessageData messageData;
154-
messageData.SetStringMessage("wtf");
159+
Aws::Crt::String expectedMessage("wtf");
160+
messageData.SetStringMessage(expectedMessage);
155161
echoMessageRequest.SetMessage(messageData);
156162
auto requestFuture = echoMessage.Activate(echoMessageRequest, s_onMessageFlush);
157163
ASSERT_TRUE(requestFuture.get().baseStatus == EVENT_STREAM_RPC_CONNECTION_CLOSED);
158164
auto result = echoMessage.GetResult().get();
159-
if (result) {
160-
std::cout << result.GetOperationResponse() << std::endl;
161-
} else {
162-
auto errorType = result.GetResultType();
163-
if (errorType == RPC_ERROR) {
164-
std::cout << result.GetRpcError() << std::endl;
165-
} else {
166-
auto *error = result.GetOperationError();
167-
(void)error;
168-
}
169-
}
165+
ASSERT_TRUE(result);
166+
auto error = result.GetRpcError();
167+
ASSERT_TRUE(error.baseStatus == EVENT_STREAM_RPC_CONNECTION_CLOSED);
168+
}
169+
}
170+
171+
return AWS_OP_SUCCESS;
172+
}
173+
174+
static int s_TestEchoOperation(struct aws_allocator *allocator, void *ctx)
175+
{
176+
(void)ctx;
177+
{
178+
Aws::Crt::ApiHandle apiHandle(allocator);
179+
Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient();
180+
Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TlsMode::CLIENT, allocator);
181+
ASSERT_TRUE(tlsContext);
182+
183+
Aws::Crt::Io::TlsConnectionOptions tlsConnectionOptions = tlsContext.NewConnectionOptions();
184+
185+
Aws::Crt::Io::EventLoopGroup eventLoopGroup(0, allocator);
186+
ASSERT_TRUE(eventLoopGroup);
187+
188+
Aws::Crt::Io::DefaultHostResolver defaultHostResolver(eventLoopGroup, 8, 30, allocator);
189+
ASSERT_TRUE(defaultHostResolver);
190+
191+
Aws::Crt::Io::ClientBootstrap clientBootstrap(eventLoopGroup, defaultHostResolver, allocator);
192+
ASSERT_TRUE(clientBootstrap);
193+
194+
clientBootstrap.EnableBlockingShutdown();
195+
ConnectionLifecycleHandler lifecycleHandler;
196+
Awstest::EchoTestRpcClient client(clientBootstrap, allocator);
197+
198+
{
199+
//aws_event_stream_library_init(allocator);
200+
auto connectedStatus = client.Connect(lifecycleHandler);
201+
ASSERT_TRUE(connectedStatus.get().baseStatus == EVENT_STREAM_RPC_SUCCESS);
202+
auto echoMessage = client.NewEchoMessage();
203+
EchoMessageRequest echoMessageRequest;
204+
MessageData messageData;
205+
Aws::Crt::String expectedMessage("wtf");
206+
messageData.SetStringMessage(expectedMessage);
207+
echoMessageRequest.SetMessage(messageData);
208+
auto requestFuture = echoMessage.Activate(echoMessageRequest, s_onMessageFlush);
209+
requestFuture.wait();
210+
auto result = echoMessage.GetResult().get();
211+
ASSERT_TRUE(result);
212+
auto response = result.GetOperationResponse();
213+
ASSERT_NOT_NULL(response);
214+
ASSERT_TRUE(response->GetMessage().value().GetStringMessage().value() == expectedMessage);
215+
//aws_event_stream_library_clean_up();
170216
}
171217
}
172218

173219
return AWS_OP_SUCCESS;
174220
}
175221

176222
AWS_TEST_CASE(EventStreamConnect, s_TestEventStreamConnect)
223+
AWS_TEST_CASE(EchoOperation, s_TestEchoOperation)
177224
AWS_TEST_CASE(OperateWhileDisconnected, s_TestOperationWhileDisconnected)

0 commit comments

Comments
 (0)