diff --git a/src/wasm.cc b/src/wasm.cc index 15127c99a..9d2566fc1 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -36,13 +37,56 @@ namespace proxy_wasm { namespace { -// Map from Wasm Key to the local Wasm instance. +// Map from Wasm key to the thread-local Wasm instance. thread_local std::unordered_map> local_wasms; +// Wasm key queue to track stale entries in `local_wasms`. +thread_local std::queue local_wasms_keys; + +// Map from plugin key to the thread-local plugin instance. thread_local std::unordered_map> local_plugins; +// Plugin key queue to track stale entries in `local_plugins`. +thread_local std::queue local_plugins_keys; + +// Check no more than `MAX_LOCAL_CACHE_GC_CHUNK_SIZE` cache entries at a time during stale entries +// cleanup. +const size_t MAX_LOCAL_CACHE_GC_CHUNK_SIZE = 64; + // Map from Wasm Key to the base Wasm instance, using a pointer to avoid the initialization fiasco. std::mutex base_wasms_mutex; std::unordered_map> *base_wasms = nullptr; +void cacheLocalWasm(const std::string &key, const std::shared_ptr &wasm_handle) { + local_wasms[key] = wasm_handle; + local_wasms_keys.emplace(key); +} + +void cacheLocalPlugin(const std::string &key, + const std::shared_ptr &plugin_handle) { + local_plugins[key] = plugin_handle; + local_plugins_keys.emplace(key); +} + +template +void removeStaleLocalCacheEntries(std::unordered_map> &cache, + std::queue &keys) { + auto num_keys_to_check = std::min(MAX_LOCAL_CACHE_GC_CHUNK_SIZE, keys.size()); + for (size_t i = 0; i < num_keys_to_check; i++) { + std::string key(keys.front()); + keys.pop(); + + const auto it = cache.find(key); + if (it == cache.end()) { + continue; + } + + if (it->second.expired()) { + cache.erase(it); + } else { + keys.push(std::move(key)); + } + } +} + } // namespace std::string makeVmKey(std::string_view vm_id, std::string_view vm_configuration, @@ -525,14 +569,15 @@ std::shared_ptr createWasm(const std::string &vm_key, const std: std::shared_ptr getThreadLocalWasm(std::string_view vm_key) { auto it = local_wasms.find(std::string(vm_key)); - if (it == local_wasms.end()) { - return nullptr; - } - auto wasm = it->second.lock(); - if (!wasm) { - local_wasms.erase(std::string(vm_key)); + if (it != local_wasms.end()) { + auto wasm = it->second.lock(); + if (wasm) { + return wasm; + } + local_wasms.erase(it); } - return wasm; + removeStaleLocalCacheEntries(local_wasms, local_wasms_keys); + return nullptr; } static std::shared_ptr @@ -546,9 +591,9 @@ getOrCreateThreadLocalWasm(const std::shared_ptr &base_handle, if (wasm_handle) { return wasm_handle; } - // Remove stale entry. - local_wasms.erase(vm_key); + local_wasms.erase(it); } + removeStaleLocalCacheEntries(local_wasms, local_wasms_keys); // Create and initialize new thread-local WasmVM. auto wasm_handle = clone_factory(base_handle); if (!wasm_handle) { @@ -560,7 +605,7 @@ getOrCreateThreadLocalWasm(const std::shared_ptr &base_handle, base_handle->wasm()->fail(FailState::UnableToInitializeCode, "Failed to initialize Wasm code"); return nullptr; } - local_wasms[vm_key] = wasm_handle; + cacheLocalWasm(vm_key, wasm_handle); wasm_handle->wasm()->wasm_vm()->addFailCallback([vm_key](proxy_wasm::FailState fail_state) { if (fail_state == proxy_wasm::FailState::RuntimeError) { // If VM failed, erase the entry so that: @@ -583,9 +628,9 @@ std::shared_ptr getOrCreateThreadLocalPlugin( if (plugin_handle) { return plugin_handle; } - // Remove stale entry. - local_plugins.erase(key); + local_plugins.erase(it); } + removeStaleLocalCacheEntries(local_plugins, local_plugins_keys); // Get thread-local WasmVM. auto wasm_handle = getOrCreateThreadLocalWasm(base_handle, clone_factory); if (!wasm_handle) { @@ -603,7 +648,7 @@ std::shared_ptr getOrCreateThreadLocalPlugin( return nullptr; } auto plugin_handle = plugin_factory(wasm_handle, plugin); - local_plugins[key] = plugin_handle; + cacheLocalPlugin(key, plugin_handle); wasm_handle->wasm()->wasm_vm()->addFailCallback([key](proxy_wasm::FailState fail_state) { if (fail_state == proxy_wasm::FailState::RuntimeError) { // If VM failed, erase the entry so that: