Skip to content

Commit d5a3028

Browse files
committed
fix wasm memory leak
1 parent 27df305 commit d5a3028

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

include/proxy-wasm/wasm.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,16 @@ class PluginHandleBase : public std::enable_shared_from_this<PluginHandleBase> {
399399
~PluginHandleBase() {
400400
if (wasm_handle_) {
401401
wasm_handle_->wasm()->startShutdown(plugin_->key());
402+
wasm_handle_->wasm()->wasm_vm()->removeFailCallback(plugin_handle_key_);
402403
}
403404
}
404405

405406
std::shared_ptr<PluginBase> &plugin() { return plugin_; }
406407
std::shared_ptr<WasmBase> &wasm() { return wasm_handle_->wasm(); }
407408
std::shared_ptr<WasmHandleBase> &wasmHandle() { return wasm_handle_; }
408409

410+
void setPluginHandleKey(std::string_view key) { plugin_handle_key_ = std::string(key); }
411+
409412
void setRecoverPluginCallback(
410413
std::function<std::shared_ptr<PluginHandleBase>(std::shared_ptr<WasmHandleBase> &)> &&f) {
411414
recover_plugin_callback_ = std::move(f);
@@ -433,6 +436,10 @@ class PluginHandleBase : public std::enable_shared_from_this<PluginHandleBase> {
433436
std::shared_ptr<WasmHandleBase> wasm_handle_;
434437
std::function<std::shared_ptr<PluginHandleBase>(std::shared_ptr<WasmHandleBase> &)>
435438
recover_plugin_callback_;
439+
440+
private:
441+
// key for the plugin handle, used to identify the key in fail callbacks
442+
std::string plugin_handle_key_;
436443
};
437444

438445
using PluginHandleFactory = std::function<std::shared_ptr<PluginHandleBase>(

include/proxy-wasm/wasm_vm.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,27 @@ class WasmVm {
328328
void fail(FailState fail_state, std::string_view message) {
329329
integration()->error(message);
330330
failed_ = fail_state;
331-
for (auto &callback : fail_callbacks_) {
331+
for (auto & [key, callback] : fail_callbacks_) {
332332
callback(fail_state);
333333
}
334334
}
335-
void addFailCallback(std::function<void(FailState)> fail_callback) {
336-
fail_callbacks_.push_back(fail_callback);
335+
336+
/**
337+
* Generates id for fail callbacks allowing direct insertion of the function.
338+
* Note: if fail callback needs to be removed later, must provide specific key.
339+
*/
340+
void addFailCallback(std::function<void(FailState)> fail_callback) {
341+
static int id = 0;
342+
std::string key = std::to_string(id++);
343+
addFailCallback(key, std::move(fail_callback));
344+
}
345+
346+
void addFailCallback(const std::string& key, std::function<void(FailState)> fail_callback) {
347+
fail_callbacks_[key] = std::move(fail_callback);
348+
}
349+
350+
void removeFailCallback(const std::string& key) {
351+
fail_callbacks_.erase(key);
337352
}
338353

339354
bool isHostFunctionAllowed(const std::string &name) {
@@ -353,7 +368,7 @@ class WasmVm {
353368
protected:
354369
std::unique_ptr<WasmVmIntegration> integration_;
355370
FailState failed_ = FailState::Ok;
356-
std::vector<std::function<void(FailState)>> fail_callbacks_;
371+
std::unordered_map<std::string, std::function<void(FailState)>> fail_callbacks_;
357372

358373
private:
359374
bool restricted_callback_{false};

src/wasm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ getOrCreateThreadLocalWasm(const std::shared_ptr<WasmHandleBase> &base_handle,
729729

730730
void setPluginFailCallback(const std::string &key,
731731
const std::shared_ptr<WasmHandleBase> &wasm_handle) {
732-
wasm_handle->wasm()->wasm_vm()->addFailCallback([key](proxy_wasm::FailState fail_state) {
732+
wasm_handle->wasm()->wasm_vm()->addFailCallback(key, [key](proxy_wasm::FailState fail_state) {
733733
if (fail_state == proxy_wasm::FailState::RuntimeError) {
734734
// If VM failed, erase the entry so that:
735735
// 1) we can recreate the new thread local plugin from the same base_wasm.
@@ -819,6 +819,7 @@ std::shared_ptr<PluginHandleBase> getOrCreateThreadLocalPlugin(
819819
}
820820
auto plugin_handle = plugin_factory(wasm_handle, plugin);
821821
cacheLocalPlugin(key, plugin_handle);
822+
plugin_handle->setPluginHandleKey(key);
822823
setPluginFailCallback(key, wasm_handle);
823824
setPluginRecoverCallback(key, plugin_handle, base_handle, plugin, plugin_factory);
824825
return plugin_handle;

0 commit comments

Comments
 (0)