Skip to content

Refactor around VM failure check on Http/Tcp callbacks. #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2021
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 100 additions & 76 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,21 @@
#include "src/shared_data.h"
#include "src/shared_queue.h"

#define CHECK_FAIL(_call, _stream_type, _return_open, _return_closed) \
#define CHECK_FAIL(_stream_type, _stream_type2, _return_open, _return_closed) \
if (isFailed()) { \
if (plugin_->fail_open_) { \
return _return_open; \
} else { \
failStream(_stream_type); \
failStream(_stream_type2); \
return _return_closed; \
} \
} else { \
if (!wasm_->_call) { \
return _return_open; \
} \
}

#define CHECK_FAIL2(_call1, _call2, _stream_type, _return_open, _return_closed) \
if (isFailed()) { \
if (plugin_->fail_open_) { \
return _return_open; \
} else { \
failStream(_stream_type); \
return _return_closed; \
} \
} else { \
if (!wasm_->_call1 && !wasm_->_call2) { \
return _return_open; \
} \
}

#define CHECK_HTTP(_call, _return_open, _return_closed) \
CHECK_FAIL(_call, WasmStreamType::Request, _return_open, _return_closed)
#define CHECK_HTTP2(_call1, _call2, _return_open, _return_closed) \
CHECK_FAIL2(_call1, _call2, WasmStreamType::Request, _return_open, _return_closed)
#define CHECK_NET(_call, _return_open, _return_closed) \
CHECK_FAIL(_call, WasmStreamType::Downstream, _return_open, _return_closed)
#define CHECK_FAIL_HTTP(_return_open, _return_closed) \
CHECK_FAIL(WasmStreamType::Request, WasmStreamType::Response, _return_open, _return_closed)
#define CHECK_FAIL_NET(_return_open, _return_closed) \
CHECK_FAIL(WasmStreamType::Downstream, WasmStreamType::Upstream, _return_open, _return_closed)

namespace proxy_wasm {

Expand Down Expand Up @@ -263,30 +244,44 @@ void ContextBase::onForeignFunction(uint32_t foreign_function_id, uint32_t data_
}

FilterStatus ContextBase::onNetworkNewConnection() {
CHECK_NET(on_new_connection_, FilterStatus::Continue, FilterStatus::StopIteration);
DeferAfterCallActions actions(this);
if (wasm_->on_new_connection_(this, id_).u64_ == 0) {
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
if (!wasm_->on_new_connection_) {
return FilterStatus::Continue;
}
return FilterStatus::StopIteration;
DeferAfterCallActions actions(this);
const auto result = wasm_->on_new_connection_(this, id_).u64_;
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

FilterStatus ContextBase::onDownstreamData(uint32_t data_length, bool end_of_stream) {
CHECK_NET(on_downstream_data_, FilterStatus::Continue, FilterStatus::StopIteration);
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
if (!wasm_->on_downstream_data_) {
return FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
auto result = wasm_->on_downstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
auto result = wasm_
->on_downstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream))
.u64_;
// TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values.
return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

FilterStatus ContextBase::onUpstreamData(uint32_t data_length, bool end_of_stream) {
CHECK_NET(on_upstream_data_, FilterStatus::Continue, FilterStatus::StopIteration);
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
if (!wasm_->on_upstream_data_) {
return FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
auto result = wasm_->on_upstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
auto result = wasm_
->on_upstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream))
.u64_;
// TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values.
return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration);
return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

void ContextBase::onDownstreamConnectionClose(CloseType close_type) {
Expand All @@ -307,74 +302,103 @@ void ContextBase::onUpstreamConnectionClose(CloseType close_type) {
template <typename P> static uint32_t headerSize(const P &p) { return p ? p->size() : 0; }

FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_stream) {
CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue,
FilterHeadersStatus::StopAllIterationAndWatermark);
CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
if (!wasm_->on_request_headers_abi_01_ && !wasm_->on_request_headers_abi_02_) {
return FilterHeadersStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
const auto result = wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_;
CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
return convertVmCallResultToFilterHeadersStatus(result);
}

FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) {
CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
if (!wasm_->on_request_body_) {
return FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterDataStatus(
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_);
const auto result =
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
return convertVmCallResultToFilterDataStatus(result);
}

FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration);
if (!wasm_->on_request_trailers_) {
return FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_request_trailers_(this, id_, trailers).u64_);
const auto result = wasm_->on_request_trailers_(this, id_, trailers).u64_;
CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration);
return convertVmCallResultToFilterTrailersStatus(result);
}

FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
if (!wasm_->on_request_metadata_) {
return FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_request_metadata_(this, id_, elements).u64_);
const auto result = wasm_->on_request_metadata_(this, id_, elements).u64_;
CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
return convertVmCallResultToFilterMetadataStatus(result);
}

FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) {
CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_,
FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
if (!wasm_->on_response_headers_abi_01_ && !wasm_->on_response_headers_abi_02_) {
return FilterHeadersStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
const auto result = wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_;
CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
return convertVmCallResultToFilterHeadersStatus(result);
}

FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) {
CHECK_HTTP(on_response_body_, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
if (!wasm_->on_response_body_) {
return FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterDataStatus(
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_);
const auto result =
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
return convertVmCallResultToFilterDataStatus(result);
}

FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration);
if (!wasm_->on_response_trailers_) {
return FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_response_trailers_(this, id_, trailers).u64_);
const auto result = wasm_->on_response_trailers_(this, id_, trailers).u64_;
CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration);
return convertVmCallResultToFilterTrailersStatus(result);
}

FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) {
CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
if (!wasm_->on_response_metadata_) {
return FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_response_metadata_(this, id_, elements).u64_);
const auto result = wasm_->on_response_metadata_(this, id_, elements).u64_;
CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
return convertVmCallResultToFilterMetadataStatus(result);
}

void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,
Expand Down