diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index bf770a7bf..c373cf324 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -389,6 +389,7 @@ class ContextBase : public RootInterface, std::shared_ptr temp_plugin_; // Remove once ABI v0.1.0 is gone. bool in_vm_context_created_ = false; bool destroyed_ = false; + bool stream_failed_ = false; // Set true after failStream is called in case of VM failure. private: // helper functions diff --git a/src/context.cc b/src/context.cc index 6825f1336..a13f79000 100644 --- a/src/context.cc +++ b/src/context.cc @@ -25,40 +25,22 @@ #include "src/shared_data.h" #include "src/shared_queue.h" -#define CHECK_FAIL(_call, _stream_type, _return_open, _return_closed) \ +#define CHECK_FAIL(_stream_type, _stream_type2, _return_open, _return_closed) \ if (isFailed()) { \ if (plugin_->fail_open_) { \ return _return_open; \ - } else { \ + } else if (!stream_failed_) { \ failStream(_stream_type); \ - return _return_closed; \ - } \ - } else { \ - if (!wasm_->_call) { \ - return _return_open; \ - } \ - } - -#define CHECK_FAIL2(_call1, _call2, _stream_type, _return_open, _return_closed) \ - if (isFailed()) { \ - if (plugin_->fail_open_) { \ - return _return_open; \ - } else { \ - failStream(_stream_type); \ - return _return_closed; \ - } \ - } else { \ - if (!wasm_->_call1 && !wasm_->_call2) { \ - return _return_open; \ + failStream(_stream_type2); \ + stream_failed_ = true; \ } \ + return _return_closed; \ } -#define CHECK_HTTP(_call, _return_open, _return_closed) \ - CHECK_FAIL(_call, WasmStreamType::Request, _return_open, _return_closed) -#define CHECK_HTTP2(_call1, _call2, _return_open, _return_closed) \ - CHECK_FAIL2(_call1, _call2, WasmStreamType::Request, _return_open, _return_closed) -#define CHECK_NET(_call, _return_open, _return_closed) \ - CHECK_FAIL(_call, WasmStreamType::Downstream, _return_open, _return_closed) +#define CHECK_FAIL_HTTP(_return_open, _return_closed) \ + CHECK_FAIL(WasmStreamType::Request, WasmStreamType::Response, _return_open, _return_closed) +#define CHECK_FAIL_NET(_return_open, _return_closed) \ + CHECK_FAIL(WasmStreamType::Downstream, WasmStreamType::Upstream, _return_open, _return_closed) namespace proxy_wasm { @@ -263,30 +245,40 @@ void ContextBase::onForeignFunction(uint32_t foreign_function_id, uint32_t data_ } FilterStatus ContextBase::onNetworkNewConnection() { - CHECK_NET(on_new_connection_, FilterStatus::Continue, FilterStatus::StopIteration); - DeferAfterCallActions actions(this); - if (wasm_->on_new_connection_(this, id_).u64_ == 0) { + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + if (!wasm_->on_new_connection_) { return FilterStatus::Continue; } - return FilterStatus::StopIteration; + DeferAfterCallActions actions(this); + const auto result = wasm_->on_new_connection_(this, id_); + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; } FilterStatus ContextBase::onDownstreamData(uint32_t data_length, bool end_of_stream) { - CHECK_NET(on_downstream_data_, FilterStatus::Continue, FilterStatus::StopIteration); + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + if (!wasm_->on_downstream_data_) { + return FilterStatus::Continue; + } DeferAfterCallActions actions(this); auto result = wasm_->on_downstream_data_(this, id_, static_cast(data_length), static_cast(end_of_stream)); // TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values. - return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; } FilterStatus ContextBase::onUpstreamData(uint32_t data_length, bool end_of_stream) { - CHECK_NET(on_upstream_data_, FilterStatus::Continue, FilterStatus::StopIteration); + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + if (!wasm_->on_upstream_data_) { + return FilterStatus::Continue; + } DeferAfterCallActions actions(this); auto result = wasm_->on_upstream_data_(this, id_, static_cast(data_length), static_cast(end_of_stream)); // TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values. - return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; + CHECK_FAIL_NET(FilterStatus::Continue, FilterStatus::StopIteration); + return result == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; } void ContextBase::onDownstreamConnectionClose(CloseType close_type) { @@ -307,74 +299,99 @@ void ContextBase::onUpstreamConnectionClose(CloseType close_type) { template static uint32_t headerSize(const P &p) { return p ? p->size() : 0; } FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_stream) { - CHECK_HTTP2(on_request_headers_abi_01_, on_request_headers_abi_02_, FilterHeadersStatus::Continue, - FilterHeadersStatus::StopAllIterationAndWatermark); + CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark); + if (!wasm_->on_request_headers_abi_01_ && !wasm_->on_request_headers_abi_02_) { + return FilterHeadersStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterHeadersStatus( - wasm_->on_request_headers_abi_01_ - ? wasm_->on_request_headers_abi_01_(this, id_, headers).u64_ - : wasm_ - ->on_request_headers_abi_02_(this, id_, headers, - static_cast(end_of_stream)) - .u64_); + const auto result = wasm_->on_request_headers_abi_01_ + ? wasm_->on_request_headers_abi_01_(this, id_, headers) + : wasm_->on_request_headers_abi_02_(this, id_, headers, + static_cast(end_of_stream)); + CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark); + return convertVmCallResultToFilterHeadersStatus(result); } FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) { - CHECK_HTTP(on_request_body_, FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); + CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); + if (!wasm_->on_request_body_) { + return FilterDataStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterDataStatus( - wasm_->on_request_body_(this, id_, data_length, static_cast(end_of_stream)).u64_); + const auto result = + wasm_->on_request_body_(this, id_, data_length, static_cast(end_of_stream)); + CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); + return convertVmCallResultToFilterDataStatus(result); } FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) { - CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue, - FilterTrailersStatus::StopIteration); + CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); + if (!wasm_->on_request_trailers_) { + return FilterTrailersStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterTrailersStatus( - wasm_->on_request_trailers_(this, id_, trailers).u64_); + const auto result = wasm_->on_request_trailers_(this, id_, trailers); + CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); + return convertVmCallResultToFilterTrailersStatus(result); } FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) { - CHECK_HTTP(on_request_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + if (!wasm_->on_request_metadata_) { + return FilterMetadataStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterMetadataStatus( - wasm_->on_request_metadata_(this, id_, elements).u64_); + const auto result = wasm_->on_request_metadata_(this, id_, elements); + CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + return convertVmCallResultToFilterMetadataStatus(result); } FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of_stream) { - CHECK_HTTP2(on_response_headers_abi_01_, on_response_headers_abi_02_, - FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark); + CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark); + if (!wasm_->on_response_headers_abi_01_ && !wasm_->on_response_headers_abi_02_) { + return FilterHeadersStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterHeadersStatus( - wasm_->on_response_headers_abi_01_ - ? wasm_->on_response_headers_abi_01_(this, id_, headers).u64_ - : wasm_ - ->on_response_headers_abi_02_(this, id_, headers, - static_cast(end_of_stream)) - .u64_); + const auto result = wasm_->on_response_headers_abi_01_ + ? wasm_->on_response_headers_abi_01_(this, id_, headers) + : wasm_->on_response_headers_abi_02_( + this, id_, headers, static_cast(end_of_stream)); + CHECK_FAIL_HTTP(FilterHeadersStatus::Continue, FilterHeadersStatus::StopAllIterationAndWatermark); + return convertVmCallResultToFilterHeadersStatus(result); } FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) { - CHECK_HTTP(on_response_body_, FilterDataStatus::Continue, - FilterDataStatus::StopIterationNoBuffer); + CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); + if (!wasm_->on_response_body_) { + return FilterDataStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterDataStatus( - wasm_->on_response_body_(this, id_, body_length, static_cast(end_of_stream)).u64_); + const auto result = + wasm_->on_response_body_(this, id_, body_length, static_cast(end_of_stream)); + CHECK_FAIL_HTTP(FilterDataStatus::Continue, FilterDataStatus::StopIterationNoBuffer); + return convertVmCallResultToFilterDataStatus(result); } FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) { - CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue, - FilterTrailersStatus::StopIteration); + CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); + if (!wasm_->on_response_trailers_) { + return FilterTrailersStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterTrailersStatus( - wasm_->on_response_trailers_(this, id_, trailers).u64_); + const auto result = wasm_->on_response_trailers_(this, id_, trailers); + CHECK_FAIL_HTTP(FilterTrailersStatus::Continue, FilterTrailersStatus::StopIteration); + return convertVmCallResultToFilterTrailersStatus(result); } FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) { - CHECK_HTTP(on_response_metadata_, FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + if (!wasm_->on_response_metadata_) { + return FilterMetadataStatus::Continue; + } DeferAfterCallActions actions(this); - return convertVmCallResultToFilterMetadataStatus( - wasm_->on_response_metadata_(this, id_, elements).u64_); + const auto result = wasm_->on_response_metadata_(this, id_, elements); + CHECK_FAIL_HTTP(FilterMetadataStatus::Continue, FilterMetadataStatus::Continue); + return convertVmCallResultToFilterMetadataStatus(result); } void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,