Skip to content

Commit 9020190

Browse files
authored
Fix setting headers in eventstream-rpc (#788)
1 parent 0a3ddbc commit 9020190

File tree

2 files changed

+58
-21
lines changed

2 files changed

+58
-21
lines changed

eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ namespace Aws
101101
MessageAmendment(const MessageAmendment &lhs);
102102
MessageAmendment(MessageAmendment &&rhs);
103103
MessageAmendment &operator=(const MessageAmendment &lhs);
104+
MessageAmendment &operator=(MessageAmendment &&rhs);
104105
~MessageAmendment() noexcept;
105106
explicit MessageAmendment(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
106107
MessageAmendment(
@@ -114,10 +115,22 @@ namespace Aws
114115
Crt::List<EventStreamHeader> &&headers,
115116
Crt::Allocator *allocator = Crt::g_allocator) noexcept;
116117
MessageAmendment(const Crt::ByteBuf &payload, Crt::Allocator *allocator = Crt::g_allocator) noexcept;
118+
119+
/**
120+
* Add a given header to the end of the header list.
121+
*/
117122
void AddHeader(EventStreamHeader &&header) noexcept;
123+
124+
/**
125+
* Add given headers to the beginning of the header list.
126+
*/
127+
void PrependHeaders(Crt::List<EventStreamHeader> &&headers);
118128
void SetPayload(const Crt::Optional<Crt::ByteBuf> &payload) noexcept;
119-
const Crt::List<EventStreamHeader> &GetHeaders() const noexcept;
120-
const Crt::Optional<Crt::ByteBuf> &GetPayload() const noexcept;
129+
void SetPayload(Crt::Optional<Crt::ByteBuf> &&payload);
130+
const Crt::List<EventStreamHeader> &GetHeaders() const &noexcept;
131+
Crt::List<EventStreamHeader> &&GetHeaders() &&;
132+
const Crt::Optional<Crt::ByteBuf> &GetPayload() const &noexcept;
133+
Crt::Optional<Crt::ByteBuf> &&GetPayload() &&;
121134

122135
private:
123136
Crt::List<EventStreamHeader> m_headers;

eventstream_rpc/source/EventStreamClient.cpp

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ namespace Aws
5151
}
5252

5353
MessageAmendment::MessageAmendment(Crt::List<EventStreamHeader> &&headers, Crt::Allocator *allocator) noexcept
54-
: m_headers(headers), m_payload(), m_allocator(allocator)
54+
: m_headers(std::move(headers)), m_payload(), m_allocator(allocator)
5555
{
5656
}
5757

@@ -79,14 +79,16 @@ namespace Aws
7979

8080
MessageAmendment &MessageAmendment::operator=(const MessageAmendment &lhs)
8181
{
82-
m_headers = lhs.m_headers;
83-
if (lhs.m_payload.has_value())
82+
if (this != &lhs)
8483
{
85-
m_payload =
86-
Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_payload.value().buffer, lhs.m_payload.value().len);
84+
m_headers = lhs.m_headers;
85+
if (lhs.m_payload.has_value())
86+
{
87+
m_payload =
88+
Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_payload.value().buffer, lhs.m_payload.value().len);
89+
}
90+
m_allocator = lhs.m_allocator;
8791
}
88-
m_allocator = lhs.m_allocator;
89-
9092
return *this;
9193
}
9294

@@ -97,9 +99,37 @@ namespace Aws
9799
rhs.m_payload = Crt::Optional<Crt::ByteBuf>();
98100
}
99101

100-
const Crt::List<EventStreamHeader> &MessageAmendment::GetHeaders() const noexcept { return m_headers; }
102+
MessageAmendment &MessageAmendment::operator=(MessageAmendment &&rhs)
103+
{
104+
if (this != &rhs)
105+
{
106+
m_headers = std::move(rhs.m_headers);
107+
m_payload = std::move(rhs.m_payload);
108+
m_allocator = rhs.m_allocator;
109+
110+
rhs.m_allocator = nullptr;
111+
rhs.m_payload = Crt::Optional<Crt::ByteBuf>();
112+
}
113+
return *this;
114+
}
101115

102-
const Crt::Optional<Crt::ByteBuf> &MessageAmendment::GetPayload() const noexcept { return m_payload; }
116+
const Crt::List<EventStreamHeader> &MessageAmendment::GetHeaders() const &noexcept { return m_headers; }
117+
118+
Crt::List<EventStreamHeader> &&MessageAmendment::GetHeaders() && { return std::move(m_headers); }
119+
120+
const Crt::Optional<Crt::ByteBuf> &MessageAmendment::GetPayload() const &noexcept { return m_payload; }
121+
122+
Crt::Optional<Crt::ByteBuf> &&MessageAmendment::GetPayload() && { return std::move(m_payload); }
123+
124+
void MessageAmendment::AddHeader(EventStreamHeader &&header) noexcept
125+
{
126+
m_headers.emplace_back(std::move(header));
127+
}
128+
129+
void MessageAmendment::PrependHeaders(Crt::List<EventStreamHeader> &&headers)
130+
{
131+
m_headers.splice(m_headers.begin(), std::move(headers));
132+
}
103133

104134
void MessageAmendment::SetPayload(const Crt::Optional<Crt::ByteBuf> &payload) noexcept
105135
{
@@ -109,6 +139,8 @@ namespace Aws
109139
}
110140
}
111141

142+
void MessageAmendment::SetPayload(Crt::Optional<Crt::ByteBuf> &&payload) { m_payload = std::move(payload); }
143+
112144
MessageAmendment::~MessageAmendment() noexcept
113145
{
114146
if (m_payload.has_value())
@@ -663,23 +695,17 @@ namespace Aws
663695
thisConnection->m_clientState = WAITING_FOR_CONNECT_ACK;
664696
thisConnection->m_underlyingConnection = connection;
665697
MessageAmendment messageAmendment;
666-
Crt::List<EventStreamHeader> messageAmendmentHeaders = messageAmendment.GetHeaders();
667698

668699
if (thisConnection->m_connectMessageAmender)
669700
{
670701
MessageAmendment connectAmendment(thisConnection->m_connectMessageAmender());
671-
Crt::List<EventStreamHeader> amenderHeaderList = connectAmendment.GetHeaders();
672702
/* The version header is necessary for establishing the connection. */
673703
messageAmendment.AddHeader(EventStreamHeader(
674704
Crt::String(EVENTSTREAM_VERSION_HEADER),
675705
Crt::String(EVENTSTREAM_VERSION_STRING),
676706
thisConnection->m_allocator));
677-
/* Note that we are prepending headers from the user-provided amender. */
678-
if (amenderHeaderList.size() > 0)
679-
{
680-
messageAmendmentHeaders.splice(messageAmendmentHeaders.end(), amenderHeaderList);
681-
}
682-
messageAmendment.SetPayload(connectAmendment.GetPayload());
707+
messageAmendment.PrependHeaders(std::move(connectAmendment).GetHeaders());
708+
messageAmendment.SetPayload(std::move(connectAmendment).GetPayload());
683709
}
684710

685711
/* Send a CONNECT packet to the server. */
@@ -695,8 +721,6 @@ namespace Aws
695721
thisConnection->m_connectionSetupPromise.set_value();
696722
}
697723

698-
void MessageAmendment::AddHeader(EventStreamHeader &&header) noexcept { m_headers.emplace_back(header); }
699-
700724
void ClientConnection::s_onConnectionShutdown(
701725
struct aws_event_stream_rpc_client_connection *connection,
702726
int errorCode,

0 commit comments

Comments
 (0)