Skip to content

Commit f012757

Browse files
authored
Handle error cases gracefully. WASM -> Wasm. (proxy-wasm#20)
* Handle error cases gracefully. WASM -> Wasm. Signed-off-by: John Plevyak <[email protected]>
1 parent 6b9a5d8 commit f012757

File tree

6 files changed

+93
-38
lines changed

6 files changed

+93
-38
lines changed

include/proxy-wasm/wasm.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
5454
bool initialize(const std::string &code, bool allow_precompiled = false);
5555
void startVm(ContextBase *root_context);
5656
bool configure(ContextBase *root_context, std::shared_ptr<PluginBase> plugin);
57-
ContextBase *start(std::shared_ptr<PluginBase> plugin); // returns the root ContextBase.
57+
// Returns the root ContextBase or nullptr if onStart returns false.
58+
ContextBase *start(std::shared_ptr<PluginBase> plugin);
5859

5960
string_view vm_id() const { return vm_id_; }
6061
string_view vm_key() const { return vm_key_; }
@@ -115,7 +116,8 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
115116

116117
// For testing.
117118
void setContext(ContextBase *context) { contexts_[context->id()] = context; }
118-
void startForTesting(std::unique_ptr<ContextBase> root_context,
119+
// Returns false if onStart returns false.
120+
bool startForTesting(std::unique_ptr<ContextBase> root_context,
119121
std::shared_ptr<PluginBase> plugin);
120122

121123
bool getEmscriptenVersion(uint32_t *emscripten_metadata_major_version,
@@ -289,6 +291,9 @@ inline const std::string &WasmBase::vm_configuration() const {
289291
}
290292

291293
inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) {
294+
if (!malloc_) {
295+
return nullptr;
296+
}
292297
Word a = malloc_(vm_context(), size);
293298
if (!a.u64_) {
294299
return nullptr;

include/proxy-wasm/word.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace proxy_wasm {
2424
// Represents a Wasm-native word-sized datum. On 32-bit VMs, the high bits are always zero.
2525
// The Wasm/VM API treats all bits as significant.
2626
struct Word {
27+
Word() : u64_(0) {}
2728
Word(uint64_t w) : u64_(w) {} // Implicit conversion into Word.
2829
Word(WasmResult r) : u64_(static_cast<uint64_t>(r)) {} // Implicit conversion into Word.
2930
uint32_t u32() const { return static_cast<uint32_t>(u64_); }

src/null/null_plugin.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<0> *f) {
3939
*f = nullptr;
4040
} else {
4141
error("Missing getFunction for: " + std::string(function_name));
42+
*f = nullptr;
4243
}
4344
}
4445

@@ -61,6 +62,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<1> *f) {
6162
};
6263
} else {
6364
error("Missing getFunction for: " + std::string(function_name));
65+
*f = nullptr;
6466
}
6567
}
6668

@@ -88,6 +90,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<2> *f) {
8890
};
8991
} else {
9092
error("Missing getFunction for: " + std::string(function_name));
93+
*f = nullptr;
9194
}
9295
}
9396

@@ -115,6 +118,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<3> *f) {
115118
};
116119
} else {
117120
error("Missing getFunction for: " + std::string(function_name));
121+
*f = nullptr;
118122
}
119123
}
120124

@@ -128,6 +132,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallVoid<5> *f) {
128132
};
129133
} else {
130134
error("Missing getFunction for: " + std::string(function_name));
135+
*f = nullptr;
131136
}
132137
}
133138

@@ -149,6 +154,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<1> *f) {
149154
};
150155
} else {
151156
error("Missing getFunction for: " + std::string(function_name));
157+
*f = nullptr;
152158
}
153159
}
154160

@@ -201,6 +207,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<2> *f) {
201207
};
202208
} else {
203209
error("Missing getFunction for: " + std::string(function_name));
210+
*f = nullptr;
204211
}
205212
}
206213

@@ -232,6 +239,7 @@ void NullPlugin::getFunction(string_view function_name, WasmCallWord<3> *f) {
232239
};
233240
} else {
234241
error("Missing getFunction for: " + std::string(function_name));
242+
*f = nullptr;
235243
}
236244
}
237245

@@ -244,6 +252,7 @@ null_plugin::Context *NullPlugin::ensureContext(uint64_t context_id, uint64_t ro
244252
auto factory = registry_->context_factories[root_id];
245253
if (!factory) {
246254
error("no context factory for root_id: " + root_id);
255+
return nullptr;
247256
}
248257
e.first->second = factory(context_id, root);
249258
}
@@ -254,6 +263,7 @@ null_plugin::RootContext *NullPlugin::ensureRootContext(uint64_t context_id) {
254263
auto root_id_opt = null_plugin::getProperty({"plugin_root_id"});
255264
if (!root_id_opt) {
256265
error("unable to get root_id");
266+
return nullptr;
257267
}
258268
auto root_id = std::move(root_id_opt.value());
259269
auto it = context_map_.find(context_id);
@@ -281,6 +291,7 @@ null_plugin::ContextBase *NullPlugin::getContextBase(uint64_t context_id) {
281291
auto it = context_map_.find(context_id);
282292
if (it == context_map_.end() || !(it->second->asContext() || it->second->asRoot())) {
283293
error("no base context context_id: " + std::to_string(context_id));
294+
return nullptr;
284295
}
285296
return it->second.get();
286297
}
@@ -289,6 +300,7 @@ null_plugin::Context *NullPlugin::getContext(uint64_t context_id) {
289300
auto it = context_map_.find(context_id);
290301
if (it == context_map_.end() || !it->second->asContext()) {
291302
error("no context context_id: " + std::to_string(context_id));
303+
return nullptr;
292304
}
293305
return it->second->asContext();
294306
}
@@ -297,6 +309,7 @@ null_plugin::RootContext *NullPlugin::getRootContext(uint64_t context_id) {
297309
auto it = context_map_.find(context_id);
298310
if (it == context_map_.end() || !it->second->asRoot()) {
299311
error("no root context_id: " + std::to_string(context_id));
312+
return nullptr;
300313
}
301314
return it->second->asRoot();
302315
}

src/v8/v8.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,18 +322,21 @@ string_view V8::getCustomSection(string_view name) {
322322
const byte_t *end = source_.get() + source_.size();
323323
while (pos < end) {
324324
if (pos + 1 > end) {
325-
error("Failed to parse corrupted WASM module");
325+
error("Failed to parse corrupted Wasm module");
326+
return "";
326327
}
327328
const auto section_type = *pos++;
328329
const auto section_len = parseVarint(pos, end);
329330
if (section_len == static_cast<uint32_t>(-1) || pos + section_len > end) {
330-
error("Failed to parse corrupted WASM module");
331+
error("Failed to parse corrupted Wasm module");
332+
return "";
331333
}
332334
if (section_type == 0 /* custom section */) {
333335
const auto section_data_start = pos;
334336
const auto section_name_len = parseVarint(pos, end);
335337
if (section_name_len == static_cast<uint32_t>(-1) || pos + section_name_len > end) {
336-
error("Failed to parse corrupted WASM module");
338+
error("Failed to parse corrupted Wasm module");
339+
return "";
337340
}
338341
if (section_name_len == name.size() && ::memcmp(pos, name.data(), section_name_len) == 0) {
339342
pos += section_name_len;
@@ -379,25 +382,27 @@ void V8::link(string_view debug_name) {
379382
case wasm::EXTERN_FUNC: {
380383
auto it = host_functions_.find(std::string(module) + "." + std::string(name));
381384
if (it == host_functions_.end()) {
382-
error(std::string("Failed to load WASM module due to a missing import: ") +
385+
error(std::string("Failed to load Wasm module due to a missing import: ") +
383386
std::string(module) + "." + std::string(name));
387+
break;
384388
}
385389
auto func = it->second.get()->callback_.get();
386390
if (!equalValTypes(import_type->func()->params(), func->type()->params()) ||
387391
!equalValTypes(import_type->func()->results(), func->type()->results())) {
388-
error(std::string("Failed to load WASM module due to an import type mismatch: ") +
392+
error(std::string("Failed to load Wasm module due to an import type mismatch: ") +
389393
std::string(module) + "." + std::string(name) +
390394
", want: " + printValTypes(import_type->func()->params()) + " -> " +
391395
printValTypes(import_type->func()->results()) +
392396
", but host exports: " + printValTypes(func->type()->params()) + " -> " +
393397
printValTypes(func->type()->results()));
398+
break;
394399
}
395400
imports.push_back(func);
396401
} break;
397402

398403
case wasm::EXTERN_GLOBAL: {
399404
// TODO(PiotrSikora): add support when/if needed.
400-
error("Failed to load WASM module due to a missing import: " + std::string(module) + "." +
405+
error("Failed to load Wasm module due to a missing import: " + std::string(module) + "." +
401406
std::string(name));
402407
} break;
403408

@@ -558,6 +563,8 @@ void V8::getModuleFunctionImpl(string_view function_name,
558563
if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes<std::tuple<Args...>>()) ||
559564
!equalValTypes(func->type()->results(), convertArgsTupleToValTypes<std::tuple<>>())) {
560565
error(std::string("Bad function signature for: ") + std::string(function_name));
566+
*function = nullptr;
567+
return;
561568
}
562569
*function = [func, function_name, this](ContextBase *context, Args... args) -> void {
563570
wasm::Val params[] = {makeVal(args)...};
@@ -582,6 +589,8 @@ void V8::getModuleFunctionImpl(string_view function_name,
582589
if (!equalValTypes(func->type()->params(), convertArgsTupleToValTypes<std::tuple<Args...>>()) ||
583590
!equalValTypes(func->type()->results(), convertArgsTupleToValTypes<std::tuple<R>>())) {
584591
error("Bad function signature for: " + std::string(function_name));
592+
*function = nullptr;
593+
return;
585594
}
586595
*function = [func, function_name, this](ContextBase *context, Args... args) -> R {
587596
wasm::Val params[] = {makeVal(args)...};
@@ -591,6 +600,7 @@ void V8::getModuleFunctionImpl(string_view function_name,
591600
if (trap) {
592601
error("Function: " + std::string(function_name) +
593602
" failed: " + std::string(trap->message().get(), trap->message().size()));
603+
return R{};
594604
}
595605
R rvalue = results[0].get<typename ConvertWordTypeToUint32<R>::type>();
596606
return rvalue;

src/wasm.cc

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,6 @@ RegisterForeignFunction::RegisterForeignFunction(std::string name, WasmForeignFu
103103
(*foreign_functions)[name] = f;
104104
}
105105

106-
WasmBase::WasmBase(std::unique_ptr<WasmVm> wasm_vm, string_view vm_id, string_view vm_configuration,
107-
string_view vm_key)
108-
: vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)),
109-
vm_configuration_(std::string(vm_configuration)) {}
110-
111-
WasmBase::~WasmBase() {}
112-
113106
void WasmBase::registerCallbacks() {
114107
#define _REGISTER(_fn) \
115108
wasm_vm_->registerCallback( \
@@ -241,7 +234,7 @@ void WasmBase::getFunctions() {
241234
#undef _GET_PROXY
242235

243236
if (!malloc_) {
244-
error("WASM missing malloc");
237+
error("Wasm missing malloc");
245238
}
246239
}
247240

@@ -255,12 +248,15 @@ WasmBase::WasmBase(const std::shared_ptr<WasmHandleBase> &base_wasm_handle, Wasm
255248
} else {
256249
wasm_vm_ = factory();
257250
}
258-
if (!initialize(base_wasm_handle->wasm()->code(),
259-
base_wasm_handle->wasm()->allow_precompiled())) {
260-
error("Failed to load WASM code");
261-
}
262251
}
263252

253+
WasmBase::WasmBase(std::unique_ptr<WasmVm> wasm_vm, string_view vm_id, string_view vm_configuration,
254+
string_view vm_key)
255+
: vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)),
256+
vm_configuration_(std::string(vm_configuration)) {}
257+
258+
WasmBase::~WasmBase() {}
259+
264260
bool WasmBase::initialize(const std::string &code, bool allow_precompiled) {
265261
if (!wasm_vm_) {
266262
return false;
@@ -354,11 +350,13 @@ ContextBase *WasmBase::start(std::shared_ptr<PluginBase> plugin) {
354350
auto context = std::unique_ptr<ContextBase>(createContext(plugin));
355351
auto context_ptr = context.get();
356352
root_contexts_[root_id] = std::move(context);
357-
context_ptr->onStart(plugin);
353+
if (!context_ptr->onStart(plugin)) {
354+
return nullptr;
355+
}
358356
return context_ptr;
359357
};
360358

361-
void WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
359+
bool WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
362360
std::shared_ptr<PluginBase> plugin) {
363361
auto context_ptr = context.get();
364362
if (!context->wasm_) {
@@ -367,7 +365,7 @@ void WasmBase::startForTesting(std::unique_ptr<ContextBase> context,
367365
}
368366
root_contexts_[plugin->root_id_] = std::move(context);
369367
// Set the current plugin over the lifetime of the onConfigure call to the RootContext.
370-
context_ptr->onStart(plugin);
368+
return context_ptr->onStart(plugin) != 0;
371369
}
372370

373371
uint32_t WasmBase::allocContextId() {
@@ -445,44 +443,72 @@ std::shared_ptr<WasmHandleBase> createWasm(std::string vm_key, std::string code,
445443
std::shared_ptr<PluginBase> plugin,
446444
WasmHandleFactory factory, bool allow_precompiled,
447445
std::unique_ptr<ContextBase> root_context_for_testing) {
448-
std::shared_ptr<WasmHandleBase> wasm;
446+
std::shared_ptr<WasmHandleBase> wasm_handle;
449447
{
450448
std::lock_guard<std::mutex> guard(base_wasms_mutex);
451449
if (!base_wasms) {
452450
base_wasms = new std::remove_reference<decltype(*base_wasms)>::type;
453451
}
454452
auto it = base_wasms->find(vm_key);
455453
if (it != base_wasms->end()) {
456-
wasm = it->second.lock();
457-
if (!wasm) {
454+
wasm_handle = it->second.lock();
455+
if (!wasm_handle) {
458456
base_wasms->erase(it);
459457
}
460458
}
461-
if (wasm)
462-
return wasm;
463-
wasm = factory(vm_key);
464-
(*base_wasms)[vm_key] = wasm;
459+
if (wasm_handle) {
460+
return wasm_handle;
461+
}
462+
wasm_handle = factory(vm_key);
463+
if (!wasm_handle) {
464+
return nullptr;
465+
}
466+
(*base_wasms)[vm_key] = wasm_handle;
465467
}
466468

467-
if (!wasm->wasm()->initialize(code, allow_precompiled)) {
468-
wasm->wasm()->error("Failed to initialize WASM code");
469+
if (!wasm_handle->wasm()->initialize(code, allow_precompiled)) {
470+
wasm_handle->wasm()->error("Failed to initialize Wasm code");
469471
return nullptr;
470472
}
473+
ContextBase *root_context = root_context_for_testing.get();
471474
if (!root_context_for_testing) {
472-
wasm->wasm()->start(plugin);
475+
root_context = wasm_handle->wasm()->start(plugin);
476+
if (!root_context) {
477+
wasm_handle->wasm()->error("Failed to start base Wasm");
478+
return nullptr;
479+
}
473480
} else {
474-
wasm->wasm()->startForTesting(std::move(root_context_for_testing), plugin);
481+
if (!wasm_handle->wasm()->startForTesting(std::move(root_context_for_testing), plugin)) {
482+
wasm_handle->wasm()->error("Failed to start base Wasm");
483+
return nullptr;
484+
}
475485
}
476-
return wasm;
486+
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
487+
wasm_handle->wasm()->error("Failed to configure base Wasm plugin");
488+
return nullptr;
489+
}
490+
return wasm_handle;
477491
};
478492

479493
static std::shared_ptr<WasmHandleBase>
480494
createThreadLocalWasm(std::shared_ptr<WasmHandleBase> &base_wasm,
481495
std::shared_ptr<PluginBase> plugin, WasmHandleCloneFactory factory) {
482496
auto wasm_handle = factory(base_wasm);
497+
if (!wasm_handle) {
498+
return nullptr;
499+
}
500+
if (!wasm_handle->wasm()->initialize(base_wasm->wasm()->code(),
501+
base_wasm->wasm()->allow_precompiled())) {
502+
base_wasm->wasm()->error("Failed to load Wasm code");
503+
return nullptr;
504+
}
483505
ContextBase *root_context = wasm_handle->wasm()->start(plugin);
506+
if (!root_context) {
507+
base_wasm->wasm()->error("Failed to start thread-local Wasm");
508+
return nullptr;
509+
}
484510
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
485-
base_wasm->wasm()->error("Failed to configure WASM code");
511+
base_wasm->wasm()->error("Failed to configure thread-local Wasm plugin");
486512
return nullptr;
487513
}
488514
local_wasms[std::string(wasm_handle->wasm()->vm_key())] = wasm_handle;
@@ -508,7 +534,7 @@ getOrCreateThreadLocalWasm(std::shared_ptr<WasmHandleBase> base_wasm,
508534
if (wasm_handle) {
509535
auto root_context = wasm_handle->wasm()->getOrCreateRootContext(plugin);
510536
if (!wasm_handle->wasm()->configure(root_context, plugin)) {
511-
base_wasm->wasm()->error("Failed to configure WASM code");
537+
base_wasm->wasm()->error("Failed to configure thread-local Wasm code");
512538
return nullptr;
513539
}
514540
return wasm_handle;

src/wavm/wavm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
124124
return true;
125125
}
126126
}
127-
vm_->error("Failed to load WASM module due to a missing import: " + std::string(module_name) +
127+
vm_->error("Failed to load Wasm module due to a missing import: " + std::string(module_name) +
128128
"." + std::string(export_name) + " " + asString(type));
129129
}
130130

0 commit comments

Comments
 (0)