diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 873531afd..23ed3c1da 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -392,7 +392,15 @@ inline void *WasmBase::allocMemory(uint64_t size, uint64_t *address) { if (!malloc_) { return nullptr; } + wasm_vm_->setRestrictedCallback( + true, {// logging (Proxy-Wasm) + "env.proxy_log", + // logging (stdout/stderr) + "wasi_unstable.fd_write", "wasi_snapshot_preview1.fd_write", + // time + "wasi_unstable.clock_time_get", "wasi_snapshot_preview1.clock_time_get"}); Word a = malloc_(vm_context(), size); + wasm_vm_->setRestrictedCallback(false); if (!a.u64_) { return nullptr; } diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 800348ac3..879e200d1 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "include/proxy-wasm/word.h" @@ -314,6 +315,16 @@ class WasmVm { fail_callbacks_.push_back(fail_callback); } + bool isHostFunctionAllowed(const std::string &name) { + return !restricted_callback_ || allowed_hostcalls_.find(name) != allowed_hostcalls_.end(); + } + + void setRestrictedCallback(bool restricted, + std::unordered_set allowed_hostcalls = {}) { + restricted_callback_ = restricted; + allowed_hostcalls_ = std::move(allowed_hostcalls); + } + // Integrator operations. std::unique_ptr &integration() { return integration_; } bool cmpLogLevel(proxy_wasm::LogLevel level) { return integration_->getLogLevel() <= level; } @@ -322,6 +333,10 @@ class WasmVm { std::unique_ptr integration_; FailState failed_ = FailState::Ok; std::vector> fail_callbacks_; + +private: + bool restricted_callback_{false}; + std::unordered_set allowed_hostcalls_{}; }; // Thread local state set during a call into a WASM VM so that calls coming out of the diff --git a/src/v8/v8.cc b/src/v8/v8.cc index 7603ea671..ad43e0799 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -107,6 +107,8 @@ class V8 : public WasmVm { void terminate() override; private: + wasm::own trap(std::string message); + std::string getFailMessage(std::string_view function_name, wasm::own trap); template @@ -503,6 +505,10 @@ bool V8::setWord(uint64_t pointer, Word word) { return true; } +wasm::own V8::trap(std::string message) { + return wasm::Trap::make(store_.get(), wasm::Message::make(std::move(message))); +} + template void V8::registerHostFunctionImpl(std::string_view module_name, std::string_view function_name, void (*function)(Args...)) { @@ -519,6 +525,9 @@ void V8::registerHostFunctionImpl(std::string_view module_name, std::string_view func_data->vm_->integration()->trace("[vm->host] " + func_data->name_ + "(" + printValues(params, sizeof...(Args)) + ")"); } + if (!func_data->vm_->isHostFunctionAllowed(func_data->name_)) { + return dynamic_cast(func_data->vm_)->trap("restricted_callback"); + } auto args = convertValTypesToArgsTuple>(params); auto function = reinterpret_cast(func_data->raw_func_); std::apply(function, args); @@ -552,6 +561,9 @@ void V8::registerHostFunctionImpl(std::string_view module_name, std::string_view func_data->vm_->integration()->trace("[vm->host] " + func_data->name_ + "(" + printValues(params, sizeof...(Args)) + ")"); } + if (!func_data->vm_->isHostFunctionAllowed(func_data->name_)) { + return dynamic_cast(func_data->vm_)->trap("restricted_callback"); + } auto args = convertValTypesToArgsTuple>(params); auto function = reinterpret_cast(func_data->raw_func_); R rvalue = std::apply(function, args); diff --git a/src/wasm.cc b/src/wasm.cc index fe21d7925..5519b3e77 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -355,6 +355,25 @@ ContextBase *WasmBase::getRootContext(const std::shared_ptr &plugin, } void WasmBase::startVm(ContextBase *root_context) { + // wasi_snapshot_preview1.clock_time_get + wasm_vm_->setRestrictedCallback( + true, {// logging (Proxy-Wasm) + "env.proxy_log", + // logging (stdout/stderr) + "wasi_unstable.fd_write", "wasi_snapshot_preview1.fd_write", + // args + "wasi_unstable.args_sizes_get", "wasi_snapshot_preview1.args_sizes_get", + "wasi_unstable.args_get", "wasi_snapshot_preview1.args_get", + // environment variables + "wasi_unstable.environ_sizes_get", "wasi_snapshot_preview1.environ_sizes_get", + "wasi_unstable.environ_get", "wasi_snapshot_preview1.environ_get", + // preopened files/directories + "wasi_unstable.fd_prestat_get", "wasi_snapshot_preview1.fd_prestat_get", + "wasi_unstable.fd_prestat_dir_name", "wasi_snapshot_preview1.fd_prestat_dir_name", + // time + "wasi_unstable.clock_time_get", "wasi_snapshot_preview1.clock_time_get", + // random + "wasi_unstable.random_get", "wasi_snapshot_preview1.random_get"}); if (_initialize_) { // WASI reactor. _initialize_(root_context); @@ -370,6 +389,7 @@ void WasmBase::startVm(ContextBase *root_context) { // WASI command. _start_(root_context); } + wasm_vm_->setRestrictedCallback(false); } bool WasmBase::configure(ContextBase *root_context, std::shared_ptr plugin) { diff --git a/test/BUILD b/test/BUILD index 196cbc590..42748b92d 100644 --- a/test/BUILD +++ b/test/BUILD @@ -102,6 +102,21 @@ cc_test( ], ) +cc_test( + name = "security_test", + srcs = ["security_test.cc"], + data = [ + "//test/test_data:bad_malloc.wasm", + ], + linkstatic = 1, + deps = [ + ":utility_lib", + "//:lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "shared_data", srcs = ["shared_data_test.cc"], diff --git a/test/security_test.cc b/test/security_test.cc new file mode 100644 index 000000000..a077e4ae6 --- /dev/null +++ b/test/security_test.cc @@ -0,0 +1,106 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include +#include + +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/wasm.h" + +#include "test/utility.h" + +namespace proxy_wasm { + +INSTANTIATE_TEST_SUITE_P(WasmEngines, TestVm, testing::ValuesIn(getWasmEngines()), + [](const testing::TestParamInfo &info) { + return info.param; + }); + +TEST_P(TestVm, MallocNoHostcalls) { + if (engine_ != "v8") { + return; + } + auto source = readTestWasmFile("bad_malloc.wasm"); + ASSERT_FALSE(source.empty()); + auto wasm = TestWasm(std::move(vm_)); + ASSERT_TRUE(wasm.load(source, false)); + ASSERT_TRUE(wasm.initialize()); + + uint64_t ptr = 0; + void *result = wasm.allocMemory(0x1000, &ptr); + EXPECT_NE(result, nullptr); + EXPECT_FALSE(wasm.isFailed()); + + // Check application logs. + auto *context = dynamic_cast(wasm.vm_context()); + EXPECT_TRUE(context->isLogEmpty()); + // Check integration logs. + auto *integration = dynamic_cast(wasm.wasm_vm()->integration().get()); + EXPECT_FALSE(integration->isErrorLogged("Function: proxy_on_memory_allocate failed")); + EXPECT_FALSE(integration->isErrorLogged("restricted_callback")); +} + +TEST_P(TestVm, MallocWithLog) { + if (engine_ != "v8") { + return; + } + auto source = readTestWasmFile("bad_malloc.wasm"); + ASSERT_FALSE(source.empty()); + auto wasm = TestWasm(std::move(vm_)); + ASSERT_TRUE(wasm.load(source, false)); + ASSERT_TRUE(wasm.initialize()); + + uint64_t ptr = 0; + // 0xAAAA => hostcall to proxy_log (allowed). + void *result = wasm.allocMemory(0xAAAA, &ptr); + EXPECT_NE(result, nullptr); + EXPECT_FALSE(wasm.isFailed()); + + // Check application logs. + auto *context = dynamic_cast(wasm.vm_context()); + EXPECT_TRUE(context->isLogged("this is fine")); + // Check integration logs. + auto *integration = dynamic_cast(wasm.wasm_vm()->integration().get()); + EXPECT_FALSE(integration->isErrorLogged("Function: proxy_on_memory_allocate failed")); + EXPECT_FALSE(integration->isErrorLogged("restricted_callback")); +} + +TEST_P(TestVm, MallocWithHostcall) { + if (engine_ != "v8") { + return; + } + auto source = readTestWasmFile("bad_malloc.wasm"); + ASSERT_FALSE(source.empty()); + auto wasm = TestWasm(std::move(vm_)); + ASSERT_TRUE(wasm.load(source, false)); + ASSERT_TRUE(wasm.initialize()); + + uint64_t ptr = 0; + // 0xBBBB => hostcall to proxy_done (not allowed). + void *result = wasm.allocMemory(0xBBBB, &ptr); + EXPECT_EQ(result, nullptr); + EXPECT_TRUE(wasm.isFailed()); + + // Check application logs. + auto *context = dynamic_cast(wasm.vm_context()); + EXPECT_TRUE(context->isLogEmpty()); + // Check integration logs. + auto *integration = dynamic_cast(wasm.wasm_vm()->integration().get()); + EXPECT_TRUE(integration->isErrorLogged("Function: proxy_on_memory_allocate failed")); + EXPECT_TRUE(integration->isErrorLogged("restricted_callback")); +} + +} // namespace proxy_wasm diff --git a/test/test_data/BUILD b/test/test_data/BUILD index c6892954e..38612ca15 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -28,6 +28,11 @@ wasm_rust_binary( ], ) +wasm_rust_binary( + name = "bad_malloc.wasm", + srcs = ["bad_malloc.rs"], +) + wasm_rust_binary( name = "callback.wasm", srcs = ["callback.rs"], diff --git a/test/test_data/bad_malloc.rs b/test/test_data/bad_malloc.rs new file mode 100644 index 000000000..43c60da02 --- /dev/null +++ b/test/test_data/bad_malloc.rs @@ -0,0 +1,47 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::mem::MaybeUninit; + +extern "C" { + fn proxy_log(level: u32, message_data: *const u8, message_size: usize) -> u32; + fn proxy_done() -> u32; +} + +#[no_mangle] +pub extern "C" fn proxy_abi_version_0_2_0() {} + +#[no_mangle] +pub extern "C" fn proxy_on_memory_allocate(size: usize) -> *mut u8 { + let mut vec: Vec> = Vec::with_capacity(size); + unsafe { + vec.set_len(size); + } + match size { + 0xAAAA => { + let message = "this is fine"; + unsafe { + proxy_log(0, message.as_ptr(), message.len()); + } + } + 0xBBBB => { + unsafe { + proxy_done(); + } + } + _ => {} + } + let slice = vec.into_boxed_slice(); + Box::into_raw(slice) as *mut u8 +}