diff --git a/include/proxy-wasm/null_vm.h b/include/proxy-wasm/null_vm.h index 494c5f80d..3a38fd990 100644 --- a/include/proxy-wasm/null_vm.h +++ b/include/proxy-wasm/null_vm.h @@ -60,6 +60,8 @@ struct NullVm : public WasmVm { FOR_ALL_WASM_VM_IMPORTS(_REGISTER_CALLBACK) #undef _REGISTER_CALLBACK + void terminate() override {} + std::string plugin_name_; std::unique_ptr plugin_; }; diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 1f857951e..800348ac3 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -297,6 +297,11 @@ class WasmVm { FOR_ALL_WASM_VM_IMPORTS(_REGISTER_CALLBACK) #undef _REGISTER_CALLBACK + /** + * Terminate execution of this WasmVM. It shouldn't be used after being terminated. + */ + virtual void terminate() = 0; + bool isFailed() { return failed_ != FailState::Ok; } void fail(FailState fail_state, std::string_view message) { integration()->error(message); diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 51969e82d..71a7483e8 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -21,12 +21,14 @@ #include #include #include +#include #include #include #include #include "include/v8-version.h" #include "include/v8.h" +#include "src/wasm/c-api.h" #include "wasm-api/wasm.hh" namespace proxy_wasm { @@ -92,6 +94,8 @@ class V8 : public WasmVm { FOR_ALL_WASM_VM_EXPORTS(_GET_MODULE_FUNCTION) #undef _GET_MODULE_FUNCTION + void terminate() override; + private: std::string getFailMessage(std::string_view function_name, wasm::own trap); @@ -657,6 +661,16 @@ void V8::getModuleFunctionImpl(std::string_view function_name, }; } +void V8::terminate() { + auto *store_impl = reinterpret_cast(store_.get()); + auto *isolate = store_impl->isolate(); + isolate->TerminateExecution(); + while (isolate->IsExecutionTerminating()) { + std::this_thread::yield(); + } + integration()->trace("[host->vm] Terminated"); +} + std::string V8::getFailMessage(std::string_view function_name, wasm::own trap) { auto message = "Function: " + std::string(function_name) + " failed: "; message += std::string(trap->message().get(), trap->message().size()); diff --git a/src/wamr/wamr.cc b/src/wamr/wamr.cc index d9b28f751..93cae4d58 100644 --- a/src/wamr/wamr.cc +++ b/src/wamr/wamr.cc @@ -85,6 +85,9 @@ class Wamr : public WasmVm { }; FOR_ALL_WASM_VM_EXPORTS(_GET_MODULE_FUNCTION) #undef _GET_MODULE_FUNCTION + + void terminate() override {} + private: template void registerHostFunctionImpl(std::string_view module_name, std::string_view function_name, diff --git a/src/wasmtime/wasmtime.cc b/src/wasmtime/wasmtime.cc index 44eb1b889..4bbc32b5f 100644 --- a/src/wasmtime/wasmtime.cc +++ b/src/wasmtime/wasmtime.cc @@ -97,6 +97,8 @@ class Wasmtime : public WasmVm { void getModuleFunctionImpl(std::string_view function_name, std::function *function); + void terminate() override {} + WasmStorePtr store_; WasmModulePtr module_; WasmSharedModulePtr shared_module_; diff --git a/src/wavm/wavm.cc b/src/wavm/wavm.cc index 6e53b4d81..1041b7fa7 100644 --- a/src/wavm/wavm.cc +++ b/src/wavm/wavm.cc @@ -229,6 +229,8 @@ struct Wavm : public WasmVm { FOR_ALL_WASM_VM_IMPORTS(_REGISTER_CALLBACK) #undef _REGISTER_CALLBACK + void terminate() override {} + IR::Module ir_module_; WAVM::Runtime::ModuleRef module_ = nullptr; WAVM::Runtime::GCPointer module_instance_; diff --git a/test/BUILD b/test/BUILD index 756a14ac6..46514b982 100644 --- a/test/BUILD +++ b/test/BUILD @@ -56,6 +56,7 @@ cc_test( data = [ "//test/test_data:abi_export.wasm", "//test/test_data:callback.wasm", + "//test/test_data:infinite_loop.wasm", "//test/test_data:trap.wasm", ], linkstatic = 1, diff --git a/test/runtime_test.cc b/test/runtime_test.cc index 649fb1069..132f8247d 100644 --- a/test/runtime_test.cc +++ b/test/runtime_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "include/proxy-wasm/context.h" @@ -245,6 +246,34 @@ TEST_P(TestVM, Callback) { ASSERT_EQ(res.u32(), 100100); // 10000 (global) + 100(in callback) } +TEST_P(TestVM, TerminateExecution) { + // TODO(chaoqin-li1123): implement execution termination for other runtime. + if (engine_ != "v8") { + return; + } + auto source = readTestWasmFile("infinite_loop.wasm"); + ASSERT_TRUE(vm_->load(source, {}, {})); + + TestContext context; + + std::thread terminate([&]() { + std::this_thread::sleep_for(std::chrono::seconds(3)); + vm_->terminate(); + }); + + ASSERT_TRUE(vm_->link("")); + WasmCallVoid<0> infinite_loop; + vm_->getFunction("infinite_loop", &infinite_loop); + ASSERT_TRUE(infinite_loop != nullptr); + infinite_loop(&context); + + terminate.join(); + + std::string exp_message = "Function: infinite_loop failed: Uncaught Error: termination_exception"; + auto *integration = dynamic_cast(vm_->integration().get()); + ASSERT_TRUE(integration->error_message_.find(exp_message) != std::string::npos); +} + TEST_P(TestVM, Trap) { auto source = readTestWasmFile("trap.wasm"); ASSERT_TRUE(vm_->load(source, {}, {})); diff --git a/test/test_data/BUILD b/test/test_data/BUILD index c6ed213e3..83f781f00 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -37,6 +37,11 @@ wasm_rust_binary( srcs = ["trap.rs"], ) +wasm_rust_binary( + name = "infinite_loop.wasm", + srcs = ["infinite_loop.rs"], +) + wasm_rust_binary( name = "env.wasm", srcs = ["env.rs"], diff --git a/test/test_data/infinite_loop.rs b/test/test_data/infinite_loop.rs new file mode 100644 index 000000000..c502be879 --- /dev/null +++ b/test/test_data/infinite_loop.rs @@ -0,0 +1,21 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[no_mangle] +pub extern "C" fn infinite_loop() { + let mut _count: u64 = 0; + loop { + _count += 1; + } +}