Skip to content

Commit bfa6c1a

Browse files
committed
Fix allocation errors
1 parent 8b091c9 commit bfa6c1a

File tree

9 files changed

+81
-64
lines changed

9 files changed

+81
-64
lines changed

eventstreamrpc/include/aws/eventstreamrpc/EventStreamClient.h

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ namespace Aws
116116
class AWS_EVENTSTREAMRPC_API MessageAmendment final
117117
{
118118
public:
119-
MessageAmendment(const MessageAmendment &lhs) = default;
120-
MessageAmendment(MessageAmendment &&rhs) = default;
121-
MessageAmendment &operator=(const MessageAmendment &rhs) = default;
119+
MessageAmendment(const MessageAmendment &lhs);
120+
MessageAmendment(MessageAmendment &&rhs);
121+
MessageAmendment &operator=(const MessageAmendment &lhs);
122122
~MessageAmendment() noexcept;
123123
explicit MessageAmendment(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
124124
MessageAmendment(
@@ -152,8 +152,7 @@ namespace Aws
152152
ConnectionConfig() noexcept : m_clientBootstrap(nullptr), m_connectRequestCallback(nullptr) {}
153153
Crt::Optional<Crt::String> GetHostName() const noexcept { return m_hostName; }
154154
Crt::Optional<uint16_t> GetPort() const noexcept { return m_port; }
155-
Crt::Optional<Crt::Io::SocketDomain> GetSocketDomain() const noexcept { return m_socketDomain; }
156-
Crt::Optional<Crt::Io::SocketType> GetSocketType() const noexcept { return m_socketType; }
155+
Crt::Optional<Crt::Io::SocketOptions> GetSocketOptions() const noexcept { return m_socketOptions; }
157156
Crt::Optional<MessageAmendment> GetConnectAmendment() const noexcept { return m_connectAmendment; }
158157
Crt::Optional<Crt::Io::TlsConnectionOptions> GetTlsConnectionOptions() const noexcept
159158
{
@@ -163,21 +162,13 @@ namespace Aws
163162
OnMessageFlushCallback GetConnectRequestCallback() const noexcept { return m_connectRequestCallback; }
164163
ConnectMessageAmender GetConnectMessageAmender() const noexcept
165164
{
166-
if (m_connectAmendment.has_value())
167-
{
168-
return [&](void) -> const MessageAmendment & { return m_connectAmendment.value(); };
169-
}
170-
else
171-
{
172-
return nullptr;
173-
}
165+
return [&](void) -> const MessageAmendment & { return m_connectAmendment; };
174166
}
175167

176168
void SetHostName(Crt::String hostName) noexcept { m_hostName = hostName; }
177169
void SetPort(uint16_t port) noexcept { m_port = port; }
178-
void SetSocketDomain(Crt::Io::SocketDomain socketDomain) noexcept { m_socketDomain = socketDomain; }
179-
void SetSocketType(Crt::Io::SocketType socketType) noexcept { m_socketType = socketType; }
180-
void SetConnectAmendment(MessageAmendment connectAmendment) noexcept
170+
void SetSocketOptions(const Crt::Io::SocketOptions &socketOptions) noexcept { m_socketOptions = socketOptions; }
171+
void SetConnectAmendment(const MessageAmendment &connectAmendment) noexcept
181172
{
182173
m_connectAmendment = connectAmendment;
183174
}
@@ -197,11 +188,10 @@ namespace Aws
197188
protected:
198189
Crt::Optional<Crt::String> m_hostName;
199190
Crt::Optional<uint16_t> m_port;
200-
Crt::Optional<Crt::Io::SocketDomain> m_socketDomain;
201-
Crt::Optional<Crt::Io::SocketType> m_socketType;
191+
Crt::Optional<Crt::Io::SocketOptions> m_socketOptions;
202192
Crt::Optional<Crt::Io::TlsConnectionOptions> m_tlsConnectionOptions;
203193
Crt::Io::ClientBootstrap *m_clientBootstrap;
204-
Crt::Optional<MessageAmendment> m_connectAmendment;
194+
MessageAmendment m_connectAmendment;
205195
OnMessageFlushCallback m_connectRequestCallback;
206196
};
207197

@@ -309,7 +299,8 @@ namespace Aws
309299
std::future<RpcError> Connect(
310300
const ConnectionConfig &connectionOptions,
311301
ConnectionLifecycleHandler *connectionLifecycleHandler,
312-
ConnectMessageAmender connectMessageAmender) noexcept;
302+
ConnectMessageAmender connectMessageAmender,
303+
Crt::Io::ClientBootstrap &clientBootstrap) noexcept;
313304

314305
std::future<RpcError> SendPing(
315306
const Crt::List<EventStreamHeader> &headers,

eventstreamrpc/source/EventStreamClient.cpp

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ namespace Aws
3838
}
3939

4040
MessageAmendment::MessageAmendment(const Crt::ByteBuf &payload, Crt::Allocator *allocator) noexcept
41-
: m_headers(), m_payload(payload), m_allocator(allocator)
41+
: m_headers(), m_payload(Crt::ByteBufNewCopy(allocator, payload.buffer, payload.len)),
42+
m_allocator(allocator)
4243
{
4344
}
4445

@@ -58,8 +59,42 @@ namespace Aws
5859
const Crt::List<EventStreamHeader> &headers,
5960
Crt::Optional<Crt::ByteBuf> &payload,
6061
Crt::Allocator *allocator) noexcept
61-
: m_headers(headers), m_payload(payload), m_allocator(allocator)
62+
: m_headers(headers), m_payload(), m_allocator(allocator)
63+
{
64+
if (payload.has_value())
65+
{
66+
m_payload = Crt::ByteBufNewCopy(allocator, payload.value().buffer, payload.value().len);
67+
}
68+
}
69+
70+
MessageAmendment::MessageAmendment(const MessageAmendment &lhs)
71+
: m_headers(lhs.m_headers), m_payload(), m_allocator(lhs.m_allocator)
6272
{
73+
if (lhs.m_payload.has_value())
74+
{
75+
m_payload =
76+
Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_payload.value().buffer, lhs.m_payload.value().len);
77+
}
78+
}
79+
80+
MessageAmendment &MessageAmendment::operator=(const MessageAmendment &lhs)
81+
{
82+
m_headers = lhs.m_headers;
83+
if (lhs.m_payload.has_value())
84+
{
85+
m_payload =
86+
Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_payload.value().buffer, lhs.m_payload.value().len);
87+
}
88+
m_allocator = lhs.m_allocator;
89+
90+
return *this;
91+
}
92+
93+
MessageAmendment::MessageAmendment(MessageAmendment &&rhs)
94+
: m_headers(std::move(rhs.m_headers)), m_payload(rhs.m_payload), m_allocator(rhs.m_allocator)
95+
{
96+
rhs.m_allocator = nullptr;
97+
rhs.m_payload = Crt::Optional<Crt::ByteBuf>();
6398
}
6499

65100
Crt::List<EventStreamHeader> &MessageAmendment::GetHeaders() noexcept { return m_headers; }
@@ -237,7 +272,8 @@ namespace Aws
237272
std::future<RpcError> ClientConnection::Connect(
238273
const ConnectionConfig &connectionConfig,
239274
ConnectionLifecycleHandler *connectionLifecycleHandler,
240-
ConnectMessageAmender connectMessageAmender) noexcept
275+
ConnectMessageAmender connectMessageAmender,
276+
Crt::Io::ClientBootstrap &clientBootstrap) noexcept
241277
{
242278
m_connectAckedPromise.Reset();
243279
m_closedPromise = {};
@@ -248,9 +284,11 @@ namespace Aws
248284

249285
struct aws_event_stream_rpc_client_connection_options connOptions;
250286
AWS_ZERO_STRUCT(connOptions);
287+
Crt::String hostName;
251288
if (connectionConfig.GetHostName().has_value())
252289
{
253-
connOptions.host_name = connectionConfig.GetHostName().value().c_str();
290+
hostName = connectionConfig.GetHostName().value();
291+
connOptions.host_name = hostName.c_str();
254292
}
255293
else
256294
{
@@ -265,28 +303,17 @@ namespace Aws
265303
baseError = EVENT_STREAM_RPC_NULL_PARAMETER;
266304
}
267305

268-
if (connectionConfig.GetClientBootstrap() != nullptr)
269-
{
270-
connOptions.bootstrap = connectionConfig.GetClientBootstrap()->GetUnderlyingHandle();
271-
}
272-
else
273-
{
274-
baseError = EVENT_STREAM_RPC_NULL_PARAMETER;
275-
}
306+
connOptions.bootstrap = clientBootstrap.GetUnderlyingHandle();
276307

277308
if (baseError)
278309
{
279310
m_connectAckedPromise.SetValue({baseError, 0});
280311
return m_connectAckedPromise.GetFuture();
281312
}
282313

283-
if (connectionConfig.GetSocketDomain().has_value())
314+
if (connectionConfig.GetSocketOptions().has_value())
284315
{
285-
m_socketOptions.SetSocketDomain(connectionConfig.GetSocketDomain().value());
286-
}
287-
if (connectionConfig.GetSocketType().has_value())
288-
{
289-
m_socketOptions.SetSocketType(connectionConfig.GetSocketType().value());
316+
m_socketOptions = connectionConfig.GetSocketOptions().value();
290317
}
291318
connOptions.socket_options = &m_socketOptions.GetImpl();
292319

@@ -566,6 +593,14 @@ namespace Aws
566593
messageAmendmentHeaders.splice(messageAmendmentHeaders.end(), amenderHeaderList);
567594
}
568595
messageAmendment.SetPayload(connectAmendment.GetPayload());
596+
if(messageAmendment.GetPayload().has_value())
597+
{
598+
std::cout << "wtf is going on" << std::endl;
599+
std::cout << Crt::String(
600+
(char *)messageAmendment.GetPayload().value().buffer,
601+
messageAmendment.GetPayload().value().len)
602+
<< std::endl;
603+
}
569604
}
570605

571606
/* Send a CONNECT packet to the server. */

eventstreamrpc/tests/EventStreamClientTest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
7878
auto messageAmender = [&](void) -> MessageAmendment & { return connectionAmendment; };
7979

8080
ConnectionConfig config;
81-
config.SetClientBootstrap(&clientBootstrap);
8281
config.SetHostName(Aws::Crt::String("127.0.0.1"));
8382
config.SetPort(8033U);
8483

@@ -88,7 +87,7 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
8887
ClientConnection connection(allocator);
8988
connectionAmendment.AddHeader(EventStreamHeader(
9089
Aws::Crt::String("client-name"), Aws::Crt::String("accepted.testy_mc_testerson"), allocator));
91-
auto future = connection.Connect(config, &lifecycleHandler, messageAmender);
90+
auto future = connection.Connect(config, &lifecycleHandler, messageAmender, clientBootstrap);
9291
ASSERT_TRUE(future.get().baseStatus == EVENT_STREAM_RPC_SUCCESS);
9392
lifecycleHandler.WaitOnCondition([&]() { return lifecycleHandler.isConnected; });
9493
/* Test all protocol messages. */
@@ -107,7 +106,7 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
107106
{
108107
TestLifecycleHandler lifecycleHandler;
109108
ClientConnection connection(allocator);
110-
auto future = connection.Connect(config, &lifecycleHandler, messageAmender);
109+
auto future = connection.Connect(config, &lifecycleHandler, messageAmender, clientBootstrap);
111110
ASSERT_TRUE(future.get().baseStatus == EVENT_STREAM_RPC_CONNECTION_CLOSED_BEFORE_CONNACK);
112111
lifecycleHandler.WaitOnCondition([&]() { return lifecycleHandler.lastErrorCode == AWS_OP_SUCCESS; });
113112
}
@@ -118,7 +117,7 @@ static int s_TestEventStreamConnect(struct aws_allocator *allocator, void *ctx)
118117
ClientConnection connection(allocator);
119118
connectionAmendment.AddHeader(EventStreamHeader(
120119
Aws::Crt::String("client-name"), Aws::Crt::String("rejected.testy_mc_testerson"), allocator));
121-
auto future = connection.Connect(config, &lifecycleHandler, messageAmender);
120+
auto future = connection.Connect(config, &lifecycleHandler, messageAmender, clientBootstrap);
122121
ASSERT_TRUE(future.get().baseStatus == EVENT_STREAM_RPC_CONNECTION_CLOSED_BEFORE_CONNACK);
123122
lifecycleHandler.WaitOnCondition([&]() { return lifecycleHandler.lastErrorCode == AWS_OP_SUCCESS; });
124123
}

ipc/include/aws/greengrass/GreengrassCoreIpcClient.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace Aws
2828
Aws::Crt::Allocator *allocator = Aws::Crt::g_allocator) noexcept;
2929
std::future<RpcError> Connect(
3030
ConnectionLifecycleHandler &lifecycleHandler,
31-
ConnectionConfig connectionConfig = DefaultConnectionConfig()) noexcept;
31+
const ConnectionConfig &connectionConfig = DefaultConnectionConfig()) noexcept;
3232
void Close() noexcept;
3333
SubscribeToIoTCoreOperation NewSubscribeToIoTCore(SubscribeToIoTCoreStreamHandler &) noexcept;
3434
PublishToIoTCoreOperation NewPublishToIoTCore() noexcept;

ipc/source/DefaultConnectionConfig.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ namespace Aws
2424
}
2525

2626
m_port = 0;
27-
m_socketDomain = Crt::Io::SocketDomain::Local;
28-
m_socketType = Crt::Io::SocketType::Stream;
27+
Crt::Io::SocketOptions socketOptions;
28+
socketOptions.SetSocketDomain(Crt::Io::SocketDomain::Local);
29+
socketOptions.SetSocketType(Crt::Io::SocketType::Stream);
30+
m_socketOptions = std::move(socketOptions);
2931
}
3032
} // namespace Greengrass
3133
} // namespace Aws

ipc/source/GreengrassCoreIpcClient.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,10 @@ namespace Aws
3838

3939
std::future<RpcError> GreengrassCoreIpcClient::Connect(
4040
ConnectionLifecycleHandler &lifecycleHandler,
41-
ConnectionConfig connectionConfig) noexcept
41+
const ConnectionConfig& connectionConfig) noexcept
4242
{
43-
/* If a client bootstrap has not been set in the config, use the one from the client. */
44-
if (connectionConfig.GetClientBootstrap() == nullptr)
45-
{
46-
connectionConfig.SetClientBootstrap(&m_clientBootstrap);
47-
}
48-
4943
return m_connection.Connect(
50-
connectionConfig, &lifecycleHandler, connectionConfig.GetConnectMessageAmender());
44+
connectionConfig, &lifecycleHandler, connectionConfig.GetConnectMessageAmender(), m_clientBootstrap);
5145
}
5246

5347
void GreengrassCoreIpcClient::Close() noexcept { m_connection.Close(); }

ipc/tests/DefaultConnectionConfig.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ namespace Awstest
66
{
77
m_hostName = Aws::Crt::String("127.0.0.1");
88
m_port = 8033;
9-
m_socketDomain = Aws::Crt::Io::SocketDomain::IPv4;
10-
m_socketType = Aws::Crt::Io::SocketType::Stream;
9+
Aws::Crt::Io::SocketOptions socketOptions;
10+
socketOptions.SetSocketDomain(Aws::Crt::Io::SocketDomain::IPv4);
11+
socketOptions.SetSocketType(Aws::Crt::Io::SocketType::Stream);
12+
m_socketOptions = std::move(socketOptions);
1113
}
1214
} // namespace Awstest

ipc/tests/EchoTestRpcClient.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,9 @@ namespace Awstest
1414

1515
std::future<RpcError> EchoTestRpcClient::Connect(
1616
ConnectionLifecycleHandler &lifecycleHandler,
17-
ConnectionConfig connectionConfig) noexcept
17+
const ConnectionConfig &connectionConfig) noexcept
1818
{
19-
/* If a client bootstrap has not been set in the config, use the one from the client. */
20-
if (connectionConfig.GetClientBootstrap() == nullptr)
21-
{
22-
connectionConfig.SetClientBootstrap(&m_clientBootstrap);
23-
}
24-
25-
return m_connection.Connect(connectionConfig, &lifecycleHandler, connectionConfig.GetConnectMessageAmender());
19+
return m_connection.Connect(connectionConfig, &lifecycleHandler, connectionConfig.GetConnectMessageAmender(), m_clientBootstrap);
2620
}
2721

2822
void EchoTestRpcClient::Close() noexcept { m_connection.Close(); }

ipc/tests/include/awstest/EchoTestRpcClient.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace Awstest
2626
Aws::Crt::Allocator *allocator = Aws::Crt::g_allocator) noexcept;
2727
std::future<RpcError> Connect(
2828
ConnectionLifecycleHandler &lifecycleHandler,
29-
ConnectionConfig connectionConfig = DefaultConnectionConfig()) noexcept;
29+
const ConnectionConfig &connectionConfig = DefaultConnectionConfig()) noexcept;
3030
void Close() noexcept;
3131
GetAllProductsOperation NewGetAllProducts() noexcept;
3232
CauseServiceErrorOperation NewCauseServiceError() noexcept;

0 commit comments

Comments
 (0)