diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 8b823cddd..4a75297b9 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -53,7 +53,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 - name: Mount bazel cache uses: actions/cache@v1 diff --git a/BUILD b/BUILD index 9a4986ce9..965e55a0a 100644 --- a/BUILD +++ b/BUILD @@ -22,6 +22,7 @@ cc_library( copts = ["-DWITHOUT_ZLIB=1"], deps = [ ":include", + "@com_google_protobuf//:protobuf_lite", "@proxy_wasm_cpp_sdk//:api_lib", ], ) @@ -45,6 +46,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf_lite", "@proxy_wasm_cpp_sdk//:api_lib", ], ) diff --git a/WORKSPACE b/WORKSPACE index 58f5ab533..c143e5c32 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -3,13 +3,10 @@ workspace(name = "proxy_wasm_cpp_host") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") -http_archive( +git_repository( name = "proxy_wasm_cpp_sdk", - sha256 = "14f66f67e8f37ec81d28d7f5307be4407d206ac5f0daaf6d22fa5536797bcac1", - strip_prefix = "proxy-wasm-cpp-sdk-31f1fc5b7e09f231fa532d2d296e479a113c3e10", - urls = ["https://github.com/proxy-wasm/proxy-wasm-cpp-sdk/archive/31f1fc5b7e09f231fa532d2d296e479a113c3e10.tar.gz"], - patch_cmds = ["rm BUILD"], - build_file = '//bazel/external:proxy-wasm-cpp-sdk.BUILD', + remote = "https://github.com/proxy-wasm/proxy-wasm-cpp-sdk", + commit = "c12553951d01bb60cb1448ba1fcfeb8f843aad62", ) http_archive( @@ -20,7 +17,6 @@ http_archive( urls = ["https://github.com/abseil/abseil-cpp/archive/37dd2562ec830d547a1524bb306be313ac3f2556.tar.gz"], ) - # required by com_google_protobuf http_archive( name = "bazel_skylib", diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 2c9640fc8..10da5fb26 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -36,7 +36,6 @@ class WasmVm; using Pairs = std::vector>; using PairsWithStringValues = std::vector>; -using CallOnThreadFunction = std::function)>; struct BufferInterface { virtual ~BufferInterface() {} @@ -50,9 +49,12 @@ struct BufferInterface { * @param size_ptr is the location in the VM address space to place the size of the newly * allocated memory block which contains the copied bytes (e.g. length). * @return true on success. + * Length is guarantteed to be > 0 and the bounds have already been checked. */ - virtual bool copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr, - uint64_t size_ptr) const = 0; + virtual WasmResult copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr, + uint64_t size_ptr) const = 0; + + virtual WasmResult copyFrom(size_t start, size_t length, string_view data) = 0; }; // Opaque context object. @@ -94,7 +96,7 @@ class ContextBase { uint32_t id() const { return id_; } bool isVmContext() const { return id_ == 0; } bool isRootContext() const { return root_context_id_ == 0; } - ContextBase *root_context() { return root_context_; } + ContextBase *root_context() const { return root_context_; } string_view root_id() const { return isRootContext() ? root_id_ : plugin_->root_id_; } string_view log_prefix() const { return isRootContext() ? root_log_prefix_ : plugin_->log_prefix(); @@ -104,76 +106,50 @@ class ContextBase { // Called before deleting the context. virtual void destroy(); - // // VM level downcalls into the WASM code on Context(id == 0). - // virtual bool onStart(std::shared_ptr plugin); virtual bool onConfigure(std::shared_ptr plugin); - // + // Root Context downcalls into the WASM code Context(id != 0, root_context_id_ == 0); + virtual void onTick(); + // Stream downcalls on Context(id > 0). // // General stream downcall on a new stream. virtual void onCreate(uint32_t root_context_id); + // Network - virtual FilterStatus onNetworkNewConnection() { - unimplemented(); - return FilterStatus::Continue; - } - virtual FilterStatus onDownstreamData(int /* data_length */, bool /* end_of_stream */) { - unimplemented(); - return FilterStatus::Continue; - } - virtual FilterStatus onUpstreamData(int /* data_length */, bool /* end_of_stream */) { - unimplemented(); - return FilterStatus::Continue; - } + virtual FilterStatus onNetworkNewConnection(); + virtual FilterStatus onDownstreamData(int data_length, bool end_of_stream); + virtual FilterStatus onUpstreamData(int data_length, bool end_of_stream); enum class PeerType : uint32_t { Unknown = 0, Local = 1, Remote = 2, }; - virtual void onDownstreamConnectionClose(PeerType) { unimplemented(); } - virtual void onUpstreamConnectionClose(PeerType) { unimplemented(); } + virtual void onDownstreamConnectionClose(PeerType); + virtual void onUpstreamConnectionClose(PeerType); // HTTP Filter Stream Request Downcalls. - virtual FilterHeadersStatus onRequestHeaders() { - unimplemented(); - return FilterHeadersStatus::Continue; - } - virtual FilterDataStatus onRequestBody(int /* body_buffer_length */, bool /* end_of_stream */) { - unimplemented(); - return FilterDataStatus::Continue; - } - virtual FilterTrailersStatus onRequestTrailers() { - unimplemented(); - return FilterTrailersStatus::Continue; - } - virtual FilterMetadataStatus onRequestMetadata() { - unimplemented(); - return FilterMetadataStatus::Continue; - } + virtual FilterHeadersStatus onRequestHeaders(uint32_t headers); + virtual FilterDataStatus onRequestBody(uint32_t body_buffer_length, bool end_of_stream); + virtual FilterTrailersStatus onRequestTrailers(uint32_t trailers); + virtual FilterMetadataStatus onRequestMetadata(uint32_t elements); // HTTP Filter Stream Response Downcalls. - virtual FilterHeadersStatus onResponseHeaders() { - unimplemented(); - return FilterHeadersStatus::Continue; - } - virtual FilterDataStatus onResponseBody(int /* body_buffer_length */, bool /* end_of_stream */) { - unimplemented(); - return FilterDataStatus::Continue; - } - virtual FilterTrailersStatus onResponseTrailers() { - unimplemented(); - return FilterTrailersStatus::Continue; - } - virtual FilterMetadataStatus onResponseMetadata() { - unimplemented(); - return FilterMetadataStatus::Continue; - } + virtual FilterHeadersStatus onResponseHeaders(uint32_t headers); + virtual FilterDataStatus onResponseBody(uint32_t body_buffer_length, bool end_of_stream); + virtual FilterTrailersStatus onResponseTrailers(uint32_t trailers); + virtual FilterMetadataStatus onResponseMetadata(uint32_t elements); // Async call response. - virtual void onHttpCallResponse(uint32_t /* token */, uint32_t /* headers */, - uint32_t /* body_size */, uint32_t /* trailers */) {} + virtual void onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size, + uint32_t trailers); + // Grpc + virtual void onGrpcReceiveInitialMetadata(uint32_t token, uint32_t elements); + virtual void onGrpcReceiveTrailingMetadata(uint32_t token, uint32_t trailers); + virtual void onGrpcReceive(uint32_t token, uint32_t response_size); + virtual void onGrpcClose(uint32_t token, uint32_t status_code); + // Inter-VM shared queue message arrival. - virtual void onQueueReady(uint32_t /* token */) { unimplemented(); } + virtual void onQueueReady(uint32_t /* token */); // General stream downcall when the stream/vm has ended. virtual bool onDone(); // General stream downcall for logging. Occurs after onDone(). @@ -212,7 +188,7 @@ class ContextBase { } // Buffer - virtual const BufferInterface *getBuffer(WasmBufferType /* type */) { + virtual BufferInterface *getBuffer(WasmBufferType /* type */) { unimplemented(); return nullptr; } @@ -233,14 +209,16 @@ class ContextBase { // gRPC // Returns a token which will be used with the corresponding onGrpc and grpc calls. virtual WasmResult grpcCall(string_view /* grpc_service */, string_view /* service_name */, - string_view /* method_name */, string_view /* request */, + string_view /* method_name */, const Pairs & /* initial_metadata */, + string_view /* request */, const optional & /* timeout */, uint32_t * /* token_ptr */) { unimplemented(); return WasmResult::Unimplemented; } virtual WasmResult grpcStream(string_view /* grpc_service */, string_view /* service_name */, - string_view /* method_name */, uint32_t * /* token_ptr */) { + string_view /* method_name */, const Pairs & /* initial_metadata */, + uint32_t * /* token_ptr */) { unimplemented(); return WasmResult::Unimplemented; } @@ -342,7 +320,12 @@ class ContextBase { protected: friend class WasmBase; - virtual void initializeRoot(WasmBase *wasm, std::shared_ptr plugin); + // NB: initializeRootBase is non-virtual and can be called in the constructor without ambiguity. + void initializeRootBase(WasmBase *wasm, std::shared_ptr plugin); + // NB: initializeRoot is virtual and should be called only outside of the constructor. + virtual void initializeRoot(WasmBase *wasm, std::shared_ptr plugin) { + initializeRootBase(wasm, plugin); + } std::string makeRootLogPrefix(string_view vm_id) const; WasmBase *wasm_{nullptr}; diff --git a/include/proxy-wasm/exports.h b/include/proxy-wasm/exports.h index cdf6d191c..786350e6b 100644 --- a/include/proxy-wasm/exports.h +++ b/include/proxy-wasm/exports.h @@ -48,6 +48,8 @@ Word enqueue_shared_queue(void *raw_context, Word token, Word data_ptr, Word dat Word get_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word ptr_ptr, Word size_ptr); Word get_buffer_status(void *raw_context, Word type, Word length_ptr, Word flags_ptr); +Word set_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word data_ptr, + Word data_size); Word add_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, Word value_ptr, Word value_size); Word get_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, @@ -72,10 +74,11 @@ Word record_metric(void *raw_context, Word metric_id, uint64_t value); Word get_metric(void *raw_context, Word metric_id, Word result_uint64_ptr); Word grpc_call(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, - Word request_ptr, Word request_size, Word timeout_milliseconds, Word token_ptr); + Word initial_metadata_ptr, Word initial_metadata_size, Word request_ptr, + Word request_size, Word timeout_milliseconds, Word token_ptr); Word grpc_stream(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, - Word token_ptr); + Word initial_metadata_ptr, Word initial_metadata_size, Word token_ptr); Word grpc_cancel(void *raw_context, Word token); Word grpc_close(void *raw_context, Word token); Word grpc_send(void *raw_context, Word token, Word message_ptr, Word message_size, Word end_stream); @@ -92,8 +95,10 @@ Word call_foreign_function(void *raw_context, Word function_name, Word function_ Word wasi_unstable_fd_write(void *raw_context, Word fd, Word iovs, Word iovs_len, Word nwritten_ptr); +Word wasi_unstable_fd_read(void *, Word, Word, Word, Word); Word wasi_unstable_fd_seek(void *, Word, int64_t, Word, Word); Word wasi_unstable_fd_close(void *, Word); +Word wasi_unstable_fd_fdstat_get(void *, Word fd, Word statOut); Word wasi_unstable_environ_get(void *, Word, Word); Word wasi_unstable_environ_sizes_get(void *raw_context, Word count_ptr, Word buf_size_ptr); Word wasi_unstable_args_get(void *raw_context, Word argc_ptr, Word argv_buf_size_ptr); diff --git a/include/proxy-wasm/null_plugin.h b/include/proxy-wasm/null_plugin.h index 2e2337f4b..f6e1970ac 100644 --- a/include/proxy-wasm/null_plugin.h +++ b/include/proxy-wasm/null_plugin.h @@ -99,7 +99,6 @@ class NullPlugin : public NullVmPlugin { void onGrpcReceive(uint64_t context_id, uint64_t token, size_t body_size); void onGrpcClose(uint64_t context_id, uint64_t token, uint64_t status_code); - void onGrpcCreateInitialMetadata(uint64_t context_id, uint64_t token, uint64_t headers); void onGrpcReceiveInitialMetadata(uint64_t context_id, uint64_t token, uint64_t headers); void onGrpcReceiveTrailingMetadata(uint64_t context_id, uint64_t token, uint64_t trailers); diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 2f84e053b..d07de3e85 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -34,19 +34,21 @@ namespace proxy_wasm { #include "proxy_wasm_common.h" +class ContextBase; class WasmBase; -class WasmHandle; +class WasmHandleBase; using WasmForeignFunction = std::function)>; using WasmVmFactory = std::function()>; +using CallOnThreadFunction = std::function)>; // Wasm execution instance. Manages the host side of the Wasm interface. class WasmBase : public std::enable_shared_from_this { public: WasmBase(std::unique_ptr wasm_vm, string_view vm_id, string_view vm_configuration, string_view vm_key); - WasmBase(const std::shared_ptr &other, WasmVmFactory factory); + WasmBase(const std::shared_ptr &other, WasmVmFactory factory); virtual ~WasmBase(); bool initialize(const std::string &code, bool allow_precompiled = false); @@ -78,11 +80,17 @@ class WasmBase : public std::enable_shared_from_this { unimplemented(); return nullptr; } + // NB: if plugin is nullptr, then a VM Context is returned. + virtual ContextBase *createContext(std::shared_ptr plugin) { + if (plugin) + return new ContextBase(this, plugin); + return new ContextBase(this); + } - void setTickPeriod(uint32_t /* root_context_id */, std::chrono::milliseconds /* tick_period */) { - unimplemented(); + virtual void setTickPeriod(uint32_t root_context_id, std::chrono::milliseconds tick_period) { + tick_period_[root_context_id] = tick_period; } - void tickHandler(uint32_t /* root_context_idl */) { unimplemented(); } + void tick(uint32_t root_context_id); void queueReady(uint32_t root_context_id, uint32_t token); void startShutdown(); @@ -216,7 +224,7 @@ class WasmBase : public std::enable_shared_from_this { WasmCallVoid<1> on_log_; WasmCallVoid<1> on_delete_; - std::shared_ptr base_wasm_handle_; + std::shared_ptr base_wasm_handle_; // Used by the base_wasm to enable non-clonable thread local Wasm(s) to be constructed. std::string code_; @@ -235,43 +243,40 @@ class WasmBase : public std::enable_shared_from_this { uint32_t next_gauge_metric_id_ = static_cast(MetricType::Gauge); uint32_t next_histogram_metric_id_ = static_cast(MetricType::Histogram); - // Foreign Functions. - std::unordered_map foreign_functions_; - // Actions to be done after the call into the VM returns. std::deque> after_vm_call_actions_; }; // Handle which enables shutdown operations to run post deletion (e.g. post listener drain). -class WasmHandle : public std::enable_shared_from_this { +class WasmHandleBase : public std::enable_shared_from_this { public: - explicit WasmHandle(std::shared_ptr wasm) : wasm_(wasm) {} - ~WasmHandle() { wasm_->startShutdown(); } + explicit WasmHandleBase(std::shared_ptr wasm_base) : wasm_base_(wasm_base) {} + ~WasmHandleBase() { wasm_base_->startShutdown(); } - std::shared_ptr &wasm() { return wasm_; } + std::shared_ptr &wasm() { return wasm_base_; } -private: - std::shared_ptr wasm_; +protected: + std::shared_ptr wasm_base_; }; std::string makeVmKey(string_view vm_id, string_view configuration, string_view code); -using WasmHandleFactory = std::function(string_view vm_id)>; +using WasmHandleFactory = std::function(string_view vm_id)>; using WasmHandleCloneFactory = - std::function(std::shared_ptr wasm)>; + std::function(std::shared_ptr wasm)>; // Returns nullptr on failure (i.e. initialization of the VM fails). -std::shared_ptr +std::shared_ptr createWasm(std::string vm_key, std::string code, std::shared_ptr plugin, WasmHandleFactory factory, bool allow_precompiled, std::unique_ptr root_context_for_testing = nullptr); // Get an existing ThreadLocal VM matching 'vm_id' or nullptr if there isn't one. -std::shared_ptr getThreadLocalWasm(string_view vm_id); +std::shared_ptr getThreadLocalWasm(string_view vm_id); // Get an existing ThreadLocal VM matching 'vm_id' or create one using 'base_wavm' by cloning or by // using it it as a template. -std::shared_ptr getOrCreateThreadLocalWasm(std::shared_ptr &base_wasm, - std::shared_ptr plugin, - WasmHandleCloneFactory factory); +std::shared_ptr +getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, + std::shared_ptr plugin, WasmHandleCloneFactory factory); inline const std::string &WasmBase::vm_configuration() const { if (base_wasm_handle_) diff --git a/include/proxy-wasm/wasm_api_impl.h b/include/proxy-wasm/wasm_api_impl.h index 6471b2d52..a0b72b1e1 100644 --- a/include/proxy-wasm/wasm_api_impl.h +++ b/include/proxy-wasm/wasm_api_impl.h @@ -138,6 +138,12 @@ inline WasmResult proxy_get_buffer_status(WasmBufferType type, size_t *length_pt exports::get_buffer_status(current_context_, WS(type), WR(length_ptr), WR(flags_ptr))); } +inline WasmResult proxy_set_buffer_bytes(WasmBufferType type, uint64_t start, uint64_t length, + const char *data, size_t size) { + return wordToWasmResult(exports::set_buffer_bytes(current_context_, WS(type), WS(start), + WS(length), WR(data), WS(size))); +} + // Headers/Trailers/Metadata Maps inline WasmResult proxy_add_header_map_value(WasmHeaderMapType type, const char *key_ptr, size_t key_size, const char *value_ptr, @@ -191,20 +197,24 @@ inline WasmResult proxy_http_call(const char *uri_ptr, size_t uri_size, void *he inline WasmResult proxy_grpc_call(const char *service_ptr, size_t service_size, const char *service_name_ptr, size_t service_name_size, const char *method_name_ptr, size_t method_name_size, + void *initial_metadata_ptr, size_t initial_metadata_size, const char *request_ptr, size_t request_size, uint64_t timeout_milliseconds, uint32_t *token_ptr) { - return wordToWasmResult(exports::grpc_call( - current_context_, WR(service_ptr), WS(service_size), WR(service_name_ptr), - WS(service_name_size), WR(method_name_ptr), WS(method_name_size), WR(request_ptr), - WS(request_size), WS(timeout_milliseconds), WR(token_ptr))); + return wordToWasmResult( + exports::grpc_call(current_context_, WR(service_ptr), WS(service_size), WR(service_name_ptr), + WS(service_name_size), WR(method_name_ptr), WS(method_name_size), + WR(initial_metadata_ptr), WS(initial_metadata_size), WR(request_ptr), + WS(request_size), WS(timeout_milliseconds), WR(token_ptr))); } inline WasmResult proxy_grpc_stream(const char *service_ptr, size_t service_size, const char *service_name_ptr, size_t service_name_size, const char *method_name_ptr, size_t method_name_size, + void *initial_metadata_ptr, size_t initial_metadata_size, uint32_t *token_ptr) { return wordToWasmResult(exports::grpc_stream( current_context_, WR(service_ptr), WS(service_size), WR(service_name_ptr), - WS(service_name_size), WR(method_name_ptr), WS(method_name_size), WR(token_ptr))); + WS(service_name_size), WR(method_name_ptr), WS(method_name_size), WR(initial_metadata_ptr), + WS(initial_metadata_size), WR(token_ptr))); } inline WasmResult proxy_grpc_cancel(uint64_t token) { return wordToWasmResult(exports::grpc_cancel(current_context_, WS(token))); diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 25dce6dfa..381580920 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -86,11 +86,15 @@ using WasmCallback_dd = double (*)(void *, double); _f(proxy_wasm::WasmCallbackVoid<4>) _f(proxy_wasm::WasmCallbackWord<0>) \ _f(proxy_wasm::WasmCallbackWord<1>) _f(proxy_wasm::WasmCallbackWord<2>) \ _f(proxy_wasm::WasmCallbackWord<3>) _f(proxy_wasm::WasmCallbackWord<4>) \ - _f(proxy_wasm::WasmCallbackWord<5>) _f(proxy_wasm::WasmCallbackWord<6>) _f( \ - proxy_wasm::WasmCallbackWord<7>) _f(proxy_wasm::WasmCallbackWord<8>) \ - _f(proxy_wasm::WasmCallbackWord<9>) _f(proxy_wasm::WasmCallbackWord<10>) \ - _f(proxy_wasm::WasmCallback_WWl) _f(proxy_wasm::WasmCallback_WWlWW) \ - _f(proxy_wasm::WasmCallback_WWm) _f(proxy_wasm::WasmCallback_dd) + _f(proxy_wasm::WasmCallbackWord<5>) _f(proxy_wasm::WasmCallbackWord<6>) \ + _f(proxy_wasm::WasmCallbackWord<7>) _f(proxy_wasm::WasmCallbackWord<8>) \ + _f(proxy_wasm::WasmCallbackWord<9>) \ + _f(proxy_wasm::WasmCallbackWord<10>) \ + _f(proxy_wasm::WasmCallbackWord<12>) \ + _f(proxy_wasm::WasmCallback_WWl) \ + _f(proxy_wasm::WasmCallback_WWlWW) \ + _f(proxy_wasm::WasmCallback_WWm) \ + _f(proxy_wasm::WasmCallback_dd) enum class Cloneable { NotCloneable, // VMs can not be cloned and should be created from scratch. diff --git a/include/proxy-wasm/word.h b/include/proxy-wasm/word.h index b28b199a7..d2886a072 100644 --- a/include/proxy-wasm/word.h +++ b/include/proxy-wasm/word.h @@ -17,12 +17,15 @@ #include +#include "proxy_wasm_common.h" + namespace proxy_wasm { // Represents a Wasm-native word-sized datum. On 32-bit VMs, the high bits are always zero. // The Wasm/VM API treats all bits as significant. struct Word { - Word(uint64_t w) : u64_(w) {} // Implicit conversion into Word. + Word(uint64_t w) : u64_(w) {} // Implicit conversion into Word. + Word(WasmResult r) : u64_(static_cast(r)) {} // Implicit conversion into Word. uint32_t u32() const { return static_cast(u64_); } operator uint64_t() const { return u64_; } uint64_t u64_; diff --git a/src/context.cc b/src/context.cc index 6ab8f074f..2f840b04b 100644 --- a/src/context.cc +++ b/src/context.cc @@ -204,7 +204,7 @@ ContextBase::ContextBase(WasmBase *wasm) : wasm_(wasm), root_context_(this) { } ContextBase::ContextBase(WasmBase *wasm, std::shared_ptr plugin) { - initializeRoot(wasm, plugin); + initializeRootBase(wasm, plugin); } ContextBase::ContextBase(WasmBase *wasm, uint32_t root_context_id, @@ -216,7 +216,7 @@ ContextBase::ContextBase(WasmBase *wasm, uint32_t root_context_id, WasmVm *ContextBase::wasmVm() const { return wasm_->wasm_vm(); } -void ContextBase::initializeRoot(WasmBase *wasm, std::shared_ptr plugin) { +void ContextBase::initializeRootBase(WasmBase *wasm, std::shared_ptr plugin) { wasm_ = wasm; id_ = wasm->allocContextId(); root_id_ = plugin->root_id_; @@ -320,6 +320,218 @@ void ContextBase::destroy() { onDone(); } +void ContextBase::onTick() { + if (wasm_->on_tick_) { + DeferAfterCallActions actions(this); + wasm_->on_tick_(this, id_); + } +} + +FilterStatus ContextBase::onNetworkNewConnection() { + DeferAfterCallActions actions(this); + onCreate(root_context_id_); + if (!wasm_->on_new_connection_) { + return FilterStatus::Continue; + } + if (wasm_->on_new_connection_(this, id_).u64_ == 0) { + return FilterStatus::Continue; + } + return FilterStatus::StopIteration; +} + +FilterStatus ContextBase::onDownstreamData(int data_length, bool end_of_stream) { + if (!wasm_->on_downstream_data_) { + return FilterStatus::Continue; + } + DeferAfterCallActions actions(this); + auto result = wasm_->on_downstream_data_(this, id_, static_cast(data_length), + static_cast(end_of_stream)); + // TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values. + return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; +} + +FilterStatus ContextBase::onUpstreamData(int data_length, bool end_of_stream) { + if (!wasm_->on_upstream_data_) { + return FilterStatus::Continue; + } + DeferAfterCallActions actions(this); + auto result = wasm_->on_upstream_data_(this, id_, static_cast(data_length), + static_cast(end_of_stream)); + // TODO(PiotrSikora): pull Proxy-WASM's FilterStatus values. + return result.u64_ == 0 ? FilterStatus::Continue : FilterStatus::StopIteration; +} + +void ContextBase::onDownstreamConnectionClose(PeerType peer_type) { + if (wasm_->on_downstream_connection_close_) { + DeferAfterCallActions actions(this); + wasm_->on_downstream_connection_close_(this, id_, static_cast(peer_type)); + } +} + +void ContextBase::onUpstreamConnectionClose(PeerType peer_type) { + if (wasm_->on_upstream_connection_close_) { + DeferAfterCallActions actions(this); + wasm_->on_upstream_connection_close_(this, id_, static_cast(peer_type)); + } +} + +// Empty headers/trailers have zero size. +template static uint32_t headerSize(const P &p) { return p ? p->size() : 0; } + +FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers) { + DeferAfterCallActions actions(this); + onCreate(root_context_id_); + in_vm_context_created_ = true; + if (!wasm_->on_request_headers_) { + return FilterHeadersStatus::Continue; + } + if (static_cast(wasm_->on_request_headers_(this, id_, headers).u64_) == + FilterHeadersStatus::Continue) { + return FilterHeadersStatus::Continue; + } + return FilterHeadersStatus::StopIteration; +} + +FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_stream) { + if (!wasm_->on_request_body_) { + return FilterDataStatus::Continue; + } + DeferAfterCallActions actions(this); + auto result = + wasm_->on_request_body_(this, id_, data_length, static_cast(end_of_stream)).u64_; + if (result > static_cast(FilterDataStatus::StopIterationNoBuffer)) + return FilterDataStatus::StopIterationNoBuffer; + return static_cast(result); +} + +FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) { + if (!wasm_->on_request_trailers_) { + return FilterTrailersStatus::Continue; + } + DeferAfterCallActions actions(this); + if (static_cast(wasm_->on_request_trailers_(this, id_, trailers).u64_) == + FilterTrailersStatus::Continue) { + return FilterTrailersStatus::Continue; + } + return FilterTrailersStatus::StopIteration; +} + +FilterMetadataStatus ContextBase::onRequestMetadata(uint32_t elements) { + if (!wasm_->on_request_metadata_) { + return FilterMetadataStatus::Continue; + } + DeferAfterCallActions actions(this); + if (static_cast(wasm_->on_request_metadata_(this, id_, elements).u64_) == + FilterMetadataStatus::Continue) { + return FilterMetadataStatus::Continue; + } + return FilterMetadataStatus::Continue; // This is currently the only return code. +} + +FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers) { + DeferAfterCallActions actions(this); + if (!in_vm_context_created_) { + // If the request is invalid then onRequestHeaders() will not be called and neither will + // onCreate() then sendLocalReply be called which will call this function. In this case we + // need to call onCreate() so that the Context inside the VM is created before the + // onResponseHeaders() call. + onCreate(root_context_id_); + in_vm_context_created_ = true; + } + if (!wasm_->on_response_headers_) { + return FilterHeadersStatus::Continue; + } + if (static_cast(wasm_->on_response_headers_(this, id_, headers).u64_) == + FilterHeadersStatus::Continue) { + return FilterHeadersStatus::Continue; + } + return FilterHeadersStatus::StopIteration; +} + +FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_stream) { + if (!wasm_->on_response_body_) { + return FilterDataStatus::Continue; + } + DeferAfterCallActions actions(this); + auto result = + wasm_->on_response_body_(this, id_, body_length, static_cast(end_of_stream)).u64_; + if (result > static_cast(FilterDataStatus::StopIterationNoBuffer)) + return FilterDataStatus::StopIterationNoBuffer; + return static_cast(result); +} + +FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) { + if (!wasm_->on_response_trailers_) { + return FilterTrailersStatus::Continue; + } + DeferAfterCallActions actions(this); + if (static_cast(wasm_->on_response_trailers_(this, id_, trailers).u64_) == + FilterTrailersStatus::Continue) { + return FilterTrailersStatus::Continue; + } + return FilterTrailersStatus::StopIteration; +} + +FilterMetadataStatus ContextBase::onResponseMetadata(uint32_t elements) { + if (!wasm_->on_response_metadata_) { + return FilterMetadataStatus::Continue; + } + DeferAfterCallActions actions(this); + if (static_cast(wasm_->on_response_metadata_(this, id_, elements).u64_) == + FilterMetadataStatus::Continue) { + return FilterMetadataStatus::Continue; + } + return FilterMetadataStatus::Continue; // This is currently the only return code. +} + +void ContextBase::onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size, + uint32_t trailers) { + if (!wasm_->on_http_call_response_) { + return; + } + DeferAfterCallActions actions(this); + wasm_->on_http_call_response_(this, id_, token, headers, body_size, trailers); +} + +void ContextBase::onQueueReady(uint32_t token) { + if (wasm_->on_queue_ready_) { + DeferAfterCallActions actions(this); + wasm_->on_queue_ready_(this, id_, token); + } +} + +void ContextBase::onGrpcReceiveInitialMetadata(uint32_t token, uint32_t elements) { + if (!wasm_->on_grpc_receive_initial_metadata_) { + return; + } + DeferAfterCallActions actions(this); + wasm_->on_grpc_receive_initial_metadata_(this, id_, token, elements); +} + +void ContextBase::onGrpcReceiveTrailingMetadata(uint32_t token, uint32_t trailers) { + if (!wasm_->on_grpc_receive_trailing_metadata_) { + return; + } + DeferAfterCallActions actions(this); + wasm_->on_grpc_receive_trailing_metadata_(this, id_, token, trailers); +} + +void ContextBase::onGrpcReceive(uint32_t token, uint32_t response_size) { + if (!wasm_->on_grpc_receive_) { + return; + } + DeferAfterCallActions actions(this); + wasm_->on_grpc_receive_(this, id_, token, response_size); +} + +void ContextBase::onGrpcClose(uint32_t token, uint32_t status_code) { + if (!wasm_->on_grpc_close_) { + return; + } + DeferAfterCallActions actions(this); + wasm_->on_grpc_close_(this, id_, token, status_code); +} + bool ContextBase::onDone() { DeferAfterCallActions actions(this); if (wasm_->on_done_) { diff --git a/src/exports.cc b/src/exports.cc index cad3d9505..7586cff53 100644 --- a/src/exports.cc +++ b/src/exports.cc @@ -29,8 +29,6 @@ namespace exports { namespace { -inline Word wasmResultToWord(WasmResult r) { return Word(static_cast(r)); } - ContextBase *ContextOrEffectiveContext(ContextBase *context) { if (effective_context_id_ == 0) { return context; @@ -124,43 +122,43 @@ bool getPairs(ContextBase *context, const Pairs &result, uint64_t ptr_ptr, uint6 Word set_property(void *raw_context, Word key_ptr, Word key_size, Word value_ptr, Word value_size) { auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); - auto value = context->wasmVm()->getMemory(value_ptr.u64_, value_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); + auto value = context->wasmVm()->getMemory(value_ptr, value_size); if (!key || !value) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(context->setProperty(key.value(), value.value())); + return context->setProperty(key.value(), value.value()); } // Generic selector Word get_property(void *raw_context, Word path_ptr, Word path_size, Word value_ptr_ptr, Word value_size_ptr) { auto context = WASM_CONTEXT(raw_context); - auto path = context->wasmVm()->getMemory(path_ptr.u64_, path_size.u64_); + auto path = context->wasmVm()->getMemory(path_ptr, path_size); if (!path.has_value()) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } std::string value; auto result = context->getProperty(path.value(), &value); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->copyToPointerSize(value, value_ptr_ptr.u64_, value_size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->copyToPointerSize(value, value_ptr_ptr, value_size_ptr)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word get_status(void *raw_context, Word code_ptr, Word value_ptr_ptr, Word value_size_ptr) { auto context = WASM_CONTEXT(raw_context); auto status = context->getStatus(); - if (!context->wasm()->setDatatype(code_ptr.u64_, status.first)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(code_ptr, status.first)) { + return WasmResult::InvalidMemoryAccess; } - if (!context->wasm()->copyToPointerSize(status.second, value_ptr_ptr.u64_, value_size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->copyToPointerSize(status.second, value_ptr_ptr, value_size_ptr)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } // HTTP @@ -169,13 +167,13 @@ Word get_status(void *raw_context, Word code_ptr, Word value_ptr_ptr, Word value Word continue_request(void *raw_context) { auto context = WASM_CONTEXT(raw_context); context->continueRequest(); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word continue_response(void *raw_context) { auto context = WASM_CONTEXT(raw_context); context->continueResponse(); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word send_local_response(void *raw_context, Word response_code, Word response_code_details_ptr, @@ -184,49 +182,49 @@ Word send_local_response(void *raw_context, Word response_code, Word response_co Word additional_response_header_pairs_size, Word grpc_code) { auto context = WASM_CONTEXT(raw_context); auto details = - context->wasmVm()->getMemory(response_code_details_ptr.u64_, response_code_details_size.u64_); - auto body = context->wasmVm()->getMemory(body_ptr.u64_, body_size.u64_); + context->wasmVm()->getMemory(response_code_details_ptr, response_code_details_size); + auto body = context->wasmVm()->getMemory(body_ptr, body_size); auto additional_response_header_pairs = context->wasmVm()->getMemory( - additional_response_header_pairs_ptr.u64_, additional_response_header_pairs_size.u64_); + additional_response_header_pairs_ptr, additional_response_header_pairs_size); if (!details || !body || !additional_response_header_pairs) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } auto additional_headers = toPairs(additional_response_header_pairs.value()); - context->sendLocalResponse(response_code.u64_, body.value(), std::move(additional_headers), - grpc_code.u64_, details.value()); - return wasmResultToWord(WasmResult::Ok); + context->sendLocalResponse(response_code, body.value(), std::move(additional_headers), grpc_code, + details.value()); + return WasmResult::Ok; } Word set_effective_context(void *raw_context, Word context_id) { auto context = WASM_CONTEXT(raw_context); - uint32_t cid = static_cast(context_id.u64_); + uint32_t cid = static_cast(context_id); auto c = context->wasm()->getContext(cid); if (!c) { - return wasmResultToWord(WasmResult::BadArgument); + return WasmResult::BadArgument; } effective_context_id_ = cid; - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word done(void *raw_context) { auto context = WASM_CONTEXT(raw_context); - return wasmResultToWord(context->wasm()->done(context)); + return context->wasm()->done(context); } Word call_foreign_function(void *raw_context, Word function_name, Word function_name_size, Word arguments, Word arguments_size, Word results, Word results_size) { auto context = WASM_CONTEXT(raw_context); - auto function = context->wasmVm()->getMemory(function_name.u64_, function_name_size.u64_); + auto function = context->wasmVm()->getMemory(function_name, function_name_size); if (!function) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - auto args_opt = context->wasmVm()->getMemory(arguments.u64_, arguments_size.u64_); + auto args_opt = context->wasmVm()->getMemory(arguments, arguments_size); if (!args_opt) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } auto f = context->wasm()->getForeignFunction(function.value()); if (!f) { - return wasmResultToWord(WasmResult::NotFound); + return WasmResult::NotFound; } auto &wasm = *context->wasm(); auto &args = args_opt.value(); @@ -238,60 +236,60 @@ Word call_foreign_function(void *raw_context, Word function_name, Word function_ result_size = s; return result; }); - if (!context->wasmVm()->setWord(results.u64_, Word(address))) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasmVm()->setWord(results, Word(address))) { + return WasmResult::InvalidMemoryAccess; } - if (!context->wasmVm()->setWord(results_size.u64_, Word(result_size))) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasmVm()->setWord(results_size, Word(result_size))) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(res); + return res; } // SharedData Word get_shared_data(void *raw_context, Word key_ptr, Word key_size, Word value_ptr_ptr, Word value_size_ptr, Word cas_ptr) { auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); if (!key) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } std::pair data; WasmResult result = context->getSharedData(key.value(), &data); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->copyToPointerSize(data.first, value_ptr_ptr.u64_, value_size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->copyToPointerSize(data.first, value_ptr_ptr, value_size_ptr)) { + return WasmResult::InvalidMemoryAccess; } - if (!context->wasmVm()->setMemory(cas_ptr.u64_, sizeof(uint32_t), &data.second)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasmVm()->setMemory(cas_ptr, sizeof(uint32_t), &data.second)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word set_shared_data(void *raw_context, Word key_ptr, Word key_size, Word value_ptr, Word value_size, Word cas) { auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); - auto value = context->wasmVm()->getMemory(value_ptr.u64_, value_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); + auto value = context->wasmVm()->getMemory(value_ptr, value_size); if (!key || !value) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(context->setSharedData(key.value(), value.value(), cas.u64_)); + return context->setSharedData(key.value(), value.value(), cas); } Word register_shared_queue(void *raw_context, Word queue_name_ptr, Word queue_name_size, Word token_ptr) { auto context = WASM_CONTEXT(raw_context); - auto queue_name = context->wasmVm()->getMemory(queue_name_ptr.u64_, queue_name_size.u64_); + auto queue_name = context->wasmVm()->getMemory(queue_name_ptr, queue_name_size); if (!queue_name) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } uint32_t token = context->registerSharedQueue(queue_name.value()); - if (!context->wasm()->setDatatype(token_ptr.u64_, token)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(token_ptr, token)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word dequeue_shared_queue(void *raw_context, Word token, Word data_ptr_ptr, Word data_size_ptr) { @@ -299,315 +297,342 @@ Word dequeue_shared_queue(void *raw_context, Word token, Word data_ptr_ptr, Word std::string data; WasmResult result = context->dequeueSharedQueue(token.u32(), &data); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->copyToPointerSize(data, data_ptr_ptr.u64_, data_size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->copyToPointerSize(data, data_ptr_ptr, data_size_ptr)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word resolve_shared_queue(void *raw_context, Word vm_id_ptr, Word vm_id_size, Word queue_name_ptr, Word queue_name_size, Word token_ptr) { auto context = WASM_CONTEXT(raw_context); - auto vm_id = context->wasmVm()->getMemory(vm_id_ptr.u64_, vm_id_size.u64_); - auto queue_name = context->wasmVm()->getMemory(queue_name_ptr.u64_, queue_name_size.u64_); + auto vm_id = context->wasmVm()->getMemory(vm_id_ptr, vm_id_size); + auto queue_name = context->wasmVm()->getMemory(queue_name_ptr, queue_name_size); if (!vm_id || !queue_name) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } uint32_t token = 0; auto result = context->resolveSharedQueue(vm_id.value(), queue_name.value(), &token); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->setDatatype(token_ptr.u64_, token)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(token_ptr, token)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word enqueue_shared_queue(void *raw_context, Word token, Word data_ptr, Word data_size) { auto context = WASM_CONTEXT(raw_context); - auto data = context->wasmVm()->getMemory(data_ptr.u64_, data_size.u64_); + auto data = context->wasmVm()->getMemory(data_ptr, data_size); if (!data) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(context->enqueueSharedQueue(token.u32(), data.value())); + return context->enqueueSharedQueue(token.u32(), data.value()); } // Header/Trailer/Metadata Maps Word add_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, Word value_ptr, Word value_size) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); - auto value = context->wasmVm()->getMemory(value_ptr.u64_, value_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); + auto value = context->wasmVm()->getMemory(value_ptr, value_size); if (!key || !value) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } context->addHeaderMapValue(static_cast(type.u64_), key.value(), value.value()); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word get_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, Word value_ptr_ptr, Word value_size_ptr) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); if (!key) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } auto result = context->getHeaderMapValue(static_cast(type.u64_), key.value()); - context->wasm()->copyToPointerSize(result, value_ptr_ptr.u64_, value_size_ptr.u64_); - return wasmResultToWord(WasmResult::Ok); + context->wasm()->copyToPointerSize(result, value_ptr_ptr, value_size_ptr); + return WasmResult::Ok; } Word replace_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, Word value_ptr, Word value_size) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); - auto value = context->wasmVm()->getMemory(value_ptr.u64_, value_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); + auto value = context->wasmVm()->getMemory(value_ptr, value_size); if (!key || !value) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } context->replaceHeaderMapValue(static_cast(type.u64_), key.value(), value.value()); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word remove_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto key = context->wasmVm()->getMemory(key_ptr.u64_, key_size.u64_); + auto key = context->wasmVm()->getMemory(key_ptr, key_size); if (!key) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } context->removeHeaderMapValue(static_cast(type.u64_), key.value()); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word get_header_map_pairs(void *raw_context, Word type, Word ptr_ptr, Word size_ptr) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); auto result = context->getHeaderMapPairs(static_cast(type.u64_)); - if (!getPairs(context, result, ptr_ptr.u64_, size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!getPairs(context, result, ptr_ptr, size_ptr)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word set_header_map_pairs(void *raw_context, Word type, Word ptr, Word size) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto data = context->wasmVm()->getMemory(ptr.u64_, size.u64_); + auto data = context->wasmVm()->getMemory(ptr, size); if (!data) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } context->setHeaderMapPairs(static_cast(type.u64_), toPairs(data.value())); - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word get_header_map_size(void *raw_context, Word type, Word result_ptr) { - if (type.u64_ > static_cast(WasmHeaderMapType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmHeaderMapType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); size_t result = context->getHeaderMapSize(static_cast(type.u64_)); - if (!context->wasmVm()->setWord(result_ptr.u64_, Word(result))) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasmVm()->setWord(result_ptr, Word(result))) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } // Buffer Word get_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word ptr_ptr, Word size_ptr) { - if (type.u64_ > static_cast(WasmBufferType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmBufferType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); auto buffer = context->getBuffer(static_cast(type.u64_)); if (!buffer) { - return wasmResultToWord(WasmResult::NotFound); + return WasmResult::NotFound; } - // NB: check for overflow. - if (buffer->size() < start.u64_ + length.u64_ || start.u64_ > start.u64_ + length.u64_) { - return wasmResultToWord(WasmResult::BadArgument); + // check for overflow. + if (buffer->size() < start + length || start > start + length) { + return WasmResult::BadArgument; } - if (!buffer->copyTo(context->wasm(), start.u64_, length.u64_, ptr_ptr.u64_, size_ptr.u64_)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (length > 0) { + return buffer->copyTo(context->wasm(), start, length, ptr_ptr, size_ptr); } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word get_buffer_status(void *raw_context, Word type, Word length_ptr, Word flags_ptr) { - if (type.u64_ > static_cast(WasmBufferType::MAX)) { - return wasmResultToWord(WasmResult::BadArgument); + if (type > static_cast(WasmBufferType::MAX)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); auto buffer = context->getBuffer(static_cast(type.u64_)); if (!buffer) { - return wasmResultToWord(WasmResult::NotFound); + return WasmResult::NotFound; } auto length = buffer->size(); uint32_t flags = 0; - if (!context->wasmVm()->setWord(length_ptr.u64_, Word(length))) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasmVm()->setWord(length_ptr, Word(length))) { + return WasmResult::InvalidMemoryAccess; } - if (!context->wasm()->setDatatype(flags_ptr.u64_, flags)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(flags_ptr, flags)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; +} + +Word set_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word data_ptr, + Word data_size) { + if (type > static_cast(WasmBufferType::MAX)) { + return WasmResult::BadArgument; + } + auto context = WASM_CONTEXT(raw_context); + auto buffer = context->getBuffer(static_cast(type.u64_)); + if (!buffer) { + return WasmResult::NotFound; + } + auto data = context->wasmVm()->getMemory(data_ptr, data_size); + if (!data) { + return WasmResult::InvalidMemoryAccess; + } + // check for overflow. + if (buffer->size() < start + length || start > start + length) { + return WasmResult::BadArgument; + } + return buffer->copyFrom(start, length, data.value()); } Word http_call(void *raw_context, Word uri_ptr, Word uri_size, Word header_pairs_ptr, Word header_pairs_size, Word body_ptr, Word body_size, Word trailer_pairs_ptr, Word trailer_pairs_size, Word timeout_milliseconds, Word token_ptr) { auto context = WASM_CONTEXT(raw_context)->root_context(); - auto uri = context->wasmVm()->getMemory(uri_ptr.u64_, uri_size.u64_); - auto body = context->wasmVm()->getMemory(body_ptr.u64_, body_size.u64_); - auto header_pairs = context->wasmVm()->getMemory(header_pairs_ptr.u64_, header_pairs_size.u64_); - auto trailer_pairs = - context->wasmVm()->getMemory(trailer_pairs_ptr.u64_, trailer_pairs_size.u64_); + auto uri = context->wasmVm()->getMemory(uri_ptr, uri_size); + auto body = context->wasmVm()->getMemory(body_ptr, body_size); + auto header_pairs = context->wasmVm()->getMemory(header_pairs_ptr, header_pairs_size); + auto trailer_pairs = context->wasmVm()->getMemory(trailer_pairs_ptr, trailer_pairs_size); if (!uri || !body || !header_pairs || !trailer_pairs) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } auto headers = toPairs(header_pairs.value()); auto trailers = toPairs(trailer_pairs.value()); uint32_t token = 0; // NB: try to write the token to verify the memory before starting the async // operation. - if (!context->wasm()->setDatatype(token_ptr.u64_, token)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(token_ptr, token)) { + return WasmResult::InvalidMemoryAccess; } - auto result = context->httpCall(uri.value(), headers, body.value(), trailers, - timeout_milliseconds.u64_, &token); - context->wasm()->setDatatype(token_ptr.u64_, token); - return wasmResultToWord(result); + auto result = + context->httpCall(uri.value(), headers, body.value(), trailers, timeout_milliseconds, &token); + context->wasm()->setDatatype(token_ptr, token); + return result; } Word define_metric(void *raw_context, Word metric_type, Word name_ptr, Word name_size, Word metric_id_ptr) { - if (metric_type.u64_ > static_cast(MetricType::Max)) { - return wasmResultToWord(WasmResult::BadArgument); + if (metric_type > static_cast(MetricType::Max)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto name = context->wasmVm()->getMemory(name_ptr.u64_, name_size.u64_); + auto name = context->wasmVm()->getMemory(name_ptr, name_size); if (!name) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } uint32_t metric_id = 0; auto result = context->defineMetric(static_cast(metric_type.u64_), name.value(), &metric_id); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->setDatatype(metric_id_ptr.u64_, metric_id)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(metric_id_ptr, metric_id)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word increment_metric(void *raw_context, Word metric_id, int64_t offset) { auto context = WASM_CONTEXT(raw_context); - return wasmResultToWord(context->incrementMetric(metric_id.u64_, offset)); + return context->incrementMetric(metric_id, offset); } Word record_metric(void *raw_context, Word metric_id, uint64_t value) { auto context = WASM_CONTEXT(raw_context); - return wasmResultToWord(context->recordMetric(metric_id.u64_, value)); + return context->recordMetric(metric_id, value); } Word get_metric(void *raw_context, Word metric_id, Word result_uint64_ptr) { auto context = WASM_CONTEXT(raw_context); uint64_t value = 0; - auto result = context->getMetric(metric_id.u64_, &value); + auto result = context->getMetric(metric_id, &value); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->setDatatype(result_uint64_ptr.u64_, value)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(result_uint64_ptr, value)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word grpc_call(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, - Word request_ptr, Word request_size, Word timeout_milliseconds, Word token_ptr) { + Word initial_metadata_ptr, Word initial_metadata_size, Word request_ptr, + Word request_size, Word timeout_milliseconds, Word token_ptr) { auto context = WASM_CONTEXT(raw_context)->root_context(); - auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_); - auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_); - auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_); - auto request = context->wasmVm()->getMemory(request_ptr.u64_, request_size.u64_); - if (!service || !service_name || !method_name || !request) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + auto service = context->wasmVm()->getMemory(service_ptr, service_size); + auto service_name = context->wasmVm()->getMemory(service_name_ptr, service_name_size); + auto method_name = context->wasmVm()->getMemory(method_name_ptr, method_name_size); + auto initial_metadata_pairs = + context->wasmVm()->getMemory(initial_metadata_ptr, initial_metadata_size); + auto request = context->wasmVm()->getMemory(request_ptr, request_size); + if (!service || !service_name || !method_name || !initial_metadata_pairs || !request) { + return WasmResult::InvalidMemoryAccess; } uint32_t token = 0; - auto result = - context->grpcCall(service.value(), service_name.value(), method_name.value(), request.value(), - std::chrono::milliseconds(timeout_milliseconds.u64_), &token); + auto initial_metadata = toPairs(initial_metadata_pairs.value()); + auto result = context->grpcCall(service.value(), service_name.value(), method_name.value(), + initial_metadata, request.value(), + std::chrono::milliseconds(timeout_milliseconds), &token); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->setDatatype(token_ptr.u64_, token)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(token_ptr, token)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word grpc_stream(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr, Word service_name_size, Word method_name_ptr, Word method_name_size, - Word token_ptr) { + Word initial_metadata_ptr, Word initial_metadata_size, Word token_ptr) { auto context = WASM_CONTEXT(raw_context)->root_context(); - auto service = context->wasmVm()->getMemory(service_ptr.u64_, service_size.u64_); - auto service_name = context->wasmVm()->getMemory(service_name_ptr.u64_, service_name_size.u64_); - auto method_name = context->wasmVm()->getMemory(method_name_ptr.u64_, method_name_size.u64_); - if (!service || !service_name || !method_name) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + auto service = context->wasmVm()->getMemory(service_ptr, service_size); + auto service_name = context->wasmVm()->getMemory(service_name_ptr, service_name_size); + auto method_name = context->wasmVm()->getMemory(method_name_ptr, method_name_size); + auto initial_metadata_pairs = + context->wasmVm()->getMemory(initial_metadata_ptr, initial_metadata_size); + if (!service || !service_name || !method_name || !initial_metadata_pairs) { + return WasmResult::InvalidMemoryAccess; } uint32_t token = 0; - auto result = - context->grpcStream(service.value(), service_name.value(), method_name.value(), &token); + auto initial_metadata = toPairs(initial_metadata_pairs.value()); + auto result = context->grpcStream(service.value(), service_name.value(), method_name.value(), + initial_metadata, &token); if (result != WasmResult::Ok) { - return wasmResultToWord(result); + return result; } - if (!context->wasm()->setDatatype(token_ptr.u64_, token)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(token_ptr, token)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word grpc_cancel(void *raw_context, Word token) { auto context = WASM_CONTEXT(raw_context)->root_context(); - return wasmResultToWord(context->grpcCancel(token.u64_)); + return context->grpcCancel(token); } Word grpc_close(void *raw_context, Word token) { auto context = WASM_CONTEXT(raw_context)->root_context(); - return wasmResultToWord(context->grpcClose(token.u64_)); + return context->grpcClose(token); } Word grpc_send(void *raw_context, Word token, Word message_ptr, Word message_size, Word end_stream) { auto context = WASM_CONTEXT(raw_context)->root_context(); - auto message = context->wasmVm()->getMemory(message_ptr.u64_, message_size.u64_); + auto message = context->wasmVm()->getMemory(message_ptr, message_size); if (!message) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(context->grpcSend(token.u64_, message.value(), end_stream.u64_)); + return context->grpcSend(token, message.value(), end_stream); } // Implementation of writev-like() syscall that redirects stdout/stderr to Envoy @@ -617,7 +642,7 @@ Word writevImpl(void *raw_context, Word fd, Word iovs, Word iovs_len, Word *nwri // Read syscall args. uint64_t log_level; - switch (fd.u64_) { + switch (fd) { case 1 /* stdout */: log_level = 2; // LogLevel::info break; @@ -629,9 +654,9 @@ Word writevImpl(void *raw_context, Word fd, Word iovs, Word iovs_len, Word *nwri } std::string s; - for (size_t i = 0; i < iovs_len.u64_; i++) { + for (size_t i = 0; i < iovs_len; i++) { auto memslice = - context->wasmVm()->getMemory(iovs.u64_ + i * 2 * sizeof(uint32_t), 2 * sizeof(uint32_t)); + context->wasmVm()->getMemory(iovs + i * 2 * sizeof(uint32_t), 2 * sizeof(uint32_t)); if (!memslice) { return 21; // __WASI_EFAULT } @@ -667,15 +692,22 @@ Word wasi_unstable_fd_write(void *raw_context, Word fd, Word iovs, Word iovs_len Word nwritten(0); auto result = writevImpl(raw_context, fd, iovs, iovs_len, &nwritten); - if (result.u64_ != 0) { // __WASI_ESUCCESS + if (result != 0) { // __WASI_ESUCCESS return result; } - if (!context->wasmVm()->setWord(nwritten_ptr.u64_, Word(nwritten))) { + if (!context->wasmVm()->setWord(nwritten_ptr, Word(nwritten))) { return 21; // __WASI_EFAULT } return 0; // __WASI_ESUCCESS } +// __wasi_errno_t __wasi_fd_read(_wasi_fd_t fd, const __wasi_iovec_t *iovs, +// size_t iovs_len, __wasi_size_t *nread); +Word wasi_unstable_fd_read(void *, Word, Word, Word, Word) { + // Don't support reading of any files. + return 52; // __WASI_ERRNO_ENOSYS +} + // __wasi_errno_t __wasi_fd_seek(__wasi_fd_t fd, __wasi_filedelta_t offset, // __wasi_whence_t whence,__wasi_filesize_t *newoffset); Word wasi_unstable_fd_seek(void *raw_context, Word, int64_t, Word, Word) { @@ -691,6 +723,26 @@ Word wasi_unstable_fd_close(void *raw_context, Word) { return 0; } +// __wasi_errno_t __wasi_fd_fdstat_get(__wasi_fd_t fd, __wasi_fdstat_t *stat) +Word wasi_unstable_fd_fdstat_get(void *raw_context, Word fd, Word statOut) { + // We will only support this interface on stdout and stderr + if (fd != 1 && fd != 2) { + return 8; // __WASI_EBADF; + } + + // The last word points to a 24-byte structure, which we + // are mostly going to zero out. + uint64_t wasi_fdstat[3]; + wasi_fdstat[0] = 0; + wasi_fdstat[1] = 64; // This sets "fs_rights_base" to __WASI_RIGHTS_FD_WRITE + wasi_fdstat[2] = 0; + + auto context = WASM_CONTEXT(raw_context); + context->wasmVm()->setMemory(statOut, 3 * sizeof(uint64_t), &wasi_fdstat); + + return 0; // __WASI_ESUCCESS +} + // __wasi_errno_t __wasi_environ_get(char **environ, char *environ_buf); Word wasi_unstable_environ_get(void *, Word, Word) { return 0; // __WASI_ESUCCESS @@ -700,10 +752,10 @@ Word wasi_unstable_environ_get(void *, Word, Word) { // *environ_buf_size); Word wasi_unstable_environ_sizes_get(void *raw_context, Word count_ptr, Word buf_size_ptr) { auto context = WASM_CONTEXT(raw_context); - if (!context->wasmVm()->setWord(count_ptr.u64_, Word(0))) { + if (!context->wasmVm()->setWord(count_ptr, Word(0))) { return 21; // __WASI_EFAULT } - if (!context->wasmVm()->setWord(buf_size_ptr.u64_, Word(0))) { + if (!context->wasmVm()->setWord(buf_size_ptr, Word(0))) { return 21; // __WASI_EFAULT } return 0; // __WASI_ESUCCESS @@ -717,10 +769,10 @@ Word wasi_unstable_args_get(void *, Word, Word) { // __wasi_errno_t __wasi_args_sizes_get(size_t *argc, size_t *argv_buf_size); Word wasi_unstable_args_sizes_get(void *raw_context, Word argc_ptr, Word argv_buf_size_ptr) { auto context = WASM_CONTEXT(raw_context); - if (!context->wasmVm()->setWord(argc_ptr.u64_, Word(0))) { + if (!context->wasmVm()->setWord(argc_ptr, Word(0))) { return 21; // __WASI_EFAULT } - if (!context->wasmVm()->setWord(argv_buf_size_ptr.u64_, Word(0))) { + if (!context->wasmVm()->setWord(argv_buf_size_ptr, Word(0))) { return 21; // __WASI_EFAULT } return 0; // __WASI_ESUCCESS @@ -732,33 +784,32 @@ void wasi_unstable_proc_exit(void *raw_context, Word) { context->error("wasi_unstable proc_exit"); } -Word pthread_equal(void *, Word left, Word right) { return left.u64_ == right.u64_; } +Word pthread_equal(void *, Word left, Word right) { return left == right; } Word set_tick_period_milliseconds(void *raw_context, Word tick_period_milliseconds) { - return wasmResultToWord( - WASM_CONTEXT(raw_context) - ->setTimerPeriod(std::chrono::milliseconds(tick_period_milliseconds.u64_))); + return WASM_CONTEXT(raw_context) + ->setTimerPeriod(std::chrono::milliseconds(tick_period_milliseconds)); } Word get_current_time_nanoseconds(void *raw_context, Word result_uint64_ptr) { auto context = WASM_CONTEXT(raw_context); uint64_t result = context->getCurrentTimeNanoseconds(); - if (!context->wasm()->setDatatype(result_uint64_ptr.u64_, result)) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + if (!context->wasm()->setDatatype(result_uint64_ptr, result)) { + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(WasmResult::Ok); + return WasmResult::Ok; } Word log(void *raw_context, Word level, Word address, Word size) { - if (level.u64_ > static_cast(LogLevel::Max)) { - return wasmResultToWord(WasmResult::BadArgument); + if (level > static_cast(LogLevel::Max)) { + return WasmResult::BadArgument; } auto context = WASM_CONTEXT(raw_context); - auto message = context->wasmVm()->getMemory(address.u64_, size.u64_); + auto message = context->wasmVm()->getMemory(address, size); if (!message) { - return wasmResultToWord(WasmResult::InvalidMemoryAccess); + return WasmResult::InvalidMemoryAccess; } - return wasmResultToWord(context->log(level.u64_, message.value())); + return context->log(level, message.value()); } } // namespace exports diff --git a/src/null/null_plugin.cc b/src/null/null_plugin.cc index dc64f979e..29f1b3b10 100644 --- a/src/null/null_plugin.cc +++ b/src/null/null_plugin.cc @@ -47,17 +47,17 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<1> *f) { if (function_name == "proxy_on_tick") { *f = [plugin](ContextBase *context, Word context_id) { SaveRestoreContext saved_context(context); - plugin->onTick(context_id.u64_); + plugin->onTick(context_id); }; } else if (function_name == "proxy_on_log") { *f = [plugin](ContextBase *context, Word context_id) { SaveRestoreContext saved_context(context); - plugin->onLog(context_id.u64_); + plugin->onLog(context_id); }; } else if (function_name == "proxy_on_delete") { *f = [plugin](ContextBase *context, Word context_id) { SaveRestoreContext saved_context(context); - plugin->onDelete(context_id.u64_); + plugin->onDelete(context_id); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -69,22 +69,22 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<2> *f) { if (function_name == "proxy_on_context_create") { *f = [plugin](ContextBase *context, Word context_id, Word parent_context_id) { SaveRestoreContext saved_context(context); - plugin->onCreate(context_id.u64_, parent_context_id.u64_); + plugin->onCreate(context_id, parent_context_id); }; } else if (function_name == "proxy_on_downstream_connection_close") { *f = [plugin](ContextBase *context, Word context_id, Word peer_type) { SaveRestoreContext saved_context(context); - plugin->onDownstreamConnectionClose(context_id.u64_, peer_type.u64_); + plugin->onDownstreamConnectionClose(context_id, peer_type); }; } else if (function_name == "proxy_on_upstream_connection_close") { *f = [plugin](ContextBase *context, Word context_id, Word peer_type) { SaveRestoreContext saved_context(context); - plugin->onUpstreamConnectionClose(context_id.u64_, peer_type.u64_); + plugin->onUpstreamConnectionClose(context_id, peer_type); }; } else if (function_name == "proxy_on_queue_ready") { *f = [plugin](ContextBase *context, Word context_id, Word token) { SaveRestoreContext saved_context(context); - plugin->onQueueReady(context_id.u64_, token.u64_); + plugin->onQueueReady(context_id, token); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -96,27 +96,22 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<3> *f) { if (function_name == "proxy_on_grpc_close") { *f = [plugin](ContextBase *context, Word context_id, Word token, Word status_code) { SaveRestoreContext saved_context(context); - plugin->onGrpcClose(context_id.u64_, token.u64_, status_code.u64_); + plugin->onGrpcClose(context_id, token, status_code); }; } else if (function_name == "proxy_on_grpc_receive") { *f = [plugin](ContextBase *context, Word context_id, Word token, Word body_size) { SaveRestoreContext saved_context(context); - plugin->onGrpcReceive(context_id.u64_, token.u64_, body_size.u64_); - }; - } else if (function_name == "proxy_on_grpc_create_initial_metadata") { - *f = [plugin](ContextBase *context, Word context_id, Word token, Word headers) { - SaveRestoreContext saved_context(context); - plugin->onGrpcCreateInitialMetadata(context_id.u64_, token.u64_, headers.u64_); + plugin->onGrpcReceive(context_id, token, body_size); }; } else if (function_name == "proxy_on_grpc_receive_initial_metadata") { *f = [plugin](ContextBase *context, Word context_id, Word token, Word headers) { SaveRestoreContext saved_context(context); - plugin->onGrpcReceiveInitialMetadata(context_id.u64_, token.u64_, headers.u64_); + plugin->onGrpcReceiveInitialMetadata(context_id, token, headers); }; } else if (function_name == "proxy_on_grpc_receive_trailing_metadata") { *f = [plugin](ContextBase *context, Word context_id, Word token, Word trailers) { SaveRestoreContext saved_context(context); - plugin->onGrpcReceiveTrailingMetadata(context_id.u64_, token.u64_, trailers.u64_); + plugin->onGrpcReceiveTrailingMetadata(context_id, token, trailers); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -129,8 +124,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<5> *f) { *f = [plugin](ContextBase *context, Word context_id, Word token, Word headers, Word body_size, Word trailers) { SaveRestoreContext saved_context(context); - plugin->onHttpCallResponse(context_id.u64_, token.u64_, headers.u64_, body_size.u64_, - trailers.u64_); + plugin->onHttpCallResponse(context_id, token, headers, body_size, trailers); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -141,17 +135,17 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<1> *f) { auto plugin = this; if (function_name == "malloc") { *f = [](ContextBase *, Word size) -> Word { - return Word(reinterpret_cast(::malloc(size.u64_))); + return Word(reinterpret_cast(::malloc(size))); }; } else if (function_name == "proxy_on_new_connection") { *f = [plugin](ContextBase *context, Word context_id) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onNewConnection(context_id.u64_)); + return Word(plugin->onNewConnection(context_id)); }; } else if (function_name == "proxy_on_done") { *f = [plugin](ContextBase *context, Word context_id) { SaveRestoreContext saved_context(context); - return Word(plugin->onDone(context_id.u64_)); + return Word(plugin->onDone(context_id)); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -163,47 +157,47 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<2> *f) { if (function_name == "proxy_on_vm_start") { *f = [plugin](ContextBase *context, Word context_id, Word configuration_size) { SaveRestoreContext saved_context(context); - return Word(plugin->onStart(context_id.u64_, configuration_size.u64_)); + return Word(plugin->onStart(context_id, configuration_size)); }; } else if (function_name == "proxy_on_configure") { *f = [plugin](ContextBase *context, Word context_id, Word configuration_size) { SaveRestoreContext saved_context(context); - return Word(plugin->onConfigure(context_id.u64_, configuration_size.u64_)); + return Word(plugin->onConfigure(context_id, configuration_size)); }; } else if (function_name == "proxy_validate_configuration") { *f = [plugin](ContextBase *context, Word context_id, Word configuration_size) { SaveRestoreContext saved_context(context); - return Word(plugin->validateConfiguration(context_id.u64_, configuration_size.u64_)); + return Word(plugin->validateConfiguration(context_id, configuration_size)); }; } else if (function_name == "proxy_on_request_headers") { *f = [plugin](ContextBase *context, Word context_id, Word headers) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onRequestHeaders(context_id.u64_, headers.u64_)); + return Word(plugin->onRequestHeaders(context_id, headers)); }; } else if (function_name == "proxy_on_request_trailers") { *f = [plugin](ContextBase *context, Word context_id, Word trailers) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onRequestTrailers(context_id.u64_, trailers.u64_)); + return Word(plugin->onRequestTrailers(context_id, trailers)); }; } else if (function_name == "proxy_on_request_metadata") { *f = [plugin](ContextBase *context, Word context_id, Word elements) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onRequestMetadata(context_id.u64_, elements.u64_)); + return Word(plugin->onRequestMetadata(context_id, elements)); }; } else if (function_name == "proxy_on_response_headers") { *f = [plugin](ContextBase *context, Word context_id, Word headers) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onResponseHeaders(context_id.u64_, headers.u64_)); + return Word(plugin->onResponseHeaders(context_id, headers)); }; } else if (function_name == "proxy_on_response_trailers") { *f = [plugin](ContextBase *context, Word context_id, Word trailers) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onResponseTrailers(context_id.u64_, trailers.u64_)); + return Word(plugin->onResponseTrailers(context_id, trailers)); }; } else if (function_name == "proxy_on_response_metadata") { *f = [plugin](ContextBase *context, Word context_id, Word elements) -> Word { SaveRestoreContext saved_context(context); - return Word(plugin->onResponseMetadata(context_id.u64_, elements.u64_)); + return Word(plugin->onResponseMetadata(context_id, elements)); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -216,29 +210,25 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<3> *f) { *f = [plugin](ContextBase *context, Word context_id, Word body_buffer_length, Word end_of_stream) -> Word { SaveRestoreContext saved_context(context); - return Word( - plugin->onDownstreamData(context_id.u64_, body_buffer_length.u64_, end_of_stream.u64_)); + return Word(plugin->onDownstreamData(context_id, body_buffer_length, end_of_stream)); }; } else if (function_name == "proxy_on_upstream_data") { *f = [plugin](ContextBase *context, Word context_id, Word body_buffer_length, Word end_of_stream) -> Word { SaveRestoreContext saved_context(context); - return Word( - plugin->onUpstreamData(context_id.u64_, body_buffer_length.u64_, end_of_stream.u64_)); + return Word(plugin->onUpstreamData(context_id, body_buffer_length, end_of_stream)); }; } else if (function_name == "proxy_on_request_body") { *f = [plugin](ContextBase *context, Word context_id, Word body_buffer_length, Word end_of_stream) -> Word { SaveRestoreContext saved_context(context); - return Word( - plugin->onRequestBody(context_id.u64_, body_buffer_length.u64_, end_of_stream.u64_)); + return Word(plugin->onRequestBody(context_id, body_buffer_length, end_of_stream)); }; } else if (function_name == "proxy_on_response_body") { *f = [plugin](ContextBase *context, Word context_id, Word body_buffer_length, Word end_of_stream) -> Word { SaveRestoreContext saved_context(context); - return Word( - plugin->onResponseBody(context_id.u64_, body_buffer_length.u64_, end_of_stream.u64_)); + return Word(plugin->onResponseBody(context_id, body_buffer_length, end_of_stream)); }; } else { error("Missing getFunction for: " + std::string(function_name)); @@ -432,11 +422,6 @@ void NullPlugin::onGrpcClose(uint64_t context_id, uint64_t token, uint64_t statu getRootContext(context_id)->onGrpcClose(token, static_cast(status_code)); } -void NullPlugin::onGrpcCreateInitialMetadata(uint64_t context_id, uint64_t token, - uint64_t headers) { - getRootContext(context_id)->onGrpcCreateInitialMetadata(token, headers); -} - void NullPlugin::onGrpcReceiveInitialMetadata(uint64_t context_id, uint64_t token, uint64_t headers) { getRootContext(context_id)->onGrpcReceiveInitialMetadata(token, headers); diff --git a/src/wasm.cc b/src/wasm.cc index b80d469da..a030d1fd7 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -36,12 +36,11 @@ thread_local uint32_t effective_context_id_ = 0; namespace { // Map from Wasm Key to the local Wasm instance. -thread_local std::unordered_map> local_wasms_; +thread_local std::unordered_map> local_wasms; // Map from Wasm Key to the base Wasm instance, using a pointer to avoid the initialization fiasco. -std::mutex base_wasms_mutex_; -std::unordered_map> *base_wasms_ = nullptr; - -std::unordered_map *foreign_functions_ = nullptr; +std::mutex base_wasms_mutex; +std::unordered_map> *base_wasms = nullptr; +std::unordered_map *foreign_functions = nullptr; const std::string INLINE_STRING = ""; @@ -98,10 +97,10 @@ class WasmBase::ShutdownHandle { }; RegisterForeignFunction::RegisterForeignFunction(std::string name, WasmForeignFunction f) { - if (!foreign_functions_) { - foreign_functions_ = new std::remove_reference::type; + if (!foreign_functions) { + foreign_functions = new std::remove_reference::type; } - (*foreign_functions_)[name] = f; + (*foreign_functions)[name] = f; } WasmBase::WasmBase(std::unique_ptr wasm_vm, string_view vm_id, string_view vm_configuration, @@ -130,8 +129,10 @@ void WasmBase::registerCallbacks() { &ConvertFunctionWordToUint32::convertFunctionWordToUint32) _REGISTER_WASI(fd_write); + _REGISTER_WASI(fd_read); _REGISTER_WASI(fd_seek); _REGISTER_WASI(fd_close); + _REGISTER_WASI(fd_fdstat_get); _REGISTER_WASI(environ_get); _REGISTER_WASI(environ_sizes_get); _REGISTER_WASI(args_get); @@ -174,6 +175,7 @@ void WasmBase::registerCallbacks() { _REGISTER_PROXY(get_buffer_status); _REGISTER_PROXY(get_buffer_bytes); + _REGISTER_PROXY(set_buffer_bytes); _REGISTER_PROXY(http_call); @@ -230,7 +232,6 @@ void WasmBase::getFunctions() { _GET_PROXY(on_http_call_response); _GET_PROXY(on_grpc_receive); _GET_PROXY(on_grpc_close); - _GET_PROXY(on_grpc_create_initial_metadata); _GET_PROXY(on_grpc_receive_initial_metadata); _GET_PROXY(on_grpc_receive_trailing_metadata); _GET_PROXY(on_queue_ready); @@ -244,7 +245,7 @@ void WasmBase::getFunctions() { } } -WasmBase::WasmBase(const std::shared_ptr &base_wasm_handle, WasmVmFactory factory) +WasmBase::WasmBase(const std::shared_ptr &base_wasm_handle, WasmVmFactory factory) : std::enable_shared_from_this(*base_wasm_handle->wasm()), vm_id_(base_wasm_handle->wasm()->vm_id_), vm_key_(base_wasm_handle->wasm()->vm_key_), started_from_(base_wasm_handle->wasm()->wasm_vm()->cloneable()), @@ -309,7 +310,7 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { wasm_vm_->link(vm_id_); } - vm_context_ = std::make_shared(this); + vm_context_.reset(createContext(nullptr)); getFunctions(); if (started_from_ != Cloneable::InstantiatedModule) { @@ -323,7 +324,7 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { ContextBase *WasmBase::getOrCreateRootContext(const std::shared_ptr &plugin) { auto root_context = getRootContext(plugin->root_id_); if (!root_context) { - auto context = std::make_unique(this, plugin); + auto context = std::unique_ptr(createContext(plugin)); root_context = context.get(); root_contexts_[plugin->root_id_] = std::move(context); } @@ -350,7 +351,7 @@ ContextBase *WasmBase::start(std::shared_ptr plugin) { it->second->onStart(plugin); return it->second.get(); } - auto context = std::make_unique(this, plugin); + auto context = std::unique_ptr(createContext(plugin)); auto context_ptr = context.get(); root_contexts_[root_id] = std::move(context); context_ptr->onStart(plugin); @@ -379,6 +380,16 @@ uint32_t WasmBase::allocContextId() { } } +void WasmBase::tick(uint32_t root_context_id) { + if (on_tick_) { + auto it = contexts_.find(root_context_id); + if (it == contexts_.end() || !it->second->isRootContext()) { + return; + } + it->second->onTick(); + } +} + void WasmBase::startShutdown() { bool all_done = true; for (auto &p : root_contexts_) { @@ -423,33 +434,33 @@ void WasmBase::queueReady(uint32_t root_context_id, uint32_t token) { } WasmForeignFunction WasmBase::getForeignFunction(string_view function_name) { - auto it = foreign_functions_.find(std::string(function_name)); - if (it != foreign_functions_.end()) { + auto it = foreign_functions->find(std::string(function_name)); + if (it != foreign_functions->end()) { return it->second; } return nullptr; } -std::shared_ptr createWasm(std::string vm_key, std::string code, - std::shared_ptr plugin, - WasmHandleFactory factory, bool allow_precompiled, - std::unique_ptr root_context_for_testing) { - std::shared_ptr wasm; +std::shared_ptr createWasm(std::string vm_key, std::string code, + std::shared_ptr plugin, + WasmHandleFactory factory, bool allow_precompiled, + std::unique_ptr root_context_for_testing) { + std::shared_ptr wasm; { - std::lock_guard guard(base_wasms_mutex_); - if (!base_wasms_) { - base_wasms_ = new std::remove_reference::type; + std::lock_guard guard(base_wasms_mutex); + if (!base_wasms) { + base_wasms = new std::remove_reference::type; } - auto it = base_wasms_->find(vm_key); - if (it != base_wasms_->end()) { + auto it = base_wasms->find(vm_key); + if (it != base_wasms->end()) { wasm = it->second.lock(); if (!wasm) { - base_wasms_->erase(it); + base_wasms->erase(it); } } if (!wasm) { wasm = factory(vm_key); - (*base_wasms_)[vm_key] = wasm; + (*base_wasms)[vm_key] = wasm; } } @@ -465,34 +476,34 @@ std::shared_ptr createWasm(std::string vm_key, std::string code, return wasm; }; -static std::shared_ptr createThreadLocalWasm(std::shared_ptr &base_wasm, - std::shared_ptr plugin, - WasmHandleCloneFactory factory) { +static std::shared_ptr +createThreadLocalWasm(std::shared_ptr &base_wasm, + std::shared_ptr plugin, WasmHandleCloneFactory factory) { auto wasm_handle = factory(base_wasm); ContextBase *root_context = wasm_handle->wasm()->start(plugin); if (!wasm_handle->wasm()->configure(root_context, plugin)) { base_wasm->wasm()->error("Failed to configure WASM code"); return nullptr; } - local_wasms_[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle; + local_wasms[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle; return wasm_handle; } -std::shared_ptr getThreadLocalWasm(string_view vm_key) { - auto it = local_wasms_.find(std::string(vm_key)); - if (it == local_wasms_.end()) { +std::shared_ptr getThreadLocalWasm(string_view vm_key) { + auto it = local_wasms.find(std::string(vm_key)); + if (it == local_wasms.end()) { return nullptr; } auto wasm = it->second.lock(); if (!wasm) { - local_wasms_.erase(std::string(vm_key)); + local_wasms.erase(std::string(vm_key)); } return wasm; } -std::shared_ptr getOrCreateThreadLocalWasm(std::shared_ptr &base_wasm, - std::shared_ptr plugin, - WasmHandleCloneFactory factory) { +std::shared_ptr +getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, + std::shared_ptr plugin, WasmHandleCloneFactory factory) { auto wasm_handle = getThreadLocalWasm(base_wasm->wasm()->vm_key()); if (wasm_handle) { auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin);