Skip to content

Greengrass IPC #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
201 changes: 137 additions & 64 deletions eventstreamrpc/include/aws/eventstreamrpc/EventStreamClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
#include <aws/eventstreamrpc/Exports.h>

#include <aws/crt/DateTime.h>
#include <aws/crt/JsonObject.h>
#include <aws/crt/StlAllocator.h>
#include <aws/crt/Types.h>
#include <aws/crt/UUID.h>
#include <aws/crt/io/EventLoopGroup.h>
#include <aws/crt/io/SocketOptions.h>
#include <aws/crt/io/TlsOptions.h>
#include <aws/crt/JsonObject.h>

#include <aws/event-stream/event_stream_rpc_client.h>

Expand Down Expand Up @@ -190,44 +190,77 @@ namespace Aws
virtual void OnContinuationClosed() = 0;
};

enum EventStreamErrors
enum EventStreamRpcError
{
/* If error messages are added to `aws_event_stream_errors`, this will need to be updated. */
AWS_ERROR_EVENT_STREAM_RPC_UNKNOWN_PROTOCOL_MESSAGE = AWS_ERROR_EVENT_STREAM_RPC_STREAM_NOT_ACTIVATED + 1,
AWS_ERROR_EVENT_STREAM_RPC_UNMAPPED_DATA,
AWS_ERROR_EVENT_STREAM_RPC_UNSUPPORTED_CONTENT_TYPE,
AWS_ERROR_EVENT_STREAM_RPC_STREAM_CLOSED_ERROR
EVENT_STREAM_RPC_SUCCESS = 0,
EVENT_STREAM_RPC_NULL_PARAMETER,
EVENT_STREAM_RPC_INITIALIZATION_ERROR,
EVENT_STREAM_RPC_CONNECTION_CLOSED_BEFORE_CONNACK,
EVENT_STREAM_RPC_UNKNOWN_PROTOCOL_MESSAGE,
EVENT_STREAM_RPC_UNMAPPED_DATA,
EVENT_STREAM_RPC_UNSUPPORTED_CONTENT_TYPE,
EVENT_STREAM_RPC_STREAM_CLOSED_ERROR,
EVENT_STREAM_RPC_UNEXPECTED_ERROR,
EVENT_STREAM_RPC_CRT_ERROR
};

struct EventStreamRpcStatus
{
EventStreamRpcError baseStatus;
int crtError;
};

template <class T> class ProtectedPromise
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this type will never be public facing and is just a helper class. Ultimately, we still return an std::future.

{
/* A wrapper around std::promise so that it cannot be set twice without having to catch exceptions. */
public:
ProtectedPromise() noexcept;
ProtectedPromise(const ProtectedPromise &lhs) noexcept = delete;
ProtectedPromise(ProtectedPromise &&rhs) noexcept;
ProtectedPromise &operator=(ProtectedPromise &&) noexcept;
ProtectedPromise(std::promise<T> &&promise) noexcept;
void SetValue(T &&r) noexcept;
void SetValue(const T &r) noexcept;
std::future<T> GetFuture() noexcept;
void Reset() noexcept;

private:
bool m_fulfilled;
std::promise<T> m_promise;
std::mutex m_mutex;
};

class AWS_EVENTSTREAMRPC_API ClientConnection final
{
public:
ClientConnection(Crt::Allocator *allocator) noexcept;
ClientConnection(
ConnectionLifecycleHandler &connectionLifecycleHandler,
Crt::Allocator *allocator) noexcept;
~ClientConnection() noexcept;
ClientConnection(const ClientConnection &) noexcept = delete;
ClientConnection &operator=(const ClientConnection &) noexcept = delete;
ClientConnection(ClientConnection &&) noexcept = default;
ClientConnection &operator=(ClientConnection &&) noexcept = default;
ClientConnection(ClientConnection &&) noexcept;
ClientConnection &operator=(ClientConnection &&) noexcept;

bool Connect(
std::future<EventStreamRpcStatus> Connect(
const ClientConnectionOptions &connectionOptions,
ConnectionLifecycleHandler *connectionLifecycleHandler,
ConnectionLifecycleHandler &connectionLifecycleHandler,
ConnectMessageAmender connectMessageAmender) noexcept;

void SendPing(
std::future<EventStreamRpcStatus> SendPing(
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
OnMessageFlushCallback onMessageFlushCallback) noexcept;

void SendPingResponse(
std::future<EventStreamRpcStatus> SendPingResponse(
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
OnMessageFlushCallback onMessageFlushCallback) noexcept;

ClientContinuation NewStream(ClientContinuationHandler *clientContinuationHandler) noexcept;
ClientContinuation NewStream(ClientContinuationHandler &clientContinuationHandler) noexcept;

void Close() noexcept;
void Close(int errorCode) noexcept;

/**
* @return true if the instance is in a valid state, false otherwise.
Expand All @@ -248,13 +281,17 @@ namespace Aws
CONNECTED,
DISCONNECTING,
};
std::mutex m_stateMutex;
Crt::Allocator *m_allocator;
struct aws_event_stream_rpc_client_connection *m_underlyingConnection;
ClientState m_clientState;
ConnectionLifecycleHandler *m_lifecycleHandler;
ConnectionLifecycleHandler &m_lifecycleHandler;
ConnectMessageAmender m_connectMessageAmender;
ProtectedPromise<EventStreamRpcStatus> m_connectAckedPromise;
std::promise<EventStreamRpcStatus> m_closedPromise;
OnMessageFlushCallback m_onConnectRequestCallback;
static void s_customDeleter(ClientConnection *connection) noexcept;
void SendProtocolMessage(
std::future<EventStreamRpcStatus> SendProtocolMessage(
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
MessageType messageType,
Expand All @@ -275,21 +312,21 @@ namespace Aws
void *userData) noexcept;

static void s_protocolMessageCallback(int errorCode, void *userData) noexcept;
static void s_sendProtocolMessage(
static std::future<EventStreamRpcStatus> s_sendProtocolMessage(
ClientConnection *connection,
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
MessageType messageType,
uint32_t messageFlags,
OnMessageFlushCallback onMessageFlushCallback) noexcept;

static void s_sendPing(
static std::future<EventStreamRpcStatus> s_sendPing(
ClientConnection *connection,
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
OnMessageFlushCallback onMessageFlushCallback) noexcept;

static void s_sendPingResponse(
static std::future<EventStreamRpcStatus> s_sendPingResponse(
ClientConnection *connection,
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
Expand All @@ -301,17 +338,17 @@ namespace Aws
public:
ClientContinuation(
ClientConnection *connection,
ClientContinuationHandler *handler,
ClientContinuationHandler &continuationHandler,
Crt::Allocator *allocator) noexcept;
void Activate(
std::future<EventStreamRpcStatus> Activate(
const Crt::String &operation,
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
MessageType messageType,
uint32_t messageFlags,
OnMessageFlushCallback onMessageFlushCallback) noexcept;
bool IsClosed() noexcept;
void SendMessage(
std::future<EventStreamRpcStatus> SendMessage(
const Crt::List<EventStreamHeader> &headers,
const Crt::Optional<Crt::ByteBuf> &payload,
MessageType messageType,
Expand All @@ -320,7 +357,7 @@ namespace Aws

private:
Crt::Allocator *m_allocator;
ClientContinuationHandler *m_handler;
ClientContinuationHandler &m_continuationHandler;
struct aws_event_stream_rpc_client_continuation_token *m_continuationToken;
static void s_onContinuationMessage(
struct aws_event_stream_rpc_client_continuation_token *continuationToken,
Expand All @@ -334,20 +371,22 @@ namespace Aws
class AWS_EVENTSTREAMRPC_API AbstractShapeBase
{
public:
AbstractShapeBase(Crt::Allocator* allocator = Crt::g_allocator) noexcept;
AbstractShapeBase(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
static void s_customDeleter(AbstractShapeBase *shape) noexcept;
virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const = 0;

protected:
virtual Crt::String GetModelName() const noexcept;
friend class ClientOperation;

private:
Crt::Allocator *m_allocator;
};

class AWS_EVENTSTREAMRPC_API OperationResponse : public AbstractShapeBase
{
public:
OperationResponse(Crt::Allocator* allocator = Crt::g_allocator) noexcept;
OperationResponse(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
static void s_customDeleter(OperationResponse *shape) noexcept;
/* A response does not necessarily have to be serialized so provide a default implementation. */
virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override;
Expand All @@ -356,18 +395,19 @@ namespace Aws
class AWS_EVENTSTREAMRPC_API OperationRequest : public AbstractShapeBase
{
public:
OperationRequest(Crt::Allocator* allocator = Crt::g_allocator) noexcept;
OperationRequest(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
};

class AWS_EVENTSTREAMRPC_API OperationError : public AbstractShapeBase
{
public:
OperationError(Crt::Allocator* allocator = Crt::g_allocator) noexcept;
OperationError(int errorCode, Crt::Allocator* allocator) noexcept;
OperationError(Crt::Allocator *allocator = Crt::g_allocator) noexcept;
OperationError(int errorCode, Crt::Allocator *allocator) noexcept;
const Crt::Optional<int> &GetErrorCode() const noexcept;
void SetErrorCode(int errorCode) noexcept;
static void s_customDeleter(OperationError *shape) noexcept;
virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override;

private:
Crt::Optional<int> m_errorCode;
};
Expand All @@ -384,6 +424,7 @@ namespace Aws
* Invoked when stream is closed, so no more messages will be receivied.
*/
virtual void OnStreamClosed();

protected:
friend class ClientOperation;
/**
Expand All @@ -398,56 +439,85 @@ namespace Aws
virtual bool OnStreamError(Crt::ScopedResource<OperationError> error);
};

union AWS_EVENTSTREAMRPC_API ResponseResult
{
ResponseResult() {}
~ResponseResult() {}
Crt::ScopedResource<OperationResponse> response;
Crt::ScopedResource<OperationError> error;
};

enum AWS_EVENTSTREAMRPC_API ResponseType
{
EXPECTED_RESPONSE,
ERROR_RESPONSE
};

struct TaggedResponse
union AWS_EVENTSTREAMRPC_API ResponseResult
{
TaggedResponse() {}
TaggedResponse(TaggedResponse &&taggedResponse)
{
if (responseType == EXPECTED_RESPONSE)
{
responseResult.response = std::move(taggedResponse.responseResult.response);
}
else if (responseType == ERROR_RESPONSE)
{
responseResult.error = std::move(taggedResponse.responseResult.error);
}
}
ResponseType responseType;
ResponseResult responseResult;
ResponseResult(Crt::ScopedResource<OperationResponse> &&response) { m_response = std::move(response); }
ResponseResult(Crt::ScopedResource<OperationError> &&error) { m_error = std::move(error); }
ResponseResult() : m_error(nullptr) {}
~ResponseResult() noexcept {};
Crt::ScopedResource<OperationResponse> m_response;
Crt::ScopedResource<OperationError> m_error;
};

class AWS_EVENTSTREAMRPC_API TaggedResponse
{
public:
TaggedResponse(Crt::ScopedResource<OperationResponse> response) noexcept;
TaggedResponse(Crt::ScopedResource<OperationError> error) noexcept;
TaggedResponse(TaggedResponse &&taggedResponse) noexcept;
~TaggedResponse() noexcept = default;
/**
* @return true if the response is associated with an expected response;
* false if the response is associated with an error.
*/
operator bool() const noexcept;

OperationResponse *GetResponse();
OperationError *GetError();

private:
ResponseType m_responseType;
ResponseResult m_responseResult;
};

using ExpectedResponseFactory =
std::function<Crt::ScopedResource<OperationResponse>(const Crt::StringView &payload, Crt::Allocator* allocator)>;
using ErrorResponseFactory = std::function<Crt::ScopedResource<OperationError>(const Crt::StringView &payload, Crt::Allocator* allocator)>;
using ExpectedResponseFactory = std::function<
Crt::ScopedResource<OperationResponse>(const Crt::StringView &payload, Crt::Allocator *allocator)>;
using ErrorResponseFactory = std::function<
Crt::ScopedResource<OperationError>(const Crt::StringView &payload, Crt::Allocator *allocator)>;

using LoneResponseRetriever = std::function<ExpectedResponseFactory(const Crt::String &modelName)>;
using StreamingResponseRetriever = std::function<ExpectedResponseFactory(const Crt::String &modelName)>;
using ErrorResponseRetriever = std::function<ErrorResponseFactory(const Crt::String &modelName)>;

class AWS_EVENTSTREAMRPC_API ResponseRetriever
{
/* An interface shared by all operations for retrieving the response object given the model name. */
public:
virtual ExpectedResponseFactory GetLoneResponseFromModelName(
const Crt::String &modelName) const noexcept = 0;
virtual ExpectedResponseFactory GetStreamingResponseFromModelName(
const Crt::String &modelName) const noexcept = 0;
virtual ErrorResponseFactory GetErrorResponseFromModelName(const Crt::String &modelName) const noexcept = 0;
};

class AWS_EVENTSTREAMRPC_API ClientOperation : private ClientContinuationHandler
{
public:
ClientOperation(ClientConnection &connection, StreamResponseHandler *streamHandler, Crt::Allocator* allocator) noexcept;
std::future<void> Close(OnMessageFlushCallback onMessageFlushCallback = nullptr) noexcept;
ClientOperation(
ClientConnection &connection,
StreamResponseHandler *streamHandler,
const ResponseRetriever &responseRetriever,
Crt::Allocator *allocator) noexcept;
~ClientOperation() noexcept;
ClientOperation(const ClientOperation &clientOperation) noexcept = default;
ClientOperation(ClientOperation &&clientOperation) noexcept;
std::future<EventStreamRpcStatus> Close(OnMessageFlushCallback onMessageFlushCallback = nullptr) noexcept;
std::future<TaggedResponse> GetResponse() noexcept;
// virtual bool IsStreaming() = 0;

protected:
void Activate(const OperationRequest *shape, OnMessageFlushCallback onMessageFlushCallback) noexcept;
void SendStreamEvent(OperationRequest *shape, OnMessageFlushCallback onMessageFlushCallback) noexcept;
Crt::Map<Crt::String, ExpectedResponseFactory> m_ModelNameToSingleResponseObject;
Crt::Map<Crt::String, ExpectedResponseFactory> m_ModelNameToStreamingResponseObject;
Crt::Map<Crt::String, ErrorResponseFactory> m_ErrorNameToObject;
std::future<EventStreamRpcStatus> Activate(
const OperationRequest *shape,
OnMessageFlushCallback onMessageFlushCallback) noexcept;
std::future<EventStreamRpcStatus> SendStreamEvent(
OperationRequest *shape,
OnMessageFlushCallback onMessageFlushCallback) noexcept;
virtual Crt::String GetModelName() const noexcept = 0;

private:
Expand Down Expand Up @@ -478,10 +548,13 @@ namespace Aws
const Crt::String &name) noexcept;
uint32_t m_messageCount;
Crt::Allocator *m_allocator;
StreamResponseHandler *m_streamHandler;
StreamResponseHandler* m_streamHandler;
const ResponseRetriever& m_responseRetriever;
ClientContinuation m_clientContinuation;
std::promise<TaggedResponse> m_initialResponsePromise;
ProtectedPromise<TaggedResponse> m_initialResponsePromise;
/* ProtectedPromise not necessary because it's only ever being set by one thread. */
std::promise<void> m_closedPromise;
std::atomic<bool> m_isClosed;
};
} // namespace Eventstreamrpc
} // namespace Aws
Loading