diff --git a/bazel/wasm.bzl b/bazel/wasm.bzl index f76407a2c..aba74aeb2 100644 --- a/bazel/wasm.bzl +++ b/bazel/wasm.bzl @@ -19,6 +19,11 @@ def _wasm_rust_transition_impl(settings, attr): "//command_line_option:platforms": "@rules_rust//rust/platform:wasm", } +def _wasi_rust_transition_impl(settings, attr): + return { + "//command_line_option:platforms": "@rules_rust//rust/platform:wasi", + } + wasm_rust_transition = transition( implementation = _wasm_rust_transition_impl, inputs = [], @@ -27,6 +32,14 @@ wasm_rust_transition = transition( ], ) +wasi_rust_transition = transition( + implementation = _wasi_rust_transition_impl, + inputs = [], + outputs = [ + "//command_line_option:platforms", + ], +) + def _wasm_binary_impl(ctx): out = ctx.actions.declare_file(ctx.label.name) ctx.actions.run( @@ -49,7 +62,12 @@ wasm_rust_binary_rule = rule( attrs = _wasm_attrs(wasm_rust_transition), ) -def wasm_rust_binary(name, tags = [], **kwargs): +wasi_rust_binary_rule = rule( + implementation = _wasm_binary_impl, + attrs = _wasm_attrs(wasi_rust_transition), +) + +def wasm_rust_binary(name, tags = [], wasi = False, **kwargs): wasm_name = "_wasm_" + name.replace(".", "_") kwargs.setdefault("visibility", ["//visibility:public"]) @@ -62,7 +80,11 @@ def wasm_rust_binary(name, tags = [], **kwargs): **kwargs ) - wasm_rust_binary_rule( + bin_rule = wasm_rust_binary_rule + if wasi: + bin_rule = wasi_rust_binary_rule + + bin_rule( name = name, binary = ":" + wasm_name, tags = tags + ["manual"], diff --git a/include/proxy-wasm/null_vm.h b/include/proxy-wasm/null_vm.h index bfa62d896..d731060a2 100644 --- a/include/proxy-wasm/null_vm.h +++ b/include/proxy-wasm/null_vm.h @@ -43,6 +43,7 @@ struct NullVm : public WasmVm { bool setMemory(uint64_t pointer, uint64_t size, const void *data) override; bool setWord(uint64_t pointer, Word data) override; bool getWord(uint64_t pointer, Word *data) override; + size_t getWordSize() override; std::string_view getCustomSection(std::string_view name) override; std::string_view getPrecompiledSectionName() override; diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 255d2273a..9caf90173 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -53,6 +53,7 @@ class WasmBase : public std::enable_shared_from_this { public: WasmBase(std::unique_ptr wasm_vm, std::string_view vm_id, std::string_view vm_configuration, std::string_view vm_key, + std::unordered_map envs, AllowedCapabilitiesMap allowed_capabilities); WasmBase(const std::shared_ptr &other, WasmVmFactory factory); virtual ~WasmBase(); @@ -136,6 +137,8 @@ class WasmBase : public std::enable_shared_from_this { AbiVersion abiVersion() { return abi_version_; } + const std::unordered_map &envs() { return envs_; } + // Called to raise the flag which indicates that the context should stop iteration regardless of // returned filter status from Proxy-Wasm extensions. For example, we ignore // FilterHeadersStatus::Continue after a local reponse is sent by the host. @@ -190,6 +193,8 @@ class WasmBase : public std::enable_shared_from_this { std::unordered_map contexts_; // Contains all contexts. std::unordered_map timer_period_; // per root_id. std::unique_ptr shutdown_handle_; + std::unordered_map + envs_; // environment variables passed through wasi.environ_get WasmCallVoid<0> _initialize_; /* Emscripten v1.39.17+ */ WasmCallVoid<0> _start_; /* Emscripten v1.39.0+ */ diff --git a/include/proxy-wasm/wasm_vm.h b/include/proxy-wasm/wasm_vm.h index 9e9a00700..e88aa3784 100644 --- a/include/proxy-wasm/wasm_vm.h +++ b/include/proxy-wasm/wasm_vm.h @@ -243,6 +243,11 @@ class WasmVm { */ virtual bool setWord(uint64_t pointer, Word data) = 0; + /** + * @return the Word size in this VM. + */ + virtual size_t getWordSize() = 0; + /** * Get the contents of the custom section with the given name or "" if it does not exist. * @param name the name of the custom section to get. diff --git a/src/exports.cc b/src/exports.cc index 1ffee3228..b5cd20ae3 100644 --- a/src/exports.cc +++ b/src/exports.cc @@ -767,7 +767,28 @@ Word wasi_unstable_fd_fdstat_get(void *raw_context, Word fd, Word statOut) { } // __wasi_errno_t __wasi_environ_get(char **environ, char *environ_buf); -Word wasi_unstable_environ_get(void *, Word, Word) { +Word wasi_unstable_environ_get(void *raw_context, Word environ_array_ptr, Word environ_buf) { + auto context = WASM_CONTEXT(raw_context); + auto word_size = context->wasmVm()->getWordSize(); + auto &envs = context->wasm()->envs(); + for (auto e : envs) { + if (!context->wasmVm()->setWord(environ_array_ptr, environ_buf)) { + return 21; // __WASI_EFAULT + } + + std::string data; + data.reserve(e.first.size() + e.second.size() + 2); + data.append(e.first); + data.append("="); + data.append(e.second); + data.append("\x0"); + if (!context->wasmVm()->setMemory(environ_buf, data.size(), data.c_str())) { + return 21; // __WASI_EFAULT + } + environ_buf = environ_buf.u64_ + data.size(); + environ_array_ptr = environ_array_ptr.u64_ + word_size; + } + return 0; // __WASI_ESUCCESS } @@ -775,10 +796,17 @@ Word wasi_unstable_environ_get(void *, Word, Word) { // *environ_buf_size); Word wasi_unstable_environ_sizes_get(void *raw_context, Word count_ptr, Word buf_size_ptr) { auto context = WASM_CONTEXT(raw_context); - if (!context->wasmVm()->setWord(count_ptr, Word(0))) { + auto &envs = context->wasm()->envs(); + if (!context->wasmVm()->setWord(count_ptr, Word(envs.size()))) { return 21; // __WASI_EFAULT } - if (!context->wasmVm()->setWord(buf_size_ptr, Word(0))) { + + size_t size = 0; + for (auto e : envs) { + // len(key) + len(value) + 1('=') + 1(null terminator) + size += e.first.size() + e.second.size() + 2; + } + if (!context->wasmVm()->setWord(buf_size_ptr, Word(size))) { return 21; // __WASI_EFAULT } return 0; // __WASI_ESUCCESS diff --git a/src/null/null_vm.cc b/src/null/null_vm.cc index 2159c1223..80a2210e0 100644 --- a/src/null/null_vm.cc +++ b/src/null/null_vm.cc @@ -102,6 +102,8 @@ bool NullVm::getWord(uint64_t pointer, Word *data) { return true; } +size_t NullVm::getWordSize() { return sizeof(uint64_t); } + std::string_view NullVm::getCustomSection(std::string_view /* name */) { // Return nothing: there is no WASM file. return {}; diff --git a/src/v8/v8.cc b/src/v8/v8.cc index c43bfb77c..a873daff7 100644 --- a/src/v8/v8.cc +++ b/src/v8/v8.cc @@ -67,6 +67,7 @@ class V8 : public WasmVm { bool setMemory(uint64_t pointer, uint64_t size, const void *data) override; bool getWord(uint64_t pointer, Word *word) override; bool setWord(uint64_t pointer, Word word) override; + size_t getWordSize() override { return sizeof(uint32_t); }; #define _REGISTER_HOST_FUNCTION(T) \ void registerCallback(std::string_view module_name, std::string_view function_name, T, \ diff --git a/src/wasm.cc b/src/wasm.cc index c4b17bb37..2a6117ff0 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -193,6 +193,7 @@ WasmBase::WasmBase(const std::shared_ptr &base_wasm_handle, Wasm : std::enable_shared_from_this(*base_wasm_handle->wasm()), vm_id_(base_wasm_handle->wasm()->vm_id_), vm_key_(base_wasm_handle->wasm()->vm_key_), started_from_(base_wasm_handle->wasm()->wasm_vm()->cloneable()), + envs_(base_wasm_handle->wasm()->envs()), allowed_capabilities_(base_wasm_handle->wasm()->allowed_capabilities_), base_wasm_handle_(base_wasm_handle) { if (started_from_ != Cloneable::NotCloneable) { @@ -209,9 +210,10 @@ WasmBase::WasmBase(const std::shared_ptr &base_wasm_handle, Wasm WasmBase::WasmBase(std::unique_ptr wasm_vm, std::string_view vm_id, std::string_view vm_configuration, std::string_view vm_key, + std::unordered_map envs, AllowedCapabilitiesMap allowed_capabilities) : vm_id_(std::string(vm_id)), vm_key_(std::string(vm_key)), wasm_vm_(std::move(wasm_vm)), - allowed_capabilities_(std::move(allowed_capabilities)), + envs_(envs), allowed_capabilities_(std::move(allowed_capabilities)), vm_configuration_(std::string(vm_configuration)), vm_id_handle_(getVmIdHandle(vm_id)) { if (!wasm_vm_) { failed_ = FailState::UnableToCreateVM; diff --git a/src/wasmtime/wasmtime.cc b/src/wasmtime/wasmtime.cc index b7d648b6e..82b87b036 100644 --- a/src/wasmtime/wasmtime.cc +++ b/src/wasmtime/wasmtime.cc @@ -65,6 +65,7 @@ class Wasmtime : public WasmVm { bool setMemory(uint64_t pointer, uint64_t size, const void *data) override; bool getWord(uint64_t pointer, Word *word) override; bool setWord(uint64_t pointer, Word word) override; + size_t getWordSize() override { return sizeof(uint32_t); }; #define _REGISTER_HOST_FUNCTION(T) \ void registerCallback(std::string_view module_name, std::string_view function_name, T, \ diff --git a/src/wavm/wavm.cc b/src/wavm/wavm.cc index 00b3b7358..f96b334fb 100644 --- a/src/wavm/wavm.cc +++ b/src/wavm/wavm.cc @@ -229,6 +229,7 @@ struct Wavm : public WasmVm { bool setMemory(uint64_t pointer, uint64_t size, const void *data) override; bool getWord(uint64_t pointer, Word *data) override; bool setWord(uint64_t pointer, Word data) override; + size_t getWordSize() override { return sizeof(uint32_t); }; std::string_view getCustomSection(std::string_view name) override; std::string_view getPrecompiledSectionName() override; AbiVersion getAbiVersion() override; diff --git a/test/BUILD b/test/BUILD index ce2c279be..52b6113d0 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,4 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_test") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") load("@proxy_wasm_cpp_host//bazel:variables.bzl", "COPTS", "LINKOPTS") cc_test( @@ -23,6 +23,23 @@ cc_test( ], linkopts = LINKOPTS, deps = [ + ":utility_lib", + "//:lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "exports_test", + srcs = ["exports_test.cc"], + copts = COPTS, + data = [ + "//test/test_data:env.wasm", + ], + linkopts = LINKOPTS, + deps = [ + ":utility_lib", "//:lib", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -72,3 +89,17 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "utility_lib", + srcs = [ + "utility.cc", + "utility.h", + ], + hdrs = ["utility.h"], + copts = COPTS, + deps = [ + "//:lib", + "@com_google_googletest//:gtest", + ], +) diff --git a/test/exports_test.cc b/test/exports_test.cc new file mode 100644 index 000000000..2e8ad43bd --- /dev/null +++ b/test/exports_test.cc @@ -0,0 +1,95 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include + +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/exports.h" +#include "include/proxy-wasm/wasm.h" + +#include "test/utility.h" + +namespace proxy_wasm { +namespace { + +auto test_values = testing::ValuesIn(getRuntimes()); + +INSTANTIATE_TEST_SUITE_P(Runtimes, TestVM, test_values); + +class TestContext : public ContextBase { +public: + TestContext(WasmBase *base) : ContextBase(base){}; + WasmResult log(uint32_t, std::string_view msg) override { + log_ += std::string(msg) + "\n"; + return WasmResult::Ok; + } + std::string &log_msg() { return log_; } + +private: + std::string log_; +}; + +TEST_P(TestVM, Environment) { + std::unordered_map envs = {{"KEY1", "VALUE1"}, {"KEY2", "VALUE2"}}; + initialize("env.wasm"); + + auto wasm_base = WasmBase(std::move(vm_), "vm_id", "", "", envs, {}); + ASSERT_TRUE(wasm_base.wasm_vm()->load(source_, false)); + + TestContext context(&wasm_base); + current_context_ = &context; + + wasm_base.registerCallbacks(); + + ASSERT_TRUE(wasm_base.wasm_vm()->link("")); + + WasmCallVoid<0> run; + wasm_base.wasm_vm()->getFunction("run", &run); + + run(current_context_); + + auto msg = context.log_msg(); + EXPECT_NE(std::string::npos, msg.find("KEY1: VALUE1")) << msg; + EXPECT_NE(std::string::npos, msg.find("KEY2: VALUE2")) << msg; +} + +TEST_P(TestVM, WithoutEnvironment) { + initialize("env.wasm"); + auto wasm_base = WasmBase(std::move(vm_), "vm_id", "", "", {}, {}); + ASSERT_TRUE(wasm_base.wasm_vm()->load(source_, false)); + + TestContext context(&wasm_base); + current_context_ = &context; + + wasm_base.registerCallbacks(); + + ASSERT_TRUE(wasm_base.wasm_vm()->link("")); + + WasmCallVoid<0> run; + wasm_base.wasm_vm()->getFunction("run", &run); + + run(current_context_); + + EXPECT_EQ(context.log_msg(), ""); +} + +} // namespace +} // namespace proxy_wasm diff --git a/test/runtime_test.cc b/test/runtime_test.cc index 8f5c524a1..078a1db05 100644 --- a/test/runtime_test.cc +++ b/test/runtime_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "gtest/gtest.h" + #include #include #include @@ -23,97 +24,11 @@ #include "include/proxy-wasm/context.h" #include "include/proxy-wasm/wasm.h" -#if defined(WASM_V8) -#include "include/proxy-wasm/v8.h" -#endif -#if defined(WASM_WAVM) -#include "include/proxy-wasm/wavm.h" -#endif -#if defined(WASM_WASMTIME) -#include "include/proxy-wasm/wasmtime.h" -#endif +#include "test/utility.h" namespace proxy_wasm { namespace { -struct DummyIntegration : public WasmVmIntegration { - ~DummyIntegration() override{}; - WasmVmIntegration *clone() override { return new DummyIntegration{}; } - void error(std::string_view message) override { - std::cout << "ERROR from integration: " << message << std::endl; - error_message_ = message; - } - void trace(std::string_view message) override { - std::cout << "TRACE from integration: " << message << std::endl; - trace_message_ = message; - } - bool getNullVmFunction(std::string_view function_name, bool returns_word, int number_of_arguments, - NullPlugin *plugin, void *ptr_to_function_return) override { - return false; - }; - - LogLevel getLogLevel() override { return log_level_; } - std::string error_message_; - std::string trace_message_; - LogLevel log_level_ = LogLevel::info; -}; - -class TestVM : public testing::TestWithParam { -public: - std::unique_ptr vm_; - - TestVM() : integration_(new DummyIntegration{}) { - runtime_ = GetParam(); - if (runtime_ == "") { - EXPECT_TRUE(false) << "runtime must not be empty"; -#if defined(WASM_V8) - } else if (runtime_ == "v8") { - vm_ = proxy_wasm::createV8Vm(); -#endif -#if defined(WASM_WAVM) - } else if (runtime_ == "wavm") { - vm_ = proxy_wasm::createWavmVm(); -#endif -#if defined(WASM_WASMTIME) - } else if (runtime_ == "wasmtime") { - vm_ = proxy_wasm::createWasmtimeVm(); -#endif - } - vm_->integration().reset(integration_); - } - - DummyIntegration *integration_; - - void initialize(std::string filename) { - auto path = "test/test_data/" + filename; - std::ifstream file(path, std::ios::binary); - EXPECT_FALSE(file.fail()) << "failed to open: " << path; - std::stringstream file_string_stream; - file_string_stream << file.rdbuf(); - source_ = file_string_stream.str(); - } - - std::string source_; - std::string runtime_; -}; - -static std::vector getRuntimes() { - std::vector runtimes = { -#if defined(WASM_V8) - "v8", -#endif -#if defined(WASM_WAVM) - "wavm", -#endif -#if defined(WASM_WASMTIME) - "wasmtime", -#endif - "" - }; - runtimes.pop_back(); - return runtimes; -} - auto test_values = testing::ValuesIn(getRuntimes()); INSTANTIATE_TEST_SUITE_P(Runtimes, TestVM, test_values); diff --git a/test/test_data/BUILD b/test/test_data/BUILD index 051f051cd..715ed92e3 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -16,3 +16,9 @@ wasm_rust_binary( name = "trap.wasm", srcs = ["trap.rs"], ) + +wasm_rust_binary( + name = "env.wasm", + srcs = ["env.rs"], + wasi = True, +) diff --git a/test/test_data/env.rs b/test/test_data/env.rs new file mode 100644 index 000000000..af7793b2e --- /dev/null +++ b/test/test_data/env.rs @@ -0,0 +1,28 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[no_mangle] +extern "C" { + fn __wasilibc_initialize_environ(); +} + +#[no_mangle] +pub extern "C" fn run() { + unsafe { + __wasilibc_initialize_environ(); + } + for (key, value) in std::env::vars() { + println!("{}: {}", key, value); + } +} diff --git a/test/utility.cc b/test/utility.cc new file mode 100644 index 000000000..7e7ba4177 --- /dev/null +++ b/test/utility.cc @@ -0,0 +1,35 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/utility.h" + +namespace proxy_wasm { + +std::vector getRuntimes() { + std::vector runtimes = { +#if defined(WASM_V8) + "v8", +#endif +#if defined(WASM_WAVM) + "wavm", +#endif +#if defined(WASM_WASMTIME) + "wasmtime", +#endif + "" + }; + runtimes.pop_back(); + return runtimes; +} +} // namespace proxy_wasm diff --git a/test/utility.h b/test/utility.h new file mode 100644 index 000000000..de844d8da --- /dev/null +++ b/test/utility.h @@ -0,0 +1,101 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include + +#include "include/proxy-wasm/context.h" +#include "include/proxy-wasm/wasm.h" + +#if defined(WASM_V8) +#include "include/proxy-wasm/v8.h" +#endif +#if defined(WASM_WAVM) +#include "include/proxy-wasm/wavm.h" +#endif +#if defined(WASM_WASMTIME) +#include "include/proxy-wasm/wasmtime.h" +#endif + +namespace proxy_wasm { + +struct DummyIntegration : public WasmVmIntegration { + ~DummyIntegration() override{}; + WasmVmIntegration *clone() override { return new DummyIntegration{}; } + void error(std::string_view message) override { + std::cout << "ERROR from integration: " << message << std::endl; + error_message_ = message; + } + void trace(std::string_view message) override { + std::cout << "TRACE from integration: " << message << std::endl; + trace_message_ = message; + } + bool getNullVmFunction(std::string_view function_name, bool returns_word, int number_of_arguments, + NullPlugin *plugin, void *ptr_to_function_return) override { + return false; + }; + + LogLevel getLogLevel() override { return log_level_; } + std::string error_message_; + std::string trace_message_; + LogLevel log_level_ = LogLevel::info; +}; + +class TestVM : public testing::TestWithParam { +public: + std::unique_ptr vm_; + + TestVM() : integration_(new DummyIntegration{}) { + runtime_ = GetParam(); + if (runtime_ == "") { + EXPECT_TRUE(false) << "runtime must not be empty"; +#if defined(WASM_V8) + } else if (runtime_ == "v8") { + vm_ = proxy_wasm::createV8Vm(); +#endif +#if defined(WASM_WAVM) + } else if (runtime_ == "wavm") { + vm_ = proxy_wasm::createWavmVm(); +#endif +#if defined(WASM_WASMTIME) + } else if (runtime_ == "wasmtime") { + vm_ = proxy_wasm::createWasmtimeVm(); +#endif + } + vm_->integration().reset(integration_); + } + + DummyIntegration *integration_; + + void initialize(std::string filename) { + auto path = "test/test_data/" + filename; + std::ifstream file(path, std::ios::binary); + EXPECT_FALSE(file.fail()) << "failed to open: " << path; + std::stringstream file_string_stream; + file_string_stream << file.rdbuf(); + source_ = file_string_stream.str(); + } + + std::string source_; + std::string runtime_; +}; + +std::vector getRuntimes(); + +} // namespace proxy_wasm