Skip to content

Commit 12db79b

Browse files
Addressed PR feedback, fixed move semantics errors.
1 parent ae282b9 commit 12db79b

File tree

11 files changed

+246
-163
lines changed

11 files changed

+246
-163
lines changed

include/aws/crt/StlAllocator.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
/*
3+
* Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License").
6+
* You may not use this file except in compliance with the License.
7+
* A copy of the License is located at
8+
*
9+
* http://aws.amazon.com/apache2.0
10+
*
11+
* or in the "license" file accompanying this file. This file is distributed
12+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
13+
* express or implied. See the License for the specific language governing
14+
* permissions and limitations under the License.
15+
*/
16+
17+
#include <aws/crt/Types.h>
18+
19+
#include <memory>
20+
21+
namespace Aws
22+
{
23+
namespace Crt
24+
{
25+
extern Allocator* g_allocator;
26+
27+
template<typename T>
28+
class StlAllocator final : public std::allocator<T>
29+
{
30+
public:
31+
using Base = std::allocator<T>;
32+
33+
StlAllocator() noexcept : Base() {}
34+
StlAllocator(const StlAllocator<T>& a) noexcept : Base(a) {}
35+
36+
template<class U>
37+
StlAllocator(const StlAllocator<U>& a) noexcept : Base(a) {}
38+
39+
~StlAllocator() {}
40+
41+
using sizeType = std::size_t;
42+
43+
template<typename U>
44+
struct rebind
45+
{
46+
typedef StlAllocator<U> other;
47+
};
48+
49+
typename Base::pointer allocate(size_type n, const void* hint = nullptr)
50+
{
51+
(void)hint;
52+
assert(g_allocator);
53+
return reinterpret_cast<typename Base::pointer>(aws_mem_acquire(g_allocator, n * sizeof(T)));
54+
}
55+
56+
void deallocate(typename Base::pointer p, size_type)
57+
{
58+
assert(g_allocator);
59+
aws_mem_release(g_allocator, p);
60+
}
61+
};
62+
}
63+
}

include/aws/crt/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace Aws
3535
AWS_CRT_CPP_API Allocator* DefaultAllocator() noexcept;
3636
AWS_CRT_CPP_API ByteBuf ByteBufFromCString(const char* str) noexcept;
3737
AWS_CRT_CPP_API ByteBuf ByteBufFromArray(const uint8_t *array, size_t len) noexcept;
38-
38+
3939
namespace Io
4040
{
4141
using SocketOptions = aws_socket_options;

include/aws/crt/io/TlsOptions.h

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace Aws
2828
{
2929
using TlsConnectionOptions = aws_tls_connection_options;
3030

31-
enum class TLSMode
31+
enum class TlsMode
3232
{
3333
CLIENT,
3434
SERVER,
@@ -41,15 +41,58 @@ namespace Aws
4141
TlsContextOptions(const TlsContextOptions&) noexcept = default;
4242
TlsContextOptions& operator=(const TlsContextOptions&) noexcept = default;
4343

44+
/**
45+
* Initializes TlsContextOptions with secure by default options, with
46+
* no client certificates.
47+
*/
4448
static TlsContextOptions InitDefaultClient() noexcept;
49+
/**
50+
* Initializes TlsContextOptions with secure by default options, with
51+
* client certificate and private key. These are paths to a file on disk. These
52+
* strings must remain in memory for the lifetime of the returned object. These files
53+
* must be in the PEM format.
54+
*/
4555
static TlsContextOptions InitClientWithMtls(const char *cert_path, const char *pkey_path) noexcept;
56+
57+
/**
58+
* Initializes TlsContextOptions with secure by default options, with
59+
* client certificateand private key in the PKCS#12 format.
60+
* This is a path to a file on disk. These
61+
* strings must remain in memory for the lifetime of the returned object.
62+
*/
4663
static TlsContextOptions InitClientWithMtlsPkcs12(const char *pkcs12_path,
4764
const char *pkcs12_pwd) noexcept;
65+
66+
/**
67+
* Returns true if alpn is supported by the underlying security provider, false
68+
* otherwise.
69+
*/
4870
static bool IsAlpnSupported() noexcept;
4971

72+
/**
73+
* Sets the list of alpn protocols, delimited by ';'. This string must remain in memory
74+
* for the lifetime of this object.
75+
*/
5076
void SetAlpnList(const char* alpnList) noexcept;
77+
78+
/**
79+
* In client mode, this turns off x.509 validation. Don't do this unless you're testing.
80+
* It's much better, to just override the default trust store and pass the self-signed
81+
* certificate as the caFile argument.
82+
*
83+
* In server mode, this defaults to false. If you want to support mutual TLS from the server,
84+
* you'll want to set this to true.
85+
*/
5186
void SetVerifyPeer(bool verifyPeer) noexcept;
52-
void OverrideDefaultTrustStore(const char* caPath, const char*caFile) noexcept;
87+
88+
/**
89+
* Overrides the default system trust store. caPath is only useful on Unix style systems where
90+
* all anchors are stored in a directory (like /etc/ssl/certs). caFile is for a single file containing
91+
* all trusted CAs. caFile must be in the PEM format.
92+
*
93+
* These strings must remain in memory for the lifetime of this object.
94+
*/
95+
void OverrideDefaultTrustStore(const char* caPath, const char* caFile) noexcept;
5396

5497
private:
5598
aws_tls_ctx_options m_options;
@@ -60,7 +103,7 @@ namespace Aws
60103
class AWS_CRT_CPP_API TlsContext final
61104
{
62105
public:
63-
TlsContext(TlsContextOptions& options, TLSMode mode, Allocator* allocator = DefaultAllocator()) noexcept;
106+
TlsContext(TlsContextOptions& options, TlsMode mode, Allocator* allocator = DefaultAllocator()) noexcept;
64107
~TlsContext();
65108
TlsContext(const TlsContext&) = delete;
66109
TlsContext& operator=(const TlsContext&) = delete;

include/aws/crt/mqtt/MqttClient.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ namespace Aws
4949
/**
5050
* Invoked Upon Connection failure.
5151
*/
52-
using OnConnectionFailedHandler = std::function<void(MqttConnection& connection)>;
52+
using OnConnectionFailedHandler = std::function<void(MqttConnection& connection, int error)>;
5353

5454
/**
5555
* Invoked when a connack message is received.
@@ -60,7 +60,7 @@ namespace Aws
6060
/**
6161
* Invoked when a disconnect message has been sent.
6262
*/
63-
using OnDisconnectHandler = std::function<bool(MqttConnection& connection)>;
63+
using OnDisconnectHandler = std::function<bool(MqttConnection& connection, int error)>;
6464

6565
/**
6666
* Invoked upon receipt of a Publish message on a subscribed topic.
@@ -70,18 +70,20 @@ namespace Aws
7070
using OnOperationCompleteHandler = std::function<void(MqttConnection& connection, uint16_t packetId)>;
7171

7272
/**
73-
* Represents a persistent Mqtt Connection. The memory is owned by MqttClient. This is a move only type.
74-
* To get a new instance of this class, see MqttClient::NewConnection.
73+
* Represents a persistent Mqtt Connection. The memory is owned by MqttClient.
74+
* To get a new instance of this class, see MqttClient::NewConnection. Unless
75+
* specified all function arguments need only to live through the duration of the
76+
* function call.
7577
*/
7678
class AWS_CRT_CPP_API MqttConnection final
7779
{
7880
friend class MqttClient;
7981
public:
8082
~MqttConnection();
8183
MqttConnection(const MqttConnection&) = delete;
82-
MqttConnection(MqttConnection&&);
84+
MqttConnection(MqttConnection&&) = delete;
8385
MqttConnection& operator =(const MqttConnection&) = delete;
84-
MqttConnection& operator =(MqttConnection&&);
86+
MqttConnection& operator =(MqttConnection&&) = delete;
8587

8688
operator bool() const noexcept;
8789
int LastError() const noexcept;
@@ -154,19 +156,18 @@ namespace Aws
154156
void Ping();
155157

156158
private:
157-
MqttConnection(MqttClient* client, const char* hostName, uint16_t port,
159+
MqttConnection(aws_mqtt_client* client, const char* hostName, uint16_t port,
158160
const Io::SocketOptions& socketOptions,
159161
Io::TlsConnectionOptions&& tlsConnOptions) noexcept;
160-
MqttConnection(MqttClient* client, const char* hostName, uint16_t port,
162+
MqttConnection(aws_mqtt_client* client, const char* hostName, uint16_t port,
161163
const Io::SocketOptions& socketOptions) noexcept;
162164

163-
MqttClient *m_owningClient;
165+
aws_mqtt_client* m_owningClient;
164166
aws_mqtt_client_connection *m_underlyingConnection;
165167

166168
OnConnectionFailedHandler m_onConnectionFailed;
167169
OnConnAckHandler m_onConnAck;
168170
OnDisconnectHandler m_onDisconnect;
169-
std::atomic<int> m_lastError;
170171
std::atomic<ConnectionState> m_connectionState;
171172

172173
static void s_onConnectionFailed(aws_mqtt_client_connection* connection, int errorCode, void* userData);
@@ -185,11 +186,12 @@ namespace Aws
185186
};
186187

187188
/**
188-
* An MQTT client. This is a move-only type.
189+
* An MQTT client. This is a move-only type. Unless otherwise specified,
190+
* all function arguments need only to live through the duration of the
191+
* function call.
189192
*/
190193
class AWS_CRT_CPP_API MqttClient final
191194
{
192-
friend class MqttConnection;
193195
public:
194196
/**
195197
* Initialize an MqttClient using bootstrap and allocator
@@ -209,19 +211,17 @@ namespace Aws
209211
* Create a new connection object using TLS from the client. The client must outlive
210212
* all of its connection instances.
211213
*/
212-
MqttConnection NewConnection(const char* hostName, uint16_t port,
214+
std::shared_ptr<MqttConnection> NewConnection(const char* hostName, uint16_t port,
213215
const Io::SocketOptions& socketOptions, Io::TlsConnectionOptions&& tlsConnOptions) noexcept;
214216
/**
215217
* Create a new connection object over plain text from the client. The client must outlive
216218
* all of its connection instances.
217219
*/
218-
MqttConnection NewConnection(const char* hostName, uint16_t port,
220+
std::shared_ptr<MqttConnection> NewConnection(const char* hostName, uint16_t port,
219221
const Io::SocketOptions& socketOptions) noexcept;
220222

221223
private:
222-
aws_mqtt_client m_client;
223-
int m_lastError;
224-
bool m_isInit;
224+
aws_mqtt_client* m_client;
225225
};
226226
}
227227
}

samples/mqtt_pub_sub/main.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ int main(int argc, char* argv[])
121121
port = 443;
122122
}
123123

124-
Io::TlsContext tlsCtx(tlsCtxOptions, Io::TLSMode::CLIENT);
124+
Io::TlsContext tlsCtx(tlsCtxOptions, Io::TlsMode::CLIENT);
125125

126126
if (!tlsCtx)
127127
{
@@ -173,13 +173,13 @@ int main(int argc, char* argv[])
173173
* Now create a connection object. Note: This type is move only
174174
* and its underlying memory is managed by the client.
175175
*/
176-
Mqtt::MqttConnection connection =
176+
auto connection =
177177
mqttClient.NewConnection(endpoint.c_str(), port, socketOptions, tlsCtx.NewConnectionOptions());
178178

179-
if (!connection)
179+
if (!*connection)
180180
{
181181
fprintf(stderr, "MQTT Connection Creation failed with error %s\n",
182-
ErrorDebugString(connection.LastError()));
182+
ErrorDebugString(connection->LastError()));
183183
exit(-1);
184184
}
185185

@@ -200,7 +200,7 @@ int main(int argc, char* argv[])
200200
{
201201
{
202202
fprintf(stdout, "Connection completed with return code %d\n", returnCode);
203-
fprintf(stdout, "Conneciton state %d\n", connection.GetConnectionState());
203+
fprintf(stdout, "Conneciton state %d\n", connection->GetConnectionState());
204204
std::lock_guard<std::mutex> lockGuard(mutex);
205205
connectionSucceeded = true;
206206
}
@@ -210,10 +210,10 @@ int main(int argc, char* argv[])
210210
/*
211211
* This will be invoked when the TCP connection fails.
212212
*/
213-
auto onConFailure = [&](Mqtt::MqttConnection& connection)
213+
auto onConFailure = [&](Mqtt::MqttConnection&, int error)
214214
{
215215
{
216-
fprintf(stdout, "Connection failed with %s\n", ErrorDebugString(connection.LastError()));
216+
fprintf(stdout, "Connection failed with %s\n", ErrorDebugString(error));
217217
std::lock_guard<std::mutex> lockGuard(mutex);
218218
connectionClosed = true;
219219
}
@@ -223,29 +223,29 @@ int main(int argc, char* argv[])
223223
/*
224224
* Invoked when a disconnect message has completed.
225225
*/
226-
auto onDisconnect = [&](Mqtt::MqttConnection&) -> bool
226+
auto onDisconnect = [&](Mqtt::MqttConnection& conn, int error) -> bool
227227
{
228228
{
229-
fprintf(stdout, "Connection closed\n");
230-
fprintf(stdout, "Conneciton state %d\n", connection.GetConnectionState());
229+
fprintf(stdout, "Connection closed with error %s\n", ErrorDebugString(error));
230+
fprintf(stdout, "Conneciton state %d\n", conn.GetConnectionState());
231231
std::lock_guard<std::mutex> lockGuard(mutex);
232232
connectionClosed = true;
233233
}
234234
conditionVariable.notify_one();
235235
return false;
236236
};
237237

238-
connection.SetOnConnAckHandler(std::move(onConAck));
239-
connection.SetOnConnectionFailedHandler(std::move(onConFailure));
240-
connection.SetOnDisconnectHandler(std::move(onDisconnect));
238+
connection->SetOnConnAckHandler(std::move(onConAck));
239+
connection->SetOnConnectionFailedHandler(std::move(onConFailure));
240+
connection->SetOnDisconnectHandler(std::move(onDisconnect));
241241

242242
/*
243243
* Actually perform the connect dance.
244244
*/
245-
if (!connection.Connect("client_id12335456", true, 0))
245+
if (!connection->Connect("client_id12335456", true, 0))
246246
{
247247
fprintf(stderr, "MQTT Connection failed with error %s\n",
248-
ErrorDebugString(connection.LastError()));
248+
ErrorDebugString(connection->LastError()));
249249
exit(-1);
250250
}
251251

@@ -283,7 +283,7 @@ int main(int argc, char* argv[])
283283
/*
284284
* Publish our message.
285285
*/
286-
auto packetId = connection.Publish("a/b", AWS_MQTT_QOS_AT_LEAST_ONCE,
286+
auto packetId = connection->Publish("a/b", AWS_MQTT_QOS_AT_LEAST_ONCE,
287287
false, helloWorldPayload, onOpComplete);
288288
(void)packetId;
289289
conditionVariable.wait(uniqueLock);
@@ -305,20 +305,20 @@ int main(int argc, char* argv[])
305305
/*
306306
* Subscribe for incoming publish messages on topic.
307307
*/
308-
packetId = connection.Subscribe("a/b", AWS_MQTT_QOS_AT_LEAST_ONCE, onPublish, onOpComplete);
308+
packetId = connection->Subscribe("a/b", AWS_MQTT_QOS_AT_LEAST_ONCE, onPublish, onOpComplete);
309309
conditionVariable.wait(uniqueLock);
310310

311311
waitForSub = false;
312312
/*
313313
* Unsubscribe from the topic.
314314
*/
315-
connection.Unsubscribe("a/b", onOpComplete);
315+
connection->Unsubscribe("a/b", onOpComplete);
316316
conditionVariable.wait(uniqueLock);
317317
}
318318

319319
if (!connectionClosed) {
320320
/* Disconnect */
321-
connection.Disconnect();
321+
connection->Disconnect();
322322
conditionVariable.wait(uniqueLock, [&]() { return connectionClosed; });
323323
}
324324
return 0;

0 commit comments

Comments
 (0)