Skip to content

Commit 2814f7b

Browse files
committed
tmp
Signed-off-by: mathetake <[email protected]>
1 parent 31f3184 commit 2814f7b

File tree

3 files changed

+79
-36
lines changed

3 files changed

+79
-36
lines changed

include/proxy-wasm/context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ class ContextBase : public RootInterface,
170170
// Called before deleting the context.
171171
virtual void destroy();
172172

173+
// Called to raise the flag which indicates that the context should stop iteration regardless of
174+
// returned filter status from WASM extensions. For example, we ignore
175+
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
176+
void stopIteration() { stop_iteration_ = true; };
177+
173178
/**
174179
* Calls into the VM.
175180
* These are implemented by the proxy-independent host code. They are virtual to support some
@@ -388,6 +393,7 @@ class ContextBase : public RootInterface,
388393
std::shared_ptr<PluginBase> temp_plugin_; // Remove once ABI v0.1.0 is gone.
389394
bool in_vm_context_created_ = false;
390395
bool destroyed_ = false;
396+
bool stop_iteration_ = false;
391397
};
392398

393399
class DeferAfterCallActions {

src/context.cc

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16+
#include "include/proxy-wasm/context.h"
17+
#include "include/proxy-wasm/wasm.h"
18+
#include <cassert>
1619
#include <deque>
1720
#include <map>
1821
#include <memory>
1922
#include <mutex>
2023
#include <unordered_map>
2124
#include <unordered_set>
2225

23-
#include "include/proxy-wasm/context.h"
24-
#include "include/proxy-wasm/wasm.h"
25-
2626
#define CHECK_FAIL(_call, _stream_type, _return_open, _return_closed) \
2727
if (isFailed()) { \
2828
if (plugin_->fail_open_) { \
@@ -469,44 +469,62 @@ template <typename P> static uint32_t headerSize(const P &p) { return p ? p->siz
469469
FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_stream) {
470470
CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue,
471471
FilterHeadersStatus::StopIteration);
472-
DeferAfterCallActions actions(this);
473-
auto result = wasm_->on_request_headers_abi_01_
474-
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
475-
: wasm_
476-
->on_request_headers_abi_02_(this, id_, headers,
477-
static_cast<uint32_t>(end_of_stream))
478-
.u64_;
479-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
472+
uint64_t result;
473+
{
474+
DeferAfterCallActions actions(this);
475+
result = wasm_->on_request_headers_abi_01_
476+
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
477+
: wasm_
478+
->on_request_headers_abi_02_(this, id_, headers,
479+
static_cast<uint32_t>(end_of_stream))
480+
.u64_;
481+
}
482+
483+
if (stop_iteration_) {
484+
return FilterHeadersStatus::StopIteration;
485+
} else if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
480486
return FilterHeadersStatus::StopAllIterationAndWatermark;
487+
}
481488
return static_cast<FilterHeadersStatus>(result);
482489
}
483490

484491
FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) {
485492
CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
486-
DeferAfterCallActions actions(this);
487-
auto result =
488-
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
489-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
493+
uint64_t result;
494+
{
495+
DeferAfterCallActions actions(this);
496+
result =
497+
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
498+
}
499+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
490500
return FilterDataStatus::StopIterationNoBuffer;
491501
return static_cast<FilterDataStatus>(result);
492502
}
493503

494504
FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
495505
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
496506
FilterTrailersStatus::StopIteration);
497-
DeferAfterCallActions actions(this);
498-
if (static_cast<FilterTrailersStatus>(wasm_->on_request_trailers_(this, id_, trailers).u64_) ==
499-
FilterTrailersStatus::Continue) {
507+
508+
uint64_t result;
509+
{
510+
DeferAfterCallActions actions(this);
511+
result = wasm_->on_request_trailers_(this, id_, trailers).u64_;
512+
}
513+
if (!stop_iteration_ &&
514+
static_cast<FilterTrailersStatus>(result) == FilterTrailersStatus::Continue) {
500515
return FilterTrailersStatus::Continue;
501516
}
502517
return FilterTrailersStatus::StopIteration;
503518
}
504519

505520
FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
506521
CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
507-
DeferAfterCallActions actions(this);
508-
if (static_cast<FilterMetadataStatus>(wasm_->on_request_metadata_(this, id_, elements).u64_) ==
509-
FilterMetadataStatus::Continue) {
522+
uint64_t result;
523+
{
524+
DeferAfterCallActions actions(this);
525+
result = wasm_->on_request_metadata_(this, id_, elements).u64_;
526+
}
527+
if (static_cast<FilterMetadataStatus>(result) == FilterMetadataStatus::Continue) {
510528
return FilterMetadataStatus::Continue;
511529
}
512530
return FilterMetadataStatus::Continue; // This is currently the only return code.
@@ -515,35 +533,53 @@ FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
515533
FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) {
516534
CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_,
517535
FilterHeadersStatus::Continue, FilterHeadersStatus::StopIteration);
518-
DeferAfterCallActions actions(this);
519-
auto result = wasm_->on_response_headers_abi_01_
520-
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
521-
: wasm_
522-
->on_response_headers_abi_02_(this, id_, headers,
523-
static_cast<uint32_t>(end_of_stream))
524-
.u64_;
525-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
536+
537+
uint64_t result;
538+
{
539+
DeferAfterCallActions actions(this);
540+
result = wasm_->on_response_headers_abi_01_
541+
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
542+
: wasm_
543+
->on_response_headers_abi_02_(this, id_, headers,
544+
static_cast<uint32_t>(end_of_stream))
545+
.u64_;
546+
}
547+
548+
if (stop_iteration_) {
549+
return FilterHeadersStatus::StopIteration;
550+
} else if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
526551
return FilterHeadersStatus::StopAllIterationAndWatermark;
552+
}
527553
return static_cast<FilterHeadersStatus>(result);
528554
}
529555

530556
FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) {
531557
CHECK_HTTP(on_response_body_, FilterDataStatus::Continue,
532558
FilterDataStatus::StopIterationNoBuffer);
533-
DeferAfterCallActions actions(this);
534-
auto result =
535-
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
536-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
559+
560+
uint64_t result;
561+
{
562+
563+
DeferAfterCallActions actions(this);
564+
result =
565+
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
566+
}
567+
568+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
537569
return FilterDataStatus::StopIterationNoBuffer;
538570
return static_cast<FilterDataStatus>(result);
539571
}
540572

541573
FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
542574
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
543575
FilterTrailersStatus::StopIteration);
544-
DeferAfterCallActions actions(this);
545-
if (static_cast<FilterTrailersStatus>(wasm_->on_response_trailers_(this, id_, trailers).u64_) ==
546-
FilterTrailersStatus::Continue) {
576+
uint64_t result;
577+
{
578+
DeferAfterCallActions actions(this);
579+
result = wasm_->on_response_trailers_(this, id_, trailers).u64_;
580+
}
581+
if (!stop_iteration_ &&
582+
static_cast<FilterTrailersStatus>(result) == FilterTrailersStatus::Continue) {
547583
return FilterTrailersStatus::Continue;
548584
}
549585
return FilterTrailersStatus::StopIteration;

src/exports.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co
186186
auto additional_headers = toPairs(additional_response_header_pairs.value());
187187
context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code,
188188
details.value());
189+
context->wasm()->addAfterVmCallAction([context] { context->stopIteration(); });
189190
return WasmResult::Ok;
190191
}
191192

0 commit comments

Comments
 (0)