Skip to content

Commit 92d819a

Browse files
Made TLsCtxOptions OOPish.
1 parent efbdf20 commit 92d819a

File tree

7 files changed

+88
-53
lines changed

7 files changed

+88
-53
lines changed

include/aws/crt/io/TlsOptions.h

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ namespace Aws
2626
{
2727
namespace Io
2828
{
29-
using TlsContextOptions = aws_tls_ctx_options;
3029
using TlsConnectionOptions = aws_tls_connection_options;
3130

3231
enum class TLSMode
@@ -35,6 +34,29 @@ namespace Aws
3534
SERVER,
3635
};
3736

37+
class AWS_CRT_CPP_API TlsContextOptions final
38+
{
39+
friend class TlsContext;
40+
public:
41+
TlsContextOptions(const TlsContextOptions&) noexcept = default;
42+
TlsContextOptions& operator=(const TlsContextOptions&) noexcept = default;
43+
44+
static TlsContextOptions InitDefaultClient() noexcept;
45+
static TlsContextOptions InitClientWithMtls(const char *cert_path, const char *pkey_path) noexcept;
46+
static TlsContextOptions InitClientWithMtlsPkcs12(const char *pkcs12_path,
47+
const char *pkcs12_pwd) noexcept;
48+
static bool IsAlpnSupported() noexcept;
49+
50+
void SetAlpnList(const char* alpnList) noexcept;
51+
void SetVerifyPeer(bool verifyPeer) noexcept;
52+
void OverrideDefaultTrustStore(const char* caPath, const char*caFile) noexcept;
53+
54+
private:
55+
aws_tls_ctx_options m_options;
56+
57+
TlsContextOptions() noexcept;
58+
};
59+
3860
class AWS_CRT_CPP_API TlsContext final
3961
{
4062
public:
@@ -55,20 +77,9 @@ namespace Aws
5577
int m_lastError;
5678
};
5779

58-
AWS_CRT_CPP_API void InitDefaultClient(TlsContextOptions& options) noexcept;
59-
AWS_CRT_CPP_API void InitClientWithMtls(TlsContextOptions &options,
60-
const char *cert_path, const char *pkey_path) noexcept;
61-
AWS_CRT_CPP_API void InitClientWithMtlsPkcs12(TlsContextOptions &options,
62-
const char *pkcs12_path, const char *pkcs12_pwd) noexcept;
63-
AWS_CRT_CPP_API void SetALPNList(TlsContextOptions& options, const char* alpn_list) noexcept;
64-
AWS_CRT_CPP_API void SetVerifyPeer(TlsContextOptions& options, bool verify_peer) noexcept;
65-
AWS_CRT_CPP_API void OverrideDefaultTrustStore(TlsContextOptions& options,
66-
const char* caPath, const char* caFile) noexcept;
67-
6880
AWS_CRT_CPP_API void InitTlsStaticState(Allocator *alloc) noexcept;
6981
AWS_CRT_CPP_API void CleanUpTlsStaticState() noexcept;
7082

71-
AWS_CRT_CPP_API bool IsAlpnSupported() noexcept;
7283
}
7384
}
7485
}

include/aws/crt/mqtt/MqttClient.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ namespace Aws
6767
{
6868
friend class MqttClient;
6969
public:
70-
~MqttConnection() = default;
70+
~MqttConnection();
7171
MqttConnection(const MqttConnection&) = delete;
7272
MqttConnection(MqttConnection&&) = default;
7373
MqttConnection& operator =(const MqttConnection&) = delete;

samples/mqtt_pub_sub/main.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,24 +100,24 @@ int main(int argc, char* argv[])
100100
/*
101101
* We're using Mutual TLS for Mqtt, so we need to load our client certificates
102102
*/
103-
Io::TlsContextOptions tlsCtxOptions;
104-
Io::InitClientWithMtls(tlsCtxOptions, certificatePath.c_str(), keyPath.c_str());
103+
Io::TlsContextOptions tlsCtxOptions =
104+
Io::TlsContextOptions::InitClientWithMtls(certificatePath.c_str(), keyPath.c_str());
105105
/*
106106
* If we have a custom CA, set that up here.
107107
*/
108108
if (!caFile.empty())
109109
{
110-
Io::OverrideDefaultTrustStore(tlsCtxOptions, nullptr, caFile.c_str());
110+
tlsCtxOptions.OverrideDefaultTrustStore(nullptr, caFile.c_str());
111111
}
112112

113113
uint16_t port = 8883;
114-
if (Io::IsAlpnSupported())
114+
if (Io::TlsContextOptions::IsAlpnSupported())
115115
{
116116
/*
117117
* Use ALPN to negotiate the mqtt protocol on a normal
118118
* TLS port if possible.
119119
*/
120-
Io::SetALPNList(tlsCtxOptions, "x-amzn-mqtt-ca");
120+
tlsCtxOptions.SetAlpnList("x-amzn-mqtt-ca");
121121
port = 443;
122122
}
123123

source/io/TlsOptions.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,51 @@ namespace Aws
2222
{
2323
namespace Io
2424
{
25-
void InitDefaultClient(TlsContextOptions& options) noexcept
25+
TlsContextOptions::TlsContextOptions() noexcept
2626
{
27-
aws_tls_ctx_options_init_default_client(&options);
27+
AWS_ZERO_STRUCT(m_options);
2828
}
2929

30-
void InitClientWithMtls(TlsContextOptions &options,
31-
const char *certPath, const char *pKeyPath) noexcept
30+
TlsContextOptions TlsContextOptions::InitDefaultClient() noexcept
3231
{
33-
aws_tls_ctx_options_init_client_mtls(&options, certPath, pKeyPath);
32+
TlsContextOptions ctxOptions;
33+
aws_tls_ctx_options_init_default_client(&ctxOptions.m_options);
34+
return ctxOptions;
3435
}
3536

36-
void InitClientWithMtlsPkcs12(TlsContextOptions &options,
37-
const char *pkcs12Path, const char *pkcs12Pwd) noexcept
37+
TlsContextOptions TlsContextOptions::InitClientWithMtls(const char *certPath, const char *pKeyPath) noexcept
3838
{
39-
aws_tls_ctx_options_init_client_mtls_pkcs12(&options, pkcs12Path, pkcs12Pwd);
39+
TlsContextOptions ctxOptions;
40+
aws_tls_ctx_options_init_client_mtls(&ctxOptions.m_options, certPath, pKeyPath);
41+
return ctxOptions;
4042
}
4143

42-
void SetALPNList(TlsContextOptions& options, const char* alpn_list) noexcept
44+
TlsContextOptions TlsContextOptions::InitClientWithMtlsPkcs12(const char *pkcs12Path,
45+
const char *pkcs12Pwd) noexcept
4346
{
44-
aws_tls_ctx_options_set_alpn_list(&options, alpn_list);
47+
TlsContextOptions ctxOptions;
48+
aws_tls_ctx_options_init_client_mtls_pkcs12(&ctxOptions.m_options, pkcs12Path, pkcs12Pwd);
49+
return ctxOptions;
4550
}
4651

47-
void SetVerifyPeer(TlsContextOptions& options, bool verify_peer) noexcept
52+
bool TlsContextOptions::IsAlpnSupported() noexcept
4853
{
49-
aws_tls_ctx_options_set_verify_peer(&options, verify_peer);
54+
return aws_tls_is_alpn_available();
55+
}
56+
57+
void TlsContextOptions::SetAlpnList(const char* alpn_list) noexcept
58+
{
59+
aws_tls_ctx_options_set_alpn_list(&m_options, alpn_list);
60+
}
61+
62+
void TlsContextOptions::SetVerifyPeer(bool verify_peer) noexcept
63+
{
64+
aws_tls_ctx_options_set_verify_peer(&m_options, verify_peer);
5065
}
5166

52-
void OverrideDefaultTrustStore(TlsContextOptions& options,
53-
const char* caPath, const char* caFile) noexcept
67+
void TlsContextOptions::OverrideDefaultTrustStore(const char* caPath, const char* caFile) noexcept
5468
{
55-
aws_tls_ctx_options_override_default_trust_store(&options, caPath, caFile);
69+
aws_tls_ctx_options_override_default_trust_store(&m_options, caPath, caFile);
5670
}
5771

5872
void InitTlsStaticState(Aws::Crt::Allocator *alloc) noexcept
@@ -65,21 +79,17 @@ namespace Aws
6579
aws_tls_clean_up_static_state();
6680
}
6781

68-
bool IsAlpnSupported() noexcept
69-
{
70-
return aws_tls_is_alpn_available();
71-
}
7282

7383
TlsContext::TlsContext(TlsContextOptions& options, TLSMode mode, Allocator* allocator) noexcept :
7484
m_ctx(nullptr), m_lastError(AWS_OP_SUCCESS)
7585
{
7686
if (mode == TLSMode::CLIENT)
7787
{
78-
m_ctx = aws_tls_client_ctx_new(allocator, &options);
88+
m_ctx = aws_tls_client_ctx_new(allocator, &options.m_options);
7989
}
8090
else
8191
{
82-
m_ctx = aws_tls_server_ctx_new(allocator, &options);
92+
m_ctx = aws_tls_server_ctx_new(allocator, &options.m_options);
8393
}
8494

8595
if (!m_ctx)

source/mqtt/MqttClient.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,22 @@ namespace Aws
6666
{
6767
MqttConnection* connection;
6868
OnPublishReceivedHandler onPublishReceived;
69+
Allocator* allocator;
6970
};
7071

72+
static void s_cleanUpOnPublishData(void *userData)
73+
{
74+
auto callbackData = reinterpret_cast<PubCallbackData*>(userData);
75+
callbackData->~PubCallbackData();
76+
aws_mem_release(callbackData->allocator, reinterpret_cast<void*>(callbackData));
77+
}
78+
7179
void MqttConnection::s_onPublish(aws_mqtt_client_connection*,
7280
const aws_byte_cursor* topic,
7381
const aws_byte_cursor* payload,
7482
void* userData)
7583
{
7684
auto callbackData = reinterpret_cast<PubCallbackData*>(userData);
77-
//TODO:
78-
// SDK-5312 gives us a callback to free this, for now let it leak. When it's fixed comeback and handle.
7985

8086
if (callbackData->onPublishReceived)
8187
{
@@ -146,6 +152,14 @@ namespace Aws
146152
}
147153
}
148154

155+
MqttConnection::~MqttConnection()
156+
{
157+
if (*this)
158+
{
159+
aws_mqtt_client_connection_destroy(m_underlyingConnection);
160+
}
161+
}
162+
149163
MqttConnection::operator bool() const noexcept
150164
{
151165
return m_isInit;
@@ -220,6 +234,7 @@ namespace Aws
220234

221235
pubCallbackData->connection = this;
222236
pubCallbackData->onPublishReceived = std::move(onPublish);
237+
pubCallbackData->allocator = m_owningClient->m_client.allocator;
223238

224239
OpCompleteCallbackData *opCompleteCallbackData =
225240
reinterpret_cast<OpCompleteCallbackData*>(aws_mem_acquire(m_owningClient->m_client.allocator,
@@ -232,6 +247,7 @@ namespace Aws
232247
m_lastError = aws_last_error();
233248
return 0;
234249
}
250+
235251
opCompleteCallbackData = new(opCompleteCallbackData)OpCompleteCallbackData;
236252

237253
opCompleteCallbackData->connection = this;
@@ -245,7 +261,7 @@ namespace Aws
245261

246262
uint16_t packetId = aws_mqtt_client_connection_subscribe(m_underlyingConnection,
247263
&topicFilterCur, qos, s_onPublish,
248-
pubCallbackData, s_onOpComplete, opCompleteCallbackData);
264+
pubCallbackData, s_cleanUpOnPublishData, s_onOpComplete, opCompleteCallbackData);
249265

250266
if (!packetId)
251267
{

tests/MqttClientTest.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
static int s_TestMqttClientResourceSafety(Aws::Crt::Allocator* allocator, void *)
2121
{
2222
Aws::Crt::ApiHandle apiHandle(allocator);
23-
Aws::Crt::Io::TlsContextOptions tlsCtxOptions;
24-
Aws::Crt::Io::InitDefaultClient(tlsCtxOptions);
23+
Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient();
2524

2625
Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TLSMode::CLIENT, allocator);
2726
ASSERT_TRUE(tlsContext);
@@ -41,16 +40,16 @@ static int s_TestMqttClientResourceSafety(Aws::Crt::Allocator* allocator, void *
4140
Aws::Crt::Mqtt::MqttClient mqttClient(clientBootstrap, allocator);
4241
ASSERT_TRUE(mqttClient);
4342

44-
//Uncomment the next section once connection clean up code in the underlying c lib has been updated.
45-
//Aws::Crt::Mqtt::MqttConnection mqttConnection = mqttClient.NewConnection("www.example.com", 443,
46-
// socketOptions, tlsContext.NewConnectionOptions());
47-
//mqttConnection.Disconnect();
48-
//
49-
//ASSERT_TRUE(mqttConnection);
50-
5143
Aws::Crt::Mqtt::MqttClient mqttClientMoved = std::move(mqttClient);
5244
ASSERT_TRUE(mqttClientMoved);
5345

46+
Aws::Crt::Mqtt::MqttConnection mqttConnection = mqttClientMoved.NewConnection("www.example.com", 443,
47+
socketOptions, tlsContext.NewConnectionOptions());
48+
mqttConnection.Disconnect();
49+
50+
ASSERT_TRUE(mqttConnection);
51+
52+
5453
// NOLINTNEXTLINE
5554
ASSERT_FALSE(mqttClient);
5655

tests/TLSContextTest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
static int s_TestTLSContextResourceSafety(Aws::Crt::Allocator* allocator, void *)
2121
{
2222
Aws::Crt::ApiHandle apiHandle(allocator);
23-
Aws::Crt::Io::TlsContextOptions tlsCtxOptions;
24-
Aws::Crt::Io::InitDefaultClient(tlsCtxOptions);
23+
Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient();
2524

2625
Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TLSMode::CLIENT, allocator);
2726
ASSERT_TRUE(tlsContext);

0 commit comments

Comments
 (0)