Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def proxy_wasm_cpp_host_repositories():
http_archive,
name = "com_github_wasmedge_wasmedge",
build_file = "@proxy_wasm_cpp_host//bazel/external:wasmedge.BUILD",
sha256 = "6724955a967a1457bcf5dc1787a8da95feaba45d3b498ae42768ebf48f587299",
strip_prefix = "WasmEdge-proxy-wasm-0.9.1",
url = "https://github.com/WasmEdge/WasmEdge/archive/refs/tags/proxy-wasm/0.9.1.tar.gz",
sha256 = "4cff44e8c805ed4364d326ff1dd40e3aeb21ba1a11388372386eea1ccc7f93dd",
strip_prefix = "WasmEdge-proxy-wasm-0.10.0",
url = "https://github.com/WasmEdge/WasmEdge/archive/refs/tags/proxy-wasm/0.10.0.tar.gz",
)

native.bind(
Expand Down
2 changes: 2 additions & 0 deletions src/wasmedge/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ using WasmEdgeLoaderPtr = common::CSmartPtr<WasmEdge_LoaderContext, WasmEdge_Loa
using WasmEdgeValidatorPtr = common::CSmartPtr<WasmEdge_ValidatorContext, WasmEdge_ValidatorDelete>;
using WasmEdgeExecutorPtr = common::CSmartPtr<WasmEdge_ExecutorContext, WasmEdge_ExecutorDelete>;
using WasmEdgeASTModulePtr = common::CSmartPtr<WasmEdge_ASTModuleContext, WasmEdge_ASTModuleDelete>;
using WasmEdgeModulePtr =
common::CSmartPtr<WasmEdge_ModuleInstanceContext, WasmEdge_ModuleInstanceDelete>;

} // namespace proxy_wasm::WasmEdge
78 changes: 41 additions & 37 deletions src/wasmedge/wasmedge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ using HostFuncDataPtr = std::unique_ptr<HostFuncData>;

struct HostModuleData {
HostModuleData(const std::string_view modname) {
cxt_ = WasmEdge_ImportObjectCreate(WasmEdge_StringWrap(modname.data(), modname.length()));
cxt_ = WasmEdge_ModuleInstanceCreate(WasmEdge_StringWrap(modname.data(), modname.length()));
}
~HostModuleData() { WasmEdge_ImportObjectDelete(cxt_); }
~HostModuleData() { WasmEdge_ModuleInstanceDelete(cxt_); }

WasmEdge_ImportObjectContext *cxt_;
WasmEdge_ModuleInstanceContext *cxt_;
};

using HostModuleDataPtr = std::unique_ptr<HostModuleData>;
Expand All @@ -228,6 +228,7 @@ class WasmEdge : public WasmVm {
validator_ = WasmEdge_ValidatorCreate(nullptr);
executor_ = WasmEdge_ExecutorCreate(nullptr, nullptr);
store_ = nullptr;
ast_module_ = nullptr;
module_ = nullptr;
memory_ = nullptr;
}
Expand Down Expand Up @@ -285,11 +286,12 @@ class WasmEdge : public WasmVm {
WasmEdgeValidatorPtr validator_;
WasmEdgeExecutorPtr executor_;
WasmEdgeStorePtr store_;
WasmEdgeASTModulePtr module_;
WasmEdgeASTModulePtr ast_module_;
WasmEdgeModulePtr module_;
WasmEdge_MemoryInstanceContext *memory_;

std::unordered_map<std::string, HostFuncDataPtr> host_functions_;
std::unordered_map<std::string, HostModuleDataPtr> import_objects_;
std::unordered_map<std::string, HostModuleDataPtr> host_modules_;
std::unordered_set<std::string> module_functions_;
};

Expand All @@ -303,22 +305,25 @@ bool WasmEdge::load(std::string_view bytecode, std::string_view /*precompiled*/,
}
res = WasmEdge_ValidatorValidate(validator_.get(), mod);
if (!WasmEdge_ResultOK(res)) {
WasmEdge_ASTModuleDelete(mod);
return false;
}
module_ = mod;
ast_module_ = mod;
return true;
}

bool WasmEdge::link(std::string_view /*debug_name*/) {
assert(module_ != nullptr);
assert(ast_module_ != nullptr);

// Create store and register imports.
store_ = WasmEdge_StoreCreate();
if (store_ == nullptr) {
store_ = WasmEdge_StoreCreate();
}
if (store_ == nullptr) {
return false;
}
WasmEdge_Result res;
for (auto &&it : import_objects_) {
for (auto &&it : host_modules_) {
res = WasmEdge_ExecutorRegisterImport(executor_.get(), store_.get(), it.second->cxt_);
if (!WasmEdge_ResultOK(res)) {
fail(FailState::UnableToInitializeCode,
Expand All @@ -327,30 +332,33 @@ bool WasmEdge::link(std::string_view /*debug_name*/) {
}
}
// Instantiate module.
res = WasmEdge_ExecutorInstantiate(executor_.get(), store_.get(), module_.get());
WasmEdge_ModuleInstanceContext *mod = nullptr;
res = WasmEdge_ExecutorInstantiate(executor_.get(), &mod, store_.get(), ast_module_.get());
if (!WasmEdge_ResultOK(res)) {
fail(FailState::UnableToInitializeCode,
std::string("Failed to link Wasm module: ") + std::string(WasmEdge_ResultGetMessage(res)));
return false;
}
// Get the function and memory exports.
uint32_t memory_num = WasmEdge_StoreListMemoryLength(store_.get());
uint32_t memory_num = WasmEdge_ModuleInstanceListMemoryLength(mod);
if (memory_num > 0) {
WasmEdge_String name;
WasmEdge_StoreListMemory(store_.get(), &name, 1);
memory_ = WasmEdge_StoreFindMemory(store_.get(), name);
WasmEdge_ModuleInstanceListMemory(mod, &name, 1);
memory_ = WasmEdge_ModuleInstanceFindMemory(mod, name);
if (memory_ == nullptr) {
WasmEdge_ModuleInstanceDelete(mod);
return false;
}
}
uint32_t func_num = WasmEdge_StoreListFunctionLength(store_.get());
uint32_t func_num = WasmEdge_ModuleInstanceListFunctionLength(mod);
if (func_num > 0) {
std::vector<WasmEdge_String> names(func_num);
WasmEdge_StoreListFunction(store_.get(), &names[0], func_num);
WasmEdge_ModuleInstanceListFunction(mod, &names[0], func_num);
for (auto i = 0; i < func_num; i++) {
module_functions_.insert(std::string(names[i].Buf, names[i].Length));
}
}
module_ = mod;
return true;
}

Expand Down Expand Up @@ -398,10 +406,10 @@ bool WasmEdge::setWord(uint64_t pointer, Word word) {
template <typename... Args>
void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
std::string_view function_name, void (*function)(Args...)) {
auto it = import_objects_.find(std::string(module_name));
if (it == import_objects_.end()) {
import_objects_.emplace(module_name, std::make_unique<HostModuleData>(module_name));
it = import_objects_.find(std::string(module_name));
auto it = host_modules_.find(std::string(module_name));
if (it == host_modules_.end()) {
host_modules_.emplace(module_name, std::make_unique<HostModuleData>(module_name));
it = host_modules_.find(std::string(module_name));
}

auto data = std::make_unique<HostFuncData>(module_name, function_name);
Expand Down Expand Up @@ -435,7 +443,7 @@ void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
return;
}

WasmEdge_ImportObjectAddFunction(
WasmEdge_ModuleInstanceAddFunction(
it->second->cxt_, WasmEdge_StringWrap(function_name.data(), function_name.length()),
hostfunc_cxt);
host_functions_.insert_or_assign(std::string(module_name) + "." + std::string(function_name),
Expand All @@ -445,10 +453,10 @@ void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
template <typename R, typename... Args>
void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
std::string_view function_name, R (*function)(Args...)) {
auto it = import_objects_.find(std::string(module_name));
if (it == import_objects_.end()) {
import_objects_.emplace(module_name, std::make_unique<HostModuleData>(module_name));
it = import_objects_.find(std::string(module_name));
auto it = host_modules_.find(std::string(module_name));
if (it == host_modules_.end()) {
host_modules_.emplace(module_name, std::make_unique<HostModuleData>(module_name));
it = host_modules_.find(std::string(module_name));
}

auto data = std::make_unique<HostFuncData>(module_name, function_name);
Expand Down Expand Up @@ -482,7 +490,7 @@ void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
return;
}

WasmEdge_ImportObjectAddFunction(
WasmEdge_ModuleInstanceAddFunction(
it->second->cxt_, WasmEdge_StringWrap(function_name.data(), function_name.length()),
hostfunc_cxt);
host_functions_.insert_or_assign(std::string(module_name) + "." + std::string(function_name),
Expand All @@ -492,8 +500,8 @@ void WasmEdge::registerHostFunctionImpl(std::string_view module_name,
template <typename... Args>
void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
std::function<void(ContextBase *, Args...)> *function) {
auto *func_cxt = WasmEdge_StoreFindFunction(
store_.get(), WasmEdge_StringWrap(function_name.data(), function_name.length()));
auto *func_cxt = WasmEdge_ModuleInstanceFindFunction(
module_.get(), WasmEdge_StringWrap(function_name.data(), function_name.length()));
if (!func_cxt) {
*function = nullptr;
return;
Expand Down Expand Up @@ -521,7 +529,7 @@ void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
return;
}

*function = [function_name, this](ContextBase *context, Args... args) -> void {
*function = [function_name, func_cxt, this](ContextBase *context, Args... args) -> void {
WasmEdge_Value params[] = {makeVal(args)...};
const bool log = cmpLogLevel(LogLevel::trace);
if (log) {
Expand All @@ -530,9 +538,7 @@ void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
}
SaveRestoreContext saved_context(context);
WasmEdge_Result res =
WasmEdge_ExecutorInvoke(executor_.get(), store_.get(),
WasmEdge_StringWrap(function_name.data(), function_name.length()),
params, sizeof...(Args), nullptr, 0);
WasmEdge_ExecutorInvoke(executor_.get(), func_cxt, params, sizeof...(Args), nullptr, 0);
if (!WasmEdge_ResultOK(res)) {
fail(FailState::RuntimeError, "Function: " + std::string(function_name) + " failed:\n" +
WasmEdge_ResultGetMessage(res));
Expand All @@ -547,8 +553,8 @@ void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
template <typename R, typename... Args>
void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
std::function<R(ContextBase *, Args...)> *function) {
auto *func_cxt = WasmEdge_StoreFindFunction(
store_.get(), WasmEdge_StringWrap(function_name.data(), function_name.length()));
auto *func_cxt = WasmEdge_ModuleInstanceFindFunction(
module_.get(), WasmEdge_StringWrap(function_name.data(), function_name.length()));
if (!func_cxt) {
*function = nullptr;
return;
Expand Down Expand Up @@ -576,7 +582,7 @@ void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
return;
}

*function = [function_name, this](ContextBase *context, Args... args) -> R {
*function = [function_name, func_cxt, this](ContextBase *context, Args... args) -> R {
WasmEdge_Value params[] = {makeVal(args)...};
WasmEdge_Value results[1];
const bool log = cmpLogLevel(LogLevel::trace);
Expand All @@ -586,9 +592,7 @@ void WasmEdge::getModuleFunctionImpl(std::string_view function_name,
}
SaveRestoreContext saved_context(context);
WasmEdge_Result res =
WasmEdge_ExecutorInvoke(executor_.get(), store_.get(),
WasmEdge_StringWrap(function_name.data(), function_name.length()),
params, sizeof...(Args), results, 1);
WasmEdge_ExecutorInvoke(executor_.get(), func_cxt, params, sizeof...(Args), results, 1);
if (!WasmEdge_ResultOK(res)) {
fail(FailState::RuntimeError, "Function: " + std::string(function_name) + " failed:\n" +
WasmEdge_ResultGetMessage(res));
Expand Down