Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/proxy-wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
bool initialize(const std::string &code, bool allow_precompiled = false);
void startVm(ContextBase *root_context);
bool configure(ContextBase *root_context, std::shared_ptr<PluginBase> plugin);
ContextBase *start(std::shared_ptr<PluginBase> plugin); // returns the root ContextBase.
// Returns the root ContextBase or nullptr if onStart returns false.
ContextBase *start(std::shared_ptr<PluginBase> plugin);

string_view vm_id() const { return vm_id_; }
string_view vm_key() const { return vm_key_; }
Expand Down Expand Up @@ -115,7 +116,8 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {

// For testing.
void setContext(ContextBase *context) { contexts_[context->id()] = context; }
void startForTesting(std::unique_ptr<ContextBase> root_context,
// Returns false if onStart returns false.
bool startForTesting(std::unique_ptr<ContextBase> root_context,
std::shared_ptr<PluginBase> plugin);

bool getEmscriptenVersion(uint32_t *emscripten_metadata_major_version,
Expand Down Expand Up @@ -289,6 +291,9 @@ inline const std::string &WasmBase::vm_configuration() const {
}

inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) {
if (!malloc_) {
return nullptr;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? We fail hard if malloc is not exported.

(I'm not saying it's not good to add it just to be safe, I'm just curious why did you add it now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally I was trying to make the proxy-independent code resilient against unexpected behavior. An integrator could choose to not hard fail on error(). For envoy we throw an exception which we will catch at a higher level, but presumably someone code set a flag.

Word a = malloc_(vm_context(), size);
if (!a.u64_) {
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions include/proxy-wasm/word.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace proxy_wasm {
// Represents a Wasm-native word-sized datum. On 32-bit VMs, the high bits are always zero.
// The Wasm/VM API treats all bits as significant.
struct Word {
Word() : u64_(0) {}
Word(uint64_t w) : u64_(w) {} // Implicit conversion into Word.
Word(WasmResult r) : u64_(static_cast<uint64_t>(r)) {} // Implicit conversion into Word.
uint32_t u32() const { return static_cast<uint32_t>(u64_); }
Expand Down
13 changes: 13 additions & 0 deletions src/null/null_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<0> *f) {
*f = nullptr;
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand All @@ -61,6 +62,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<1> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand Down Expand Up @@ -88,6 +90,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<2> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand Down Expand Up @@ -115,6 +118,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<3> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand All @@ -128,6 +132,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<5> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand All @@ -149,6 +154,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<1> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand Down Expand Up @@ -201,6 +207,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<2> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand Down Expand Up @@ -232,6 +239,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<3> *f) {
};
} else {
error("Missing getFunction for: " + std::string(function_name));
*f = nullptr;
}
}

Expand All @@ -244,6 +252,7 @@ null_plugin::Context *NullPlugin::ensureContext(uint64_t context_id, uint64_t ro
auto factory = registry_->context_factories[root_id];
if (!factory) {
error("no context factory for root_id: " + root_id);
return nullptr;
}
e.first->second = factory(context_id, root);
}
Expand All @@ -254,6 +263,7 @@ null_plugin::RootContext *NullPlugin::ensureRootContext(uint64_t context_id) {
auto root_id_opt = null_plugin::getProperty({"plugin_root_id"});
if (!root_id_opt) {
error("unable to get root_id");
return nullptr;
}
auto root_id = std::move(root_id_opt.value());
auto it = context_map_.find(context_id);
Expand Down Expand Up @@ -281,6 +291,7 @@ null_plugin::ContextBase *NullPlugin::getContextBase(uint64_t context_id) {
auto it = context_map_.find(context_id);
if (it == context_map_.end() || !(it->second->asContext() || it->second->asRoot())) {
error("no base context context_id: " + std::to_string(context_id));
return nullptr;
}
return it->second.get();
}
Expand All @@ -289,6 +300,7 @@ null_plugin::Context *NullPlugin::getContext(uint64_t context_id) {
auto it = context_map_.find(context_id);
if (it == context_map_.end() || !it->second->asContext()) {
error("no context context_id: " + std::to_string(context_id));
return nullptr;
}
return it->second->asContext();
}
Expand All @@ -297,6 +309,7 @@ null_plugin::RootContext *NullPlugin::getRootContext(uint64_t context_id) {
auto it = context_map_.find(context_id);
if (it == context_map_.end() || !it->second->asRoot()) {
error("no root context_id: " + std::to_string(context_id));
return nullptr;
}
return it->second->asRoot();
}
Expand Down
22 changes: 16 additions & 6 deletions src/v8/v8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,21 @@ string_view V8::getCustomSection(string_view name) {
const byte_t *end = source_.get() + source_.size();
while (pos < end) {
if (pos + 1 > end) {
error("Failed to parse corrupted WASM module");
error("Failed to parse corrupted Wasm module");
return "";
}
const auto section_type = *pos++;
const auto section_len = parseVarint(pos, end);
if (section_len == static_cast<uint32_t>(-1) || pos + section_len > end) {
error("Failed to parse corrupted WASM module");
error("Failed to parse corrupted Wasm module");
return "";
}
if (section_type == 0 /* custom section */) {
const auto section_data_start = pos;
const auto section_name_len = parseVarint(pos, end);
if (section_name_len == static_cast<uint32_t>(-1) || pos + section_name_len > end) {
error("Failed to parse corrupted WASM module");
error("Failed to parse corrupted Wasm module");
return "";
}
if (section_name_len == name.size() && ::memcmp(pos, name.data(), section_name_len) == 0) {
pos += section_name_len;
Expand Down Expand Up @@ -379,25 +382,27 @@ void V8::link(string_view debug_name) {
case wasm::EXTERN_FUNC: {
auto it = host_functions_.find(std::string(module) + "." + std::string(name));
if (it == host_functions_.end()) {
error(std::string("Failed to load WASM module due to a missing import: ") +
error(std::string("Failed to load Wasm module due to a missing import: ") +
std::string(module) + "." + std::string(name));
break;
}
auto func = it->second.get()->callback_.get();
if (!equalValTypes(import_type->func()->params(), func->type()->params()) ||
!equalValTypes(import_type->func()->results(), func->type()->results())) {
error(std::string("Failed to load WASM module due to an import type mismatch: ") +
error(std::string("Failed to load Wasm module due to an import type mismatch: ") +
std::string(module) + "." + std::string(name) +
", want: " + printValTypes(import_type->func()->params()) + " -> " +
printValTypes(import_type->func()->results()) +
", but host exports: " + printValTypes(func->type()->params()) + " -> " +
printValTypes(func->type()->results()));
break;
}
imports.push_back(func);
} break;

case wasm::EXTERN_GLOBAL: {
// TODO(PiotrSikora): add support when/if needed.
error("Failed to load WASM module due to a missing import: " + std::string(module) + "." +
error("Failed to load Wasm module due to a missing import: " + std::string(module) + "." +
std::string(name));
} break;

Expand Down Expand Up @@ -558,6 +563,8 @@ void V8::getModuleFunctionImpl(string_view function_name,
if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes<std::tuple<Args...>>()) ||
!equalValTypes(func->type()->results(), convertArgsTupleToValTypes<std::tuple<>>())) {
error(std::string("Bad function signature for: ") + std::string(function_name));
*function = nullptr;
return;
}
*function = [func, function_name, this](ContextBase *context, Args... args) -> void {
wasm::Val params[] = {makeVal(args)...};
Expand All @@ -582,6 +589,8 @@ void V8::getModuleFunctionImpl(string_view function_name,
if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes<std::tuple<Args...>>()) ||
!equalValTypes(func->type()->results(), convertArgsTupleToValTypes<std::tuple<R>>())) {
error("Bad function signature for: " + std::string(function_name));
*function = nullptr;
return;
}
*function = [func, function_name, this](ContextBase *context, Args... args) -> R {
wasm::Val params[] = {makeVal(args)...};
Expand All @@ -591,6 +600,7 @@ void V8::getModuleFunctionImpl(string_view function_name,
if (trap) {
error("Function: " + std::string(function_name) +
" failed: " + std::string(trap->message().get(), trap->message().size()));
return R{};
}
R rvalue = results[0].get<typename ConvertWordTypeToUint32<R>::type>();
return rvalue;
Expand Down
84 changes: 55 additions & 29 deletions src/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ RegisterForeignFunction::RegisterForeignFunction(std::string name, WasmForeignFu
(*foreign_functions)[name] = f;
}

WasmBase::WasmBase(std::unique_ptr<WasmVm> wasm_vm, string_view vm_id, string_view vm_configuration,
string_view vm_key)
: vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)),
vm_configuration_(std::string(vm_configuration)) {}

WasmBase::~WasmBase() {}

void WasmBase::registerCallbacks() {
#define _REGISTER(_fn) \
wasm_vm_->registerCallback( \
Expand Down Expand Up @@ -241,7 +234,7 @@ void WasmBase::getFunctions() {
#undef _GET_PROXY

if (!malloc_) {
error("WASM missing malloc");
error("Wasm missing malloc");
}
}

Expand All @@ -255,12 +248,15 @@ WasmBase::WasmBase(const std::shared_ptr<WasmHandleBase> &base_wasm_handle, Wasm
} else {
wasm_vm_ = factory();
}
if (!initialize(base_wasm_handle->wasm()->code(),
base_wasm_handle->wasm()->allow_precompiled())) {
error("Failed to load WASM code");
}
}

WasmBase::WasmBase(std::unique_ptr<WasmVm> wasm_vm, string_view vm_id, string_view vm_configuration,
string_view vm_key)
: vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)),
vm_configuration_(std::string(vm_configuration)) {}

WasmBase::~WasmBase() {}

bool WasmBase::initialize(const std::string &code, bool allow_precompiled) {
if (!wasm_vm_) {
return false;
Expand Down Expand Up @@ -354,11 +350,13 @@ ContextBase *WasmBase::start(std::shared_ptr<PluginBase> plugin) {
auto context = std::unique_ptr<ContextBase>(createContext(plugin));
auto context_ptr = context.get();
root_contexts_[root_id] = std::move(context);
context_ptr->onStart(plugin);
if (!context_ptr->onStart(plugin)) {
return nullptr;
}
return context_ptr;
};

void WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
bool WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
std::shared_ptr<PluginBase> plugin) {
auto context_ptr = context.get();
if (!context->wasm_) {
Expand All @@ -367,7 +365,7 @@ void WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
}
root_contexts_[plugin->root_id_] = std::move(context);
// Set the current plugin over the lifetime of the onConfigure call to the RootContext.
context_ptr->onStart(plugin);
return context_ptr->onStart(plugin) != 0;
}

uint32_t WasmBase::allocContextId() {
Expand Down Expand Up @@ -445,44 +443,72 @@ std::shared_ptr<WasmHandleBase> createWasm(std::string vm_key, std::string code,
std::shared_ptr<PluginBase> plugin,
WasmHandleFactory factory, bool allow_precompiled,
std::unique_ptr<ContextBase> root_context_for_testing) {
std::shared_ptr<WasmHandleBase> wasm;
std::shared_ptr<WasmHandleBase> wasm_handle;
{
std::lock_guard<std::mutex> guard(base_wasms_mutex);
if (!base_wasms) {
base_wasms = new std::remove_reference<decltype(*base_wasms)>::type;
}
auto it = base_wasms->find(vm_key);
if (it != base_wasms->end()) {
wasm = it->second.lock();
if (!wasm) {
wasm_handle = it->second.lock();
if (!wasm_handle) {
base_wasms->erase(it);
}
}
if (wasm)
return wasm;
wasm = factory(vm_key);
(*base_wasms)[vm_key] = wasm;
if (wasm_handle) {
return wasm_handle;
}
wasm_handle = factory(vm_key);
if (!wasm_handle) {
return nullptr;
}
(*base_wasms)[vm_key] = wasm_handle;
}

if (!wasm->wasm()->initialize(code, allow_precompiled)) {
wasm->wasm()->error("Failed to initialize WASM code");
if (!wasm_handle->wasm()->initialize(code, allow_precompiled)) {
wasm_handle->wasm()->error("Failed to initialize Wasm code");
return nullptr;
}
ContextBase *root_context = root_context_for_testing.get();
if (!root_context_for_testing) {
wasm->wasm()->start(plugin);
root_context = wasm_handle->wasm()->start(plugin);
if (!root_context) {
wasm_handle->wasm()->error("Failed to start base Wasm");
return nullptr;
}
} else {
wasm->wasm()->startForTesting(std::move(root_context_for_testing), plugin);
if (!wasm_handle->wasm()->startForTesting(std::move(root_context_for_testing), plugin)) {
wasm_handle->wasm()->error("Failed to start base Wasm");
return nullptr;
}
}
return wasm;
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
wasm_handle->wasm()->error("Failed to configure base Wasm plugin");
return nullptr;
}
return wasm_handle;
};

static std::shared_ptr<WasmHandleBase>
createThreadLocalWasm(std::shared_ptr<WasmHandleBase> &base_wasm,
std::shared_ptr<PluginBase> plugin, WasmHandleCloneFactory factory) {
auto wasm_handle = factory(base_wasm);
if (!wasm_handle) {
return nullptr;
}
if (!wasm_handle->wasm()->initialize(base_wasm->wasm()->code(),
base_wasm->wasm()->allow_precompiled())) {
base_wasm->wasm()->error("Failed to load Wasm code");
return nullptr;
}
ContextBase *root_context = wasm_handle->wasm()->start(plugin);
if (!root_context) {
base_wasm->wasm()->error("Failed to start thread-local Wasm");
return nullptr;
}
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
base_wasm->wasm()->error("Failed to configure WASM code");
base_wasm->wasm()->error("Failed to configure thread-local Wasm plugin");
return nullptr;
}
local_wasms[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle;
Expand All @@ -508,7 +534,7 @@ getOrCreateThreadLocalWasm(std::shared_ptr<WasmHandleBase> base_wasm,
if (wasm_handle) {
auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin);
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
base_wasm->wasm()->error("Failed to configure WASM code");
base_wasm->wasm()->error("Failed to configure thread-local Wasm code");
return nullptr;
}
return wasm_handle;
Expand Down
2 changes: 1 addition & 1 deletion src/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
return true;
}
}
vm_->error("Failed to load WASM module due to a missing import: " + std::string(module_name) +
vm_->error("Failed to load Wasm module due to a missing import: " + std::string(module_name) +
"." + std::string(export_name) + " " + asString(type));
}

Expand Down