1717
1818#include < string.h>
1919
20+ #include < cassert>
2021#include < atomic>
2122#include < deque>
2223#include < map>
@@ -72,6 +73,11 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
7273 return it->second ;
7374 return nullptr ;
7475 }
76+ void clearWasmInContext () {
77+ for (auto &item : contexts_) {
78+ item.second ->clearWasm ();
79+ }
80+ }
7581 uint32_t allocContextId ();
7682 bool isFailed () { return failed_ != FailState::Ok; }
7783 FailState fail_state () { return failed_; }
@@ -174,6 +180,7 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
174180 HttpCall = 0 ,
175181 GrpcCall = 1 ,
176182 GrpcStream = 2 ,
183+ RedisCall = 3 ,
177184 };
178185 static const uint32_t kCalloutTypeMask = 0x3 ; // Enough to cover the 3 types.
179186 static const uint32_t kCalloutIncrement = 0x4 ; // Enough to cover the 3 types.
@@ -186,6 +193,9 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
186193 bool isGrpcStreamId (uint32_t callout_id) {
187194 return (callout_id & kCalloutTypeMask ) == static_cast <uint32_t >(CalloutType::GrpcStream);
188195 }
196+ bool isRedisCallId (uint32_t callout_id) {
197+ return (callout_id & kCalloutTypeMask ) == static_cast <uint32_t >(CalloutType::RedisCall);
198+ }
189199 uint32_t nextHttpCallId () {
190200 // TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
191201 return next_http_call_id_ += kCalloutIncrement ;
@@ -198,6 +208,10 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
198208 // TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
199209 return next_grpc_stream_id_ += kCalloutIncrement ;
200210 }
211+ uint32_t nextRedisCallId () {
212+ // TODO(PiotrSikora): re-add rollover protection (requires at least 1 billion callouts).
213+ return next_redis_call_id_ += kCalloutIncrement ;
214+ }
201215
202216protected:
203217 friend class ContextBase ;
@@ -257,6 +271,8 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
257271
258272 WasmCallVoid<5 > on_http_call_response_;
259273
274+ WasmCallVoid<4 > on_redis_call_response_;
275+
260276 WasmCallVoid<3 > on_grpc_receive_;
261277 WasmCallVoid<3 > on_grpc_close_;
262278 WasmCallVoid<3 > on_grpc_create_initial_metadata_;
@@ -276,9 +292,10 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
276292 _f (on_downstream_connection_close) _f(on_upstream_connection_close) _f(on_request_body) \
277293 _f (on_request_trailers) _f(on_request_metadata) _f(on_response_body) \
278294 _f (on_response_trailers) _f(on_response_metadata) _f(on_http_call_response) \
279- _f (on_grpc_receive) _f(on_grpc_close) _f(on_grpc_receive_initial_metadata) \
280- _f (on_grpc_receive_trailing_metadata) _f(on_queue_ready) _f(on_done) \
281- _f (on_log) _f(on_delete)
295+ _f (on_redis_call_response) _f(on_grpc_receive) _f(on_grpc_close) \
296+ _f (on_grpc_receive_initial_metadata) \
297+ _f (on_grpc_receive_trailing_metadata) _f(on_queue_ready) _f(on_done) \
298+ _f (on_log) _f(on_delete)
282299
283300 // Capabilities which are allowed to be linked to the module. If this is empty, restriction
284301 // is not enforced.
@@ -307,6 +324,7 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
307324 uint32_t next_http_call_id_ = static_cast <uint32_t >(CalloutType::HttpCall);
308325 uint32_t next_grpc_call_id_ = static_cast <uint32_t >(CalloutType::GrpcCall);
309326 uint32_t next_grpc_stream_id_ = static_cast <uint32_t >(CalloutType::GrpcStream);
327+ uint32_t next_redis_call_id_ = static_cast <uint32_t >(CalloutType::RedisCall);
310328
311329 // Actions to be done after the call into the VM returns.
312330 std::deque<std::function<void ()>> after_vm_call_actions_;
@@ -335,8 +353,26 @@ class WasmHandleBase : public std::enable_shared_from_this<WasmHandleBase> {
335353
336354 std::shared_ptr<WasmBase> &wasm () { return wasm_base_; }
337355
356+ void setRecoverVmCallback (std::function<std::shared_ptr<WasmHandleBase>()> &&f) {
357+ recover_vm_callback_ = std::move (f);
358+ }
359+
360+ // Recover the wasm vm and generate a new wasm handle
361+ bool doRecover (std::shared_ptr<WasmHandleBase> &new_handle) {
362+ assert (new_handle == nullptr );
363+ if (recover_vm_callback_ == nullptr ) {
364+ return true ;
365+ }
366+ new_handle = recover_vm_callback_ ();
367+ if (!new_handle) {
368+ return false ;
369+ }
370+ return true ;
371+ }
372+
338373protected:
339374 std::shared_ptr<WasmBase> wasm_base_;
375+ std::function<std::shared_ptr<WasmHandleBase>()> recover_vm_callback_;
340376 std::unordered_map<std::string, bool > plugin_canary_cache_;
341377};
342378
@@ -365,10 +401,35 @@ class PluginHandleBase : public std::enable_shared_from_this<PluginHandleBase> {
365401
366402 std::shared_ptr<PluginBase> &plugin () { return plugin_; }
367403 std::shared_ptr<WasmBase> &wasm () { return wasm_handle_->wasm (); }
404+ std::shared_ptr<WasmHandleBase> &wasmHandle () { return wasm_handle_; }
405+
406+ void setRecoverPluginCallback (
407+ std::function<std::shared_ptr<PluginHandleBase>(std::shared_ptr<WasmHandleBase> &)> &&f) {
408+ recover_plugin_callback_ = std::move (f);
409+ }
410+
411+ // Recover the wasm plugin and generate a new plugin handle
412+ bool doRecover (std::shared_ptr<PluginHandleBase> &new_handle) {
413+ assert (new_handle == nullptr );
414+ if (recover_plugin_callback_ == nullptr ) {
415+ return true ;
416+ }
417+ std::shared_ptr<WasmHandleBase> new_wasm_handle;
418+ if (!wasm_handle_->doRecover (new_wasm_handle)) {
419+ return false ;
420+ }
421+ new_handle = recover_plugin_callback_ (new_wasm_handle);
422+ if (!new_handle) {
423+ return false ;
424+ }
425+ return true ;
426+ }
368427
369428protected:
370429 std::shared_ptr<PluginBase> plugin_;
371430 std::shared_ptr<WasmHandleBase> wasm_handle_;
431+ std::function<std::shared_ptr<PluginHandleBase>(std::shared_ptr<WasmHandleBase> &)>
432+ recover_plugin_callback_;
372433};
373434
374435using PluginHandleFactory = std::function<std::shared_ptr<PluginHandleBase>(
0 commit comments