diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 9d914b557..27f2b3727 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -54,7 +54,8 @@ class WasmBase : public std::enable_shared_from_this { bool initialize(const std::string &code, bool allow_precompiled = false); void startVm(ContextBase *root_context); bool configure(ContextBase *root_context, std::shared_ptr plugin); - ContextBase *start(std::shared_ptr plugin); // returns the root ContextBase. + // Returns the root ContextBase or nullptr if onStart returns false. + ContextBase *start(std::shared_ptr plugin); string_view vm_id() const { return vm_id_; } string_view vm_key() const { return vm_key_; } @@ -115,7 +116,8 @@ class WasmBase : public std::enable_shared_from_this { // For testing. void setContext(ContextBase *context) { contexts_[context->id()] = context; } - void startForTesting(std::unique_ptr root_context, + // Returns false if onStart returns false. + bool startForTesting(std::unique_ptr root_context, std::shared_ptr plugin); bool getEmscriptenVersion(uint32_t *emscripten_metadata_major_version, @@ -289,6 +291,9 @@ inline const std::string &WasmBase::vm_configuration() const { } inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) { + if (!malloc_) { + return nullptr; + } Word a = malloc_(vm_context(), size); if (!a.u64_) { return nullptr; diff --git a/include/proxy-wasm/word.h b/include/proxy-wasm/word.h index d2886a072..1e4ce6734 100644 --- a/include/proxy-wasm/word.h +++ b/include/proxy-wasm/word.h @@ -24,6 +24,7 @@ namespace proxy_wasm { // Represents a Wasm-native word-sized datum. On 32-bit VMs, the high bits are always zero. // The Wasm/VM API treats all bits as significant. struct Word { + Word() : u64_(0) {} Word(uint64_t w) : u64_(w) {} // Implicit conversion into Word. Word(WasmResult r) : u64_(static_cast(r)) {} // Implicit conversion into Word. uint32_t u32() const { return static_cast(u64_); } diff --git a/src/null/null_plugin.cc b/src/null/null_plugin.cc index b394c8719..149541dc4 100644 --- a/src/null/null_plugin.cc +++ b/src/null/null_plugin.cc @@ -39,6 +39,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<0> *f) { *f = nullptr; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -61,6 +62,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<1> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -88,6 +90,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<2> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -115,6 +118,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<3> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -128,6 +132,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<5> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -149,6 +154,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<1> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -201,6 +207,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<2> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -232,6 +239,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<3> *f) { }; } else { error("Missing getFunction for: " + std::string(function_name)); + *f = nullptr; } } @@ -244,6 +252,7 @@ null_plugin::Context *NullPlugin::ensureContext(uint64_t context_id, uint64_t ro auto factory = registry_->context_factories[root_id]; if (!factory) { error("no context factory for root_id: " + root_id); + return nullptr; } e.first->second = factory(context_id, root); } @@ -254,6 +263,7 @@ null_plugin::RootContext *NullPlugin::ensureRootContext(uint64_t context_id) { auto root_id_opt = null_plugin::getProperty({"plugin_root_id"}); if (!root_id_opt) { error("unable to get root_id"); + return nullptr; } auto root_id = std::move(root_id_opt.value()); auto it = context_map_.find(context_id); @@ -281,6 +291,7 @@ null_plugin::ContextBase *NullPlugin::getContextBase(uint64_t context_id) { auto it = context_map_.find(context_id); if (it == context_map_.end() || !(it->second->asContext() || it->second->asRoot())) { error("no base context context_id: " + std::to_string(context_id)); + return nullptr; } return it->second.get(); } @@ -289,6 +300,7 @@ null_plugin::Context *NullPlugin::getContext(uint64_t context_id) { auto it = context_map_.find(context_id); if (it == context_map_.end() || !it->second->asContext()) { error("no context context_id: " + std::to_string(context_id)); + return nullptr; } return it->second->asContext(); } @@ -297,6 +309,7 @@ null_plugin::RootContext *NullPlugin::getRootContext(uint64_t context_id) { auto it = context_map_.find(context_id); if (it == context_map_.end() || !it->second->asRoot()) { error("no root context_id: " + std::to_string(context_id)); + return nullptr; } return it->second->asRoot(); } diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 7f9b2a871..3cca0a435 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -322,18 +322,21 @@ string_view V8::getCustomSection(string_view name) { const byte_t *end = source_.get() + source_.size(); while (pos < end) { if (pos + 1 > end) { - error("Failed to parse corrupted WASM module"); + error("Failed to parse corrupted Wasm module"); + return ""; } const auto section_type = *pos++; const auto section_len = parseVarint(pos, end); if (section_len == static_cast(-1) || pos + section_len > end) { - error("Failed to parse corrupted WASM module"); + error("Failed to parse corrupted Wasm module"); + return ""; } if (section_type == 0 /* custom section */) { const auto section_data_start = pos; const auto section_name_len = parseVarint(pos, end); if (section_name_len == static_cast(-1) || pos + section_name_len > end) { - error("Failed to parse corrupted WASM module"); + error("Failed to parse corrupted Wasm module"); + return ""; } if (section_name_len == name.size() && ::memcmp(pos, name.data(), section_name_len) == 0) { pos += section_name_len; @@ -379,25 +382,27 @@ void V8::link(string_view debug_name) { case wasm::EXTERN_FUNC: { auto it = host_functions_.find(std::string(module) + "." + std::string(name)); if (it == host_functions_.end()) { - error(std::string("Failed to load WASM module due to a missing import: ") + + error(std::string("Failed to load Wasm module due to a missing import: ") + std::string(module) + "." + std::string(name)); + break; } auto func = it->second.get()->callback_.get(); if (!equalValTypes(import_type->func()->params(), func->type()->params()) || !equalValTypes(import_type->func()->results(), func->type()->results())) { - error(std::string("Failed to load WASM module due to an import type mismatch: ") + + error(std::string("Failed to load Wasm module due to an import type mismatch: ") + std::string(module) + "." + std::string(name) + ", want: " + printValTypes(import_type->func()->params()) + " -> " + printValTypes(import_type->func()->results()) + ", but host exports: " + printValTypes(func->type()->params()) + " -> " + printValTypes(func->type()->results())); + break; } imports.push_back(func); } break; case wasm::EXTERN_GLOBAL: { // TODO(PiotrSikora): add support when/if needed. - error("Failed to load WASM module due to a missing import: " + std::string(module) + "." + + error("Failed to load Wasm module due to a missing import: " + std::string(module) + "." + std::string(name)); } break; @@ -558,6 +563,8 @@ void V8::getModuleFunctionImpl(string_view function_name, if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes>()) || !equalValTypes(func->type()->results(), convertArgsTupleToValTypes>())) { error(std::string("Bad function signature for: ") + std::string(function_name)); + *function = nullptr; + return; } *function = [func, function_name, this](ContextBase *context, Args... args) -> void { wasm::Val params[] = {makeVal(args)...}; @@ -582,6 +589,8 @@ void V8::getModuleFunctionImpl(string_view function_name, if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes>()) || !equalValTypes(func->type()->results(), convertArgsTupleToValTypes>())) { error("Bad function signature for: " + std::string(function_name)); + *function = nullptr; + return; } *function = [func, function_name, this](ContextBase *context, Args... args) -> R { wasm::Val params[] = {makeVal(args)...}; @@ -591,6 +600,7 @@ void V8::getModuleFunctionImpl(string_view function_name, if (trap) { error("Function: " + std::string(function_name) + " failed: " + std::string(trap->message().get(), trap->message().size())); + return R{}; } R rvalue = results[0].get::type>(); return rvalue; diff --git a/src/wasm.cc b/src/wasm.cc index a6ce33f02..786cba701 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -103,13 +103,6 @@ RegisterForeignFunction::RegisterForeignFunction(std::string name, WasmForeignFu (*foreign_functions)[name] = f; } -WasmBase::WasmBase(std::unique_ptr wasm_vm, string_view vm_id, string_view vm_configuration, - string_view vm_key) - : vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)), - vm_configuration_(std::string(vm_configuration)) {} - -WasmBase::~WasmBase() {} - void WasmBase::registerCallbacks() { #define _REGISTER(_fn) \ wasm_vm_->registerCallback( \ @@ -241,7 +234,7 @@ void WasmBase::getFunctions() { #undef _GET_PROXY if (!malloc_) { - error("WASM missing malloc"); + error("Wasm missing malloc"); } } @@ -255,12 +248,15 @@ WasmBase::WasmBase(const std::shared_ptr &base_wasm_handle, Wasm } else { wasm_vm_ = factory(); } - if (!initialize(base_wasm_handle->wasm()->code(), - base_wasm_handle->wasm()->allow_precompiled())) { - error("Failed to load WASM code"); - } } +WasmBase::WasmBase(std::unique_ptr wasm_vm, string_view vm_id, string_view vm_configuration, + string_view vm_key) + : vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)), + vm_configuration_(std::string(vm_configuration)) {} + +WasmBase::~WasmBase() {} + bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { if (!wasm_vm_) { return false; @@ -354,11 +350,13 @@ ContextBase *WasmBase::start(std::shared_ptr plugin) { auto context = std::unique_ptr(createContext(plugin)); auto context_ptr = context.get(); root_contexts_[root_id] = std::move(context); - context_ptr->onStart(plugin); + if (!context_ptr->onStart(plugin)) { + return nullptr; + } return context_ptr; }; -void WasmBase::startForTesting(std::unique_ptr context, +bool WasmBase::startForTesting(std::unique_ptr context, std::shared_ptr plugin) { auto context_ptr = context.get(); if (!context->wasm_) { @@ -367,7 +365,7 @@ void WasmBase::startForTesting(std::unique_ptr context, } root_contexts_[plugin->root_id_] = std::move(context); // Set the current plugin over the lifetime of the onConfigure call to the RootContext. - context_ptr->onStart(plugin); + return context_ptr->onStart(plugin) != 0; } uint32_t WasmBase::allocContextId() { @@ -445,7 +443,7 @@ std::shared_ptr createWasm(std::string vm_key, std::string code, std::shared_ptr plugin, WasmHandleFactory factory, bool allow_precompiled, std::unique_ptr root_context_for_testing) { - std::shared_ptr wasm; + std::shared_ptr wasm_handle; { std::lock_guard guard(base_wasms_mutex); if (!base_wasms) { @@ -453,36 +451,64 @@ std::shared_ptr createWasm(std::string vm_key, std::string code, } auto it = base_wasms->find(vm_key); if (it != base_wasms->end()) { - wasm = it->second.lock(); - if (!wasm) { + wasm_handle = it->second.lock(); + if (!wasm_handle) { base_wasms->erase(it); } } - if (wasm) - return wasm; - wasm = factory(vm_key); - (*base_wasms)[vm_key] = wasm; + if (wasm_handle) { + return wasm_handle; + } + wasm_handle = factory(vm_key); + if (!wasm_handle) { + return nullptr; + } + (*base_wasms)[vm_key] = wasm_handle; } - if (!wasm->wasm()->initialize(code, allow_precompiled)) { - wasm->wasm()->error("Failed to initialize WASM code"); + if (!wasm_handle->wasm()->initialize(code, allow_precompiled)) { + wasm_handle->wasm()->error("Failed to initialize Wasm code"); return nullptr; } + ContextBase *root_context = root_context_for_testing.get(); if (!root_context_for_testing) { - wasm->wasm()->start(plugin); + root_context = wasm_handle->wasm()->start(plugin); + if (!root_context) { + wasm_handle->wasm()->error("Failed to start base Wasm"); + return nullptr; + } } else { - wasm->wasm()->startForTesting(std::move(root_context_for_testing), plugin); + if (!wasm_handle->wasm()->startForTesting(std::move(root_context_for_testing), plugin)) { + wasm_handle->wasm()->error("Failed to start base Wasm"); + return nullptr; + } } - return wasm; + if (!wasm_handle->wasm()->configure(root_context, plugin)) { + wasm_handle->wasm()->error("Failed to configure base Wasm plugin"); + return nullptr; + } + return wasm_handle; }; static std::shared_ptr createThreadLocalWasm(std::shared_ptr &base_wasm, std::shared_ptr plugin, WasmHandleCloneFactory factory) { auto wasm_handle = factory(base_wasm); + if (!wasm_handle) { + return nullptr; + } + if (!wasm_handle->wasm()->initialize(base_wasm->wasm()->code(), + base_wasm->wasm()->allow_precompiled())) { + base_wasm->wasm()->error("Failed to load Wasm code"); + return nullptr; + } ContextBase *root_context = wasm_handle->wasm()->start(plugin); + if (!root_context) { + base_wasm->wasm()->error("Failed to start thread-local Wasm"); + return nullptr; + } if (!wasm_handle->wasm()->configure(root_context, plugin)) { - base_wasm->wasm()->error("Failed to configure WASM code"); + base_wasm->wasm()->error("Failed to configure thread-local Wasm plugin"); return nullptr; } local_wasms[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle; @@ -508,7 +534,7 @@ getOrCreateThreadLocalWasm(std::shared_ptr base_wasm, if (wasm_handle) { auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin); if (!wasm_handle->wasm()->configure(root_context, plugin)) { - base_wasm->wasm()->error("Failed to configure WASM code"); + base_wasm->wasm()->error("Failed to configure thread-local Wasm code"); return nullptr; } return wasm_handle; diff --git a/src/wavm/wavm.cc b/src/wavm/wavm.cc index 4fc8cf5e0..86ee52ba1 100644 --- a/src/wavm/wavm.cc +++ b/src/wavm/wavm.cc @@ -124,7 +124,7 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggableerror("Failed to load WASM module due to a missing import: " + std::string(module_name) + + vm_->error("Failed to load Wasm module due to a missing import: " + std::string(module_name) + "." + std::string(export_name) + " " + asString(type)); }