Skip to content

Refactor around VM failure check on Http/Tcp callbacks. #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 88 additions & 58 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "src/shared_data.h"
#include "src/shared_queue.h"

#define CHECK_FAIL(_call, _stream_type, _return_open, _return_closed) \
#define PRECHECK_FAIL(_call, _stream_type, _return_open, _return_closed) \
if (isFailed()) { \
if (plugin_->fail_open_) { \
return _return_open; \
Expand All @@ -39,7 +39,7 @@
} \
}

#define CHECK_FAIL2(_call1, _call2, _stream_type, _return_open, _return_closed) \
#define PRECHECK_FAIL2(_call1, _call2, _stream_type, _return_open, _return_closed) \
if (isFailed()) { \
if (plugin_->fail_open_) { \
return _return_open; \
Expand All @@ -53,12 +53,15 @@
} \
}

#define CHECK_HTTP(_call, _return_open, _return_closed) \
CHECK_FAIL(_call, WasmStreamType::Request, _return_open, _return_closed)
#define CHECK_HTTP2(_call1, _call2, _return_open, _return_closed) \
CHECK_FAIL2(_call1, _call2, WasmStreamType::Request, _return_open, _return_closed)
#define CHECK_NET(_call, _return_open, _return_closed) \
CHECK_FAIL(_call, WasmStreamType::Downstream, _return_open, _return_closed)
#define POSTCHECK_FAIL(_stream_type, _return_open, _return_closed) \
if (isFailed()) { \
if (plugin_->fail_open_) { \
return _return_open; \
} else { \
failStream(_stream_type); \
return _return_closed; \
} \
}

namespace proxy_wasm {

Expand Down Expand Up @@ -263,30 +266,38 @@ void ContextBase::onForeignFunction(uint32_t foreign_function_id, uint32_t data_
}

FilterStatus ContextBase::onNetworkNewConnection() {
CHECK_NET(on_new_connection_, FilterStatus::Continue, FilterStatus::StopIteration);
PRECHECK_FAIL(on_new_connection_, WasmStreamType::Downstream, FilterStatus::Continue,
FilterStatus::StopIteration);
DeferAfterCallActions actions(this);
if (wasm_->on_new_connection_(this, id_).u64_ == 0) {
return FilterStatus::Continue;
}
return FilterStatus::StopIteration;
const auto call_result = wasm_->on_new_connection_(this, id_).u64_;
POSTCHECK_FAIL(WasmStreamType::Downstream, FilterStatus::Continue, FilterStatus::StopIteration);
return call_result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

FilterStatus ContextBase::onDownstreamData(uint32_t data_length, bool end_of_stream) {
CHECK_NET(on_downstream_data_, FilterStatus::Continue, FilterStatus::StopIteration);
PRECHECK_FAIL(on_downstream_data_, WasmStreamType::Downstream, FilterStatus::Continue,
FilterStatus::StopIteration);
DeferAfterCallActions actions(this);
auto result = wasm_->on_downstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
auto call_result = wasm_
->on_downstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream))
.u64_;
// TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values.
return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
POSTCHECK_FAIL(WasmStreamType::Downstream, FilterStatus::Continue, FilterStatus::StopIteration);
return call_result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

FilterStatus ContextBase::onUpstreamData(uint32_t data_length, bool end_of_stream) {
CHECK_NET(on_upstream_data_, FilterStatus::Continue, FilterStatus::StopIteration);
PRECHECK_FAIL(on_upstream_data_, WasmStreamType::Upstream, FilterStatus::Continue,
FilterStatus::StopIteration);
DeferAfterCallActions actions(this);
auto result = wasm_->on_upstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream));
auto call_result = wasm_
->on_upstream_data_(this, id_, static_cast<uint32_t>(data_length),
static_cast<uint32_t>(end_of_stream))
.u64_;
// TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values.
return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
POSTCHECK_FAIL(WasmStreamType::Upstream, FilterStatus::Continue, FilterStatus::StopIteration);
return call_result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration;
}

void ContextBase::onDownstreamConnectionClose(CloseType close_type) {
Expand All @@ -307,74 +318,93 @@ void ContextBase::onUpstreamConnectionClose(CloseType close_type) {
template <typename P> static uint32_t headerSize(const P &p) { return p ? p->size() : 0; }

FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_stream) {
CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue,
FilterHeadersStatus::StopAllIterationAndWatermark);
PRECHECK_FAIL2(on_request_headers_abi_01_, on_request_headers_abi_02_, WasmStreamType::Request,
FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
const auto call_result = wasm_->on_request_headers_abi_01_
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_request_headers_abi_02_(
this, id_, headers, static_cast<uint32_t>(end_of_stream))
.u64_;
POSTCHECK_FAIL(WasmStreamType::Request, FilterHeadersStatus::Continue,
FilterHeadersStatus::StopAllIterationAndWatermark);
return convertVmCallResultToFilterHeadersStatus(call_result);
}

FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) {
CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
PRECHECK_FAIL(on_request_body_, WasmStreamType::Request, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterDataStatus(
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_);
const auto call_result =
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
POSTCHECK_FAIL(WasmStreamType::Request, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
return convertVmCallResultToFilterDataStatus(call_result);
}

FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
PRECHECK_FAIL(on_request_trailers_, WasmStreamType::Request, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_request_trailers_(this, id_, trailers).u64_);
const auto call_result = wasm_->on_request_trailers_(this, id_, trailers).u64_;
POSTCHECK_FAIL(WasmStreamType::Request, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
return convertVmCallResultToFilterTrailersStatus(call_result);
}

FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
PRECHECK_FAIL(on_request_metadata_, WasmStreamType::Request, FilterMetadataStatus::Continue,
FilterMetadataStatus::Continue);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_request_metadata_(this, id_, elements).u64_);
}

FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) {
CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_,
FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
PRECHECK_FAIL2(on_response_headers_abi_01_, on_response_headers_abi_02_, WasmStreamType::Response,
FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterHeadersStatus(
wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(this, id_, headers,
static_cast<uint32_t>(end_of_stream))
.u64_);
const auto call_result = wasm_->on_response_headers_abi_01_
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
: wasm_
->on_response_headers_abi_02_(
this, id_, headers, static_cast<uint32_t>(end_of_stream))
.u64_;
POSTCHECK_FAIL(WasmStreamType::Response, FilterHeadersStatus::Continue,
FilterHeadersStatus::StopAllIterationAndWatermark);
return convertVmCallResultToFilterHeadersStatus(call_result);
}

FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) {
CHECK_HTTP(on_response_body_, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
PRECHECK_FAIL(on_response_body_, WasmStreamType::Response, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterDataStatus(
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_);
const auto call_result =
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
POSTCHECK_FAIL(WasmStreamType::Response, FilterDataStatus::Continue,
FilterDataStatus::StopIterationNoBuffer);
return convertVmCallResultToFilterDataStatus(call_result);
}

FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
PRECHECK_FAIL(on_response_trailers_, WasmStreamType::Response, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterTrailersStatus(
wasm_->on_response_trailers_(this, id_, trailers).u64_);
const auto call_result = wasm_->on_response_trailers_(this, id_, trailers).u64_;
POSTCHECK_FAIL(WasmStreamType::Response, FilterTrailersStatus::Continue,
FilterTrailersStatus::StopIteration);
return convertVmCallResultToFilterTrailersStatus(call_result);
}

FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) {
CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
PRECHECK_FAIL(on_response_metadata_, WasmStreamType::Response, FilterMetadataStatus::Continue,
FilterMetadataStatus::Continue);
DeferAfterCallActions actions(this);
return convertVmCallResultToFilterMetadataStatus(
wasm_->on_response_metadata_(this, id_, elements).u64_);
const auto call_result = wasm_->on_response_metadata_(this, id_, elements).u64_;
POSTCHECK_FAIL(WasmStreamType::Response, FilterMetadataStatus::Continue,
FilterMetadataStatus::Continue);
return convertVmCallResultToFilterMetadataStatus(call_result);
}

void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,
Expand Down