Skip to content
73 changes: 59 additions & 14 deletions src/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <limits>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
Expand All @@ -36,13 +37,56 @@ namespace proxy_wasm {

namespace {

// Map from Wasm Key to the local Wasm instance.
// Map from Wasm key to the thread-local Wasm instance.
thread_local std::unordered_map<std::string, std::weak_ptr<WasmHandleBase>> local_wasms;
// Wasm key queue to track stale entries in `local_wasms`.
thread_local std::queue<std::string> local_wasms_keys;

// Map from plugin key to the thread-local plugin instance.
thread_local std::unordered_map<std::string, std::weak_ptr<PluginHandleBase>> local_plugins;
// Plugin key queue to track stale entries in `local_plugins`.
thread_local std::queue<std::string> local_plugins_keys;

// Check no more than `MAX_LOCAL_CACHE_GC_CHUNK_SIZE` cache entries at a time during stale entries
// cleanup.
const size_t MAX_LOCAL_CACHE_GC_CHUNK_SIZE = 64;

// Map from Wasm Key to the base Wasm instance, using a pointer to avoid the initialization fiasco.
std::mutex base_wasms_mutex;
std::unordered_map<std::string, std::weak_ptr<WasmHandleBase>> *base_wasms = nullptr;

void cacheLocalWasm(const std::string &key, const std::shared_ptr<WasmHandleBase> &wasm_handle) {
local_wasms[key] = wasm_handle;
local_wasms_keys.emplace(key);
}

void cacheLocalPlugin(const std::string &key,
const std::shared_ptr<PluginHandleBase> &plugin_handle) {
local_plugins[key] = plugin_handle;
local_plugins_keys.emplace(key);
}

template <class T>
void removeStaleLocalCacheEntries(std::unordered_map<std::string, std::weak_ptr<T>> &cache,
std::queue<std::string> &keys) {
auto num_keys_to_check = std::min(MAX_LOCAL_CACHE_GC_CHUNK_SIZE, keys.size());
for (size_t i = 0; i < num_keys_to_check; i++) {
std::string key(keys.front());
keys.pop();

const auto it = cache.find(key);
if (it == cache.end()) {
continue;
}

if (it->second.expired()) {
cache.erase(it);
} else {
keys.push(std::move(key));
}
}
}

} // namespace

std::string makeVmKey(std::string_view vm_id, std::string_view vm_configuration,
Expand Down Expand Up @@ -525,14 +569,15 @@ std::shared_ptr<WasmHandleBase> createWasm(const std::string &vm_key, const std:

std::shared_ptr<WasmHandleBase> getThreadLocalWasm(std::string_view vm_key) {
auto it = local_wasms.find(std::string(vm_key));
if (it == local_wasms.end()) {
return nullptr;
}
auto wasm = it->second.lock();
if (!wasm) {
local_wasms.erase(std::string(vm_key));
if (it != local_wasms.end()) {
auto wasm = it->second.lock();
if (wasm) {
return wasm;
}
local_wasms.erase(it);
}
return wasm;
removeStaleLocalCacheEntries(local_wasms, local_wasms_keys);
return nullptr;
}

static std::shared_ptr<WasmHandleBase>
Expand All @@ -546,9 +591,9 @@ getOrCreateThreadLocalWasm(const std::shared_ptr<WasmHandleBase> &base_handle,
if (wasm_handle) {
return wasm_handle;
}
// Remove stale entry.
local_wasms.erase(vm_key);
local_wasms.erase(it);
}
removeStaleLocalCacheEntries(local_wasms, local_wasms_keys);
// Create and initialize new thread-local WasmVM.
auto wasm_handle = clone_factory(base_handle);
if (!wasm_handle) {
Expand All @@ -560,7 +605,7 @@ getOrCreateThreadLocalWasm(const std::shared_ptr<WasmHandleBase> &base_handle,
base_handle->wasm()->fail(FailState::UnableToInitializeCode, "Failed to initialize Wasm code");
return nullptr;
}
local_wasms[vm_key] = wasm_handle;
cacheLocalWasm(vm_key, wasm_handle);
wasm_handle->wasm()->wasm_vm()->addFailCallback([vm_key](proxy_wasm::FailState fail_state) {
if (fail_state == proxy_wasm::FailState::RuntimeError) {
// If VM failed, erase the entry so that:
Expand All @@ -583,9 +628,9 @@ std::shared_ptr<PluginHandleBase> getOrCreateThreadLocalPlugin(
if (plugin_handle) {
return plugin_handle;
}
// Remove stale entry.
local_plugins.erase(key);
local_plugins.erase(it);
}
removeStaleLocalCacheEntries(local_plugins, local_plugins_keys);
// Get thread-local WasmVM.
auto wasm_handle = getOrCreateThreadLocalWasm(base_handle, clone_factory);
if (!wasm_handle) {
Expand All @@ -603,7 +648,7 @@ std::shared_ptr<PluginHandleBase> getOrCreateThreadLocalPlugin(
return nullptr;
}
auto plugin_handle = plugin_factory(wasm_handle, plugin);
local_plugins[key] = plugin_handle;
cacheLocalPlugin(key, plugin_handle);
wasm_handle->wasm()->wasm_vm()->addFailCallback([key](proxy_wasm::FailState fail_state) {
if (fail_state == proxy_wasm::FailState::RuntimeError) {
// If VM failed, erase the entry so that:
Expand Down