Skip to content

Commit 376ffaf

Browse files
authored
Force stop iteration after local response is sent (proxy-wasm#88)
Signed-off-by: mathetake <[email protected]>
1 parent 64313a6 commit 376ffaf

File tree

3 files changed

+72
-48
lines changed

3 files changed

+72
-48
lines changed

include/proxy-wasm/context.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ class ContextBase : public RootInterface,
167167
// Called before deleting the context.
168168
virtual void destroy();
169169

170+
// Called to raise the flag which indicates that the context should stop iteration regardless of
171+
// returned filter status from Proxy-Wasm extensions. For example, we ignore
172+
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
173+
void stopIteration() { stop_iteration_ = true; };
174+
170175
/**
171176
* Calls into the VM.
172177
* These are implemented by the proxy-independent host code. They are virtual to support some
@@ -385,6 +390,14 @@ class ContextBase : public RootInterface,
385390
std::shared_ptr<PluginBase> plugin_;
386391
bool in_vm_context_created_ = false;
387392
bool destroyed_ = false;
393+
bool stop_iteration_ = false;
394+
395+
private:
396+
// helper functions
397+
FilterHeadersStatus convertVmCallResultToFilterHeadersStatus(uint64_t result);
398+
FilterDataStatus convertVmCallResultToFilterDataStatus(uint64_t result);
399+
FilterTrailersStatus convertVmCallResultToFilterTrailersStatus(uint64_t result);
400+
FilterMetadataStatus convertVmCallResultToFilterMetadataStatus(uint64_t result);
388401
};
389402

390403
class DeferAfterCallActions {

src/context.cc

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -477,93 +477,71 @@ FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_
477477
CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue,
478478
FilterHeadersStatus::StopIteration);
479479
DeferAfterCallActions actions(this);
480-
auto result = wasm_->on_request_headers_abi_01_
481-
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
482-
: wasm_
483-
->on_request_headers_abi_02_(this, id_, headers,
484-
static_cast<uint32_t>(end_of_stream))
485-
.u64_;
486-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
487-
return FilterHeadersStatus::StopAllIterationAndWatermark;
488-
return static_cast<FilterHeadersStatus>(result);
480+
return convertVmCallResultToFilterHeadersStatus(
481+
wasm_->on_request_headers_abi_01_
482+
? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_
483+
: wasm_
484+
->on_request_headers_abi_02_(this, id_, headers,
485+
static_cast<uint32_t>(end_of_stream))
486+
.u64_);
489487
}
490488

491489
FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) {
492490
CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer);
493491
DeferAfterCallActions actions(this);
494-
auto result =
495-
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
496-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
497-
return FilterDataStatus::StopIterationNoBuffer;
498-
return static_cast<FilterDataStatus>(result);
492+
return convertVmCallResultToFilterDataStatus(
493+
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_);
499494
}
500495

501496
FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
502497
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
503498
FilterTrailersStatus::StopIteration);
504499
DeferAfterCallActions actions(this);
505-
if (static_cast<FilterTrailersStatus>(wasm_->on_request_trailers_(this, id_, trailers).u64_) ==
506-
FilterTrailersStatus::Continue) {
507-
return FilterTrailersStatus::Continue;
508-
}
509-
return FilterTrailersStatus::StopIteration;
500+
return convertVmCallResultToFilterTrailersStatus(
501+
wasm_->on_request_trailers_(this, id_, trailers).u64_);
510502
}
511503

512504
FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) {
513505
CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
514506
DeferAfterCallActions actions(this);
515-
if (static_cast<FilterMetadataStatus>(wasm_->on_request_metadata_(this, id_, elements).u64_) ==
516-
FilterMetadataStatus::Continue) {
517-
return FilterMetadataStatus::Continue;
518-
}
519-
return FilterMetadataStatus::Continue; // This is currently the only return code.
507+
return convertVmCallResultToFilterMetadataStatus(
508+
wasm_->on_request_metadata_(this, id_, elements).u64_);
520509
}
521510

522511
FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) {
523512
CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_,
524513
FilterHeadersStatus::Continue, FilterHeadersStatus::StopIteration);
525514
DeferAfterCallActions actions(this);
526-
auto result = wasm_->on_response_headers_abi_01_
527-
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
528-
: wasm_
529-
->on_response_headers_abi_02_(this, id_, headers,
530-
static_cast<uint32_t>(end_of_stream))
531-
.u64_;
532-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
533-
return FilterHeadersStatus::StopAllIterationAndWatermark;
534-
return static_cast<FilterHeadersStatus>(result);
515+
return convertVmCallResultToFilterHeadersStatus(
516+
wasm_->on_response_headers_abi_01_
517+
? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_
518+
: wasm_
519+
->on_response_headers_abi_02_(this, id_, headers,
520+
static_cast<uint32_t>(end_of_stream))
521+
.u64_);
535522
}
536523

537524
FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) {
538525
CHECK_HTTP(on_response_body_, FilterDataStatus::Continue,
539526
FilterDataStatus::StopIterationNoBuffer);
540527
DeferAfterCallActions actions(this);
541-
auto result =
542-
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
543-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
544-
return FilterDataStatus::StopIterationNoBuffer;
545-
return static_cast<FilterDataStatus>(result);
528+
return convertVmCallResultToFilterDataStatus(
529+
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_);
546530
}
547531

548532
FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
549533
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
550534
FilterTrailersStatus::StopIteration);
551535
DeferAfterCallActions actions(this);
552-
if (static_cast<FilterTrailersStatus>(wasm_->on_response_trailers_(this, id_, trailers).u64_) ==
553-
FilterTrailersStatus::Continue) {
554-
return FilterTrailersStatus::Continue;
555-
}
556-
return FilterTrailersStatus::StopIteration;
536+
return convertVmCallResultToFilterTrailersStatus(
537+
wasm_->on_response_trailers_(this, id_, trailers).u64_);
557538
}
558539

559540
FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) {
560541
CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue);
561542
DeferAfterCallActions actions(this);
562-
if (static_cast<FilterMetadataStatus>(wasm_->on_response_metadata_(this, id_, elements).u64_) ==
563-
FilterMetadataStatus::Continue) {
564-
return FilterMetadataStatus::Continue;
565-
}
566-
return FilterMetadataStatus::Continue; // This is currently the only return code.
543+
return convertVmCallResultToFilterMetadataStatus(
544+
wasm_->on_response_metadata_(this, id_, elements).u64_);
567545
}
568546

569547
void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,
@@ -643,6 +621,38 @@ WasmResult ContextBase::setTimerPeriod(std::chrono::milliseconds period,
643621
return WasmResult::Ok;
644622
}
645623

624+
FilterHeadersStatus ContextBase::convertVmCallResultToFilterHeadersStatus(uint64_t result) {
625+
if (stop_iteration_ ||
626+
result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
627+
stop_iteration_ = false;
628+
return FilterHeadersStatus::StopAllIterationAndWatermark;
629+
}
630+
return static_cast<FilterHeadersStatus>(result);
631+
}
632+
633+
FilterDataStatus ContextBase::convertVmCallResultToFilterDataStatus(uint64_t result) {
634+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer)) {
635+
stop_iteration_ = false;
636+
return FilterDataStatus::StopIterationNoBuffer;
637+
}
638+
return static_cast<FilterDataStatus>(result);
639+
}
640+
641+
FilterTrailersStatus ContextBase::convertVmCallResultToFilterTrailersStatus(uint64_t result) {
642+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterTrailersStatus::StopIteration)) {
643+
stop_iteration_ = false;
644+
return FilterTrailersStatus::StopIteration;
645+
}
646+
return static_cast<FilterTrailersStatus>(result);
647+
}
648+
649+
FilterMetadataStatus ContextBase::convertVmCallResultToFilterMetadataStatus(uint64_t result) {
650+
if (static_cast<FilterMetadataStatus>(result) == FilterMetadataStatus::Continue) {
651+
return FilterMetadataStatus::Continue;
652+
}
653+
return FilterMetadataStatus::Continue; // This is currently the only return code.
654+
}
655+
646656
ContextBase::~ContextBase() {
647657
// Do not remove vm or root contexts which have the same lifetime as wasm_.
648658
if (parent_context_id_) {

src/exports.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co
186186
auto additional_headers = toPairs(additional_response_header_pairs.value());
187187
context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code,
188188
details.value());
189+
context->stopIteration();
189190
return WasmResult::Ok;
190191
}
191192

0 commit comments

Comments
 (0)