Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2

- name: Mount bazel cache
uses: actions/cache@v1
Expand Down
2 changes: 2 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cc_library(
copts = ["-DWITHOUT_ZLIB=1"],
deps = [
":include",
"@com_google_protobuf//:protobuf_lite",
"@proxy_wasm_cpp_sdk//:api_lib",
],
)
Expand All @@ -45,6 +46,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_protobuf//:protobuf_lite",
"@proxy_wasm_cpp_sdk//:api_lib",
],
)
Expand Down
10 changes: 3 additions & 7 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ workspace(name = "proxy_wasm_cpp_host")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

http_archive(
git_repository(
name = "proxy_wasm_cpp_sdk",
sha256 = "14f66f67e8f37ec81d28d7f5307be4407d206ac5f0daaf6d22fa5536797bcac1",
strip_prefix = "proxy-wasm-cpp-sdk-31f1fc5b7e09f231fa532d2d296e479a113c3e10",
urls = ["https://github.com/proxy-wasm/proxy-wasm-cpp-sdk/archive/31f1fc5b7e09f231fa532d2d296e479a113c3e10.tar.gz"],
patch_cmds = ["rm BUILD"],
build_file = '//bazel/external:proxy-wasm-cpp-sdk.BUILD',
remote = "https://github.com/proxy-wasm/proxy-wasm-cpp-sdk",
commit = "c12553951d01bb60cb1448ba1fcfeb8f843aad62",
)

http_archive(
Expand All @@ -20,7 +17,6 @@ http_archive(
urls = ["https://github.com/abseil/abseil-cpp/archive/37dd2562ec830d547a1524bb306be313ac3f2556.tar.gz"],
)


# required by com_google_protobuf
http_archive(
name = "bazel_skylib",
Expand Down
103 changes: 43 additions & 60 deletions include/proxy-wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class WasmVm;

using Pairs = std::vector<std::pair<string_view, string_view>>;
using PairsWithStringValues = std::vector<std::pair<string_view, std::string>>;
using CallOnThreadFunction = std::function<void(std::function<void()>)>;

struct BufferInterface {
virtual ~BufferInterface() {}
Expand All @@ -50,9 +49,12 @@ struct BufferInterface {
* @param size_ptr is the location in the VM address space to place the size of the newly
* allocated memory block which contains the copied bytes (e.g. length).
* @return true on success.
* Length is guarantteed to be > 0 and the bounds have already been checked.
*/
virtual bool copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr,
uint64_t size_ptr) const = 0;
virtual WasmResult copyTo(WasmBase *wasm, size_t start, size_t length, uint64_t ptr_ptr,
uint64_t size_ptr) const = 0;

virtual WasmResult copyFrom(size_t start, size_t length, string_view data) = 0;
};

// Opaque context object.
Expand Down Expand Up @@ -94,7 +96,7 @@ class ContextBase {
uint32_t id() const { return id_; }
bool isVmContext() const { return id_ == 0; }
bool isRootContext() const { return root_context_id_ == 0; }
ContextBase *root_context() { return root_context_; }
ContextBase *root_context() const { return root_context_; }
string_view root_id() const { return isRootContext() ? root_id_ : plugin_->root_id_; }
string_view log_prefix() const {
return isRootContext() ? root_log_prefix_ : plugin_->log_prefix();
Expand All @@ -104,76 +106,50 @@ class ContextBase {
// Called before deleting the context.
virtual void destroy();

//
// VM level downcalls into the WASM code on Context(id == 0).
//
virtual bool onStart(std::shared_ptr<PluginBase> plugin);
virtual bool onConfigure(std::shared_ptr<PluginBase> plugin);

//
// Root Context downcalls into the WASM code Context(id != 0, root_context_id_ == 0);
virtual void onTick();

// Stream downcalls on Context(id > 0).
//
// General stream downcall on a new stream.
virtual void onCreate(uint32_t root_context_id);

// Network
virtual FilterStatus onNetworkNewConnection() {
unimplemented();
return FilterStatus::Continue;
}
virtual FilterStatus onDownstreamData(int /* data_length */, bool /* end_of_stream */) {
unimplemented();
return FilterStatus::Continue;
}
virtual FilterStatus onUpstreamData(int /* data_length */, bool /* end_of_stream */) {
unimplemented();
return FilterStatus::Continue;
}
virtual FilterStatus onNetworkNewConnection();
virtual FilterStatus onDownstreamData(int data_length, bool end_of_stream);
virtual FilterStatus onUpstreamData(int data_length, bool end_of_stream);
enum class PeerType : uint32_t {
Unknown = 0,
Local = 1,
Remote = 2,
};
virtual void onDownstreamConnectionClose(PeerType) { unimplemented(); }
virtual void onUpstreamConnectionClose(PeerType) { unimplemented(); }
virtual void onDownstreamConnectionClose(PeerType);
virtual void onUpstreamConnectionClose(PeerType);
// HTTP Filter Stream Request Downcalls.
virtual FilterHeadersStatus onRequestHeaders() {
unimplemented();
return FilterHeadersStatus::Continue;
}
virtual FilterDataStatus onRequestBody(int /* body_buffer_length */, bool /* end_of_stream */) {
unimplemented();
return FilterDataStatus::Continue;
}
virtual FilterTrailersStatus onRequestTrailers() {
unimplemented();
return FilterTrailersStatus::Continue;
}
virtual FilterMetadataStatus onRequestMetadata() {
unimplemented();
return FilterMetadataStatus::Continue;
}
virtual FilterHeadersStatus onRequestHeaders(uint32_t headers);
virtual FilterDataStatus onRequestBody(uint32_t body_buffer_length, bool end_of_stream);
virtual FilterTrailersStatus onRequestTrailers(uint32_t trailers);
virtual FilterMetadataStatus onRequestMetadata(uint32_t elements);
// HTTP Filter Stream Response Downcalls.
virtual FilterHeadersStatus onResponseHeaders() {
unimplemented();
return FilterHeadersStatus::Continue;
}
virtual FilterDataStatus onResponseBody(int /* body_buffer_length */, bool /* end_of_stream */) {
unimplemented();
return FilterDataStatus::Continue;
}
virtual FilterTrailersStatus onResponseTrailers() {
unimplemented();
return FilterTrailersStatus::Continue;
}
virtual FilterMetadataStatus onResponseMetadata() {
unimplemented();
return FilterMetadataStatus::Continue;
}
virtual FilterHeadersStatus onResponseHeaders(uint32_t headers);
virtual FilterDataStatus onResponseBody(uint32_t body_buffer_length, bool end_of_stream);
virtual FilterTrailersStatus onResponseTrailers(uint32_t trailers);
virtual FilterMetadataStatus onResponseMetadata(uint32_t elements);
// Async call response.
virtual void onHttpCallResponse(uint32_t /* token */, uint32_t /* headers */,
uint32_t /* body_size */, uint32_t /* trailers */) {}
virtual void onHttpCallResponse(uint32_t token, uint32_t headers, uint32_t body_size,
uint32_t trailers);
// Grpc
virtual void onGrpcReceiveInitialMetadata(uint32_t token, uint32_t elements);
virtual void onGrpcReceiveTrailingMetadata(uint32_t token, uint32_t trailers);
virtual void onGrpcReceive(uint32_t token, uint32_t response_size);
virtual void onGrpcClose(uint32_t token, uint32_t status_code);

// Inter-VM shared queue message arrival.
virtual void onQueueReady(uint32_t /* token */) { unimplemented(); }
virtual void onQueueReady(uint32_t /* token */);
// General stream downcall when the stream/vm has ended.
virtual bool onDone();
// General stream downcall for logging. Occurs after onDone().
Expand Down Expand Up @@ -212,7 +188,7 @@ class ContextBase {
}

// Buffer
virtual const BufferInterface *getBuffer(WasmBufferType /* type */) {
virtual BufferInterface *getBuffer(WasmBufferType /* type */) {
unimplemented();
return nullptr;
}
Expand All @@ -233,14 +209,16 @@ class ContextBase {
// gRPC
// Returns a token which will be used with the corresponding onGrpc and grpc calls.
virtual WasmResult grpcCall(string_view /* grpc_service */, string_view /* service_name */,
string_view /* method_name */, string_view /* request */,
string_view /* method_name */, const Pairs & /* initial_metadata */,
string_view /* request */,
const optional<std::chrono::milliseconds> & /* timeout */,
uint32_t * /* token_ptr */) {
unimplemented();
return WasmResult::Unimplemented;
}
virtual WasmResult grpcStream(string_view /* grpc_service */, string_view /* service_name */,
string_view /* method_name */, uint32_t * /* token_ptr */) {
string_view /* method_name */, const Pairs & /* initial_metadata */,
uint32_t * /* token_ptr */) {
unimplemented();
return WasmResult::Unimplemented;
}
Expand Down Expand Up @@ -342,7 +320,12 @@ class ContextBase {
protected:
friend class WasmBase;

virtual void initializeRoot(WasmBase *wasm, std::shared_ptr<PluginBase> plugin);
// NB: initializeRootBase is non-virtual and can be called in the constructor without ambiguity.
void initializeRootBase(WasmBase *wasm, std::shared_ptr<PluginBase> plugin);
// NB: initializeRoot is virtual and should be called only outside of the constructor.
virtual void initializeRoot(WasmBase *wasm, std::shared_ptr<PluginBase> plugin) {
initializeRootBase(wasm, plugin);
}
std::string makeRootLogPrefix(string_view vm_id) const;

WasmBase *wasm_{nullptr};
Expand Down
9 changes: 7 additions & 2 deletions include/proxy-wasm/exports.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ Word enqueue_shared_queue(void *raw_context, Word token, Word data_ptr, Word dat
Word get_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word ptr_ptr,
Word size_ptr);
Word get_buffer_status(void *raw_context, Word type, Word length_ptr, Word flags_ptr);
Word set_buffer_bytes(void *raw_context, Word type, Word start, Word length, Word data_ptr,
Word data_size);
Word add_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size, Word value_ptr,
Word value_size);
Word get_header_map_value(void *raw_context, Word type, Word key_ptr, Word key_size,
Expand All @@ -72,10 +74,11 @@ Word record_metric(void *raw_context, Word metric_id, uint64_t value);
Word get_metric(void *raw_context, Word metric_id, Word result_uint64_ptr);
Word grpc_call(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr,
Word service_name_size, Word method_name_ptr, Word method_name_size,
Word request_ptr, Word request_size, Word timeout_milliseconds, Word token_ptr);
Word initial_metadata_ptr, Word initial_metadata_size, Word request_ptr,
Word request_size, Word timeout_milliseconds, Word token_ptr);
Word grpc_stream(void *raw_context, Word service_ptr, Word service_size, Word service_name_ptr,
Word service_name_size, Word method_name_ptr, Word method_name_size,
Word token_ptr);
Word initial_metadata_ptr, Word initial_metadata_size, Word token_ptr);
Word grpc_cancel(void *raw_context, Word token);
Word grpc_close(void *raw_context, Word token);
Word grpc_send(void *raw_context, Word token, Word message_ptr, Word message_size, Word end_stream);
Expand All @@ -92,8 +95,10 @@ Word call_foreign_function(void *raw_context, Word function_name, Word function_

Word wasi_unstable_fd_write(void *raw_context, Word fd, Word iovs, Word iovs_len,
Word nwritten_ptr);
Word wasi_unstable_fd_read(void *, Word, Word, Word, Word);
Word wasi_unstable_fd_seek(void *, Word, int64_t, Word, Word);
Word wasi_unstable_fd_close(void *, Word);
Word wasi_unstable_fd_fdstat_get(void *, Word fd, Word statOut);
Word wasi_unstable_environ_get(void *, Word, Word);
Word wasi_unstable_environ_sizes_get(void *raw_context, Word count_ptr, Word buf_size_ptr);
Word wasi_unstable_args_get(void *raw_context, Word argc_ptr, Word argv_buf_size_ptr);
Expand Down
1 change: 0 additions & 1 deletion include/proxy-wasm/null_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class NullPlugin : public NullVmPlugin {

void onGrpcReceive(uint64_t context_id, uint64_t token, size_t body_size);
void onGrpcClose(uint64_t context_id, uint64_t token, uint64_t status_code);
void onGrpcCreateInitialMetadata(uint64_t context_id, uint64_t token, uint64_t headers);
void onGrpcReceiveInitialMetadata(uint64_t context_id, uint64_t token, uint64_t headers);
void onGrpcReceiveTrailingMetadata(uint64_t context_id, uint64_t token, uint64_t trailers);

Expand Down
49 changes: 27 additions & 22 deletions include/proxy-wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@ namespace proxy_wasm {

#include "proxy_wasm_common.h"

class ContextBase;
class WasmBase;
class WasmHandle;
class WasmHandleBase;

using WasmForeignFunction =
std::function<WasmResult(WasmBase &, string_view, std::function<void *(size_t size)>)>;
using WasmVmFactory = std::function<std::unique_ptr<WasmVm>()>;
using CallOnThreadFunction = std::function<void(std::function<void()>)>;

// Wasm execution instance. Manages the host side of the Wasm interface.
class WasmBase : public std::enable_shared_from_this<WasmBase> {
public:
WasmBase(std::unique_ptr<WasmVm> wasm_vm, string_view vm_id, string_view vm_configuration,
string_view vm_key);
WasmBase(const std::shared_ptr<WasmHandle> &other, WasmVmFactory factory);
WasmBase(const std::shared_ptr<WasmHandleBase> &other, WasmVmFactory factory);
virtual ~WasmBase();

bool initialize(const std::string &code, bool allow_precompiled = false);
Expand Down Expand Up @@ -78,11 +80,17 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
unimplemented();
return nullptr;
}
// NB: if plugin is nullptr, then a VM Context is returned.
virtual ContextBase *createContext(std::shared_ptr<PluginBase> plugin) {
if (plugin)
return new ContextBase(this, plugin);
return new ContextBase(this);
}

void setTickPeriod(uint32_t /* root_context_id */, std::chrono::milliseconds /* tick_period */) {
unimplemented();
virtual void setTickPeriod(uint32_t root_context_id, std::chrono::milliseconds tick_period) {
tick_period_[root_context_id] = tick_period;
}
void tickHandler(uint32_t /* root_context_idl */) { unimplemented(); }
void tick(uint32_t root_context_id);
void queueReady(uint32_t root_context_id, uint32_t token);

void startShutdown();
Expand Down Expand Up @@ -216,7 +224,7 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
WasmCallVoid<1> on_log_;
WasmCallVoid<1> on_delete_;

std::shared_ptr<WasmHandle> base_wasm_handle_;
std::shared_ptr<WasmHandleBase> base_wasm_handle_;

// Used by the base_wasm to enable non-clonable thread local Wasm(s) to be constructed.
std::string code_;
Expand All @@ -235,43 +243,40 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> {
uint32_t next_gauge_metric_id_ = static_cast<uint32_t>(MetricType::Gauge);
uint32_t next_histogram_metric_id_ = static_cast<uint32_t>(MetricType::Histogram);

// Foreign Functions.
std::unordered_map<std::string, WasmForeignFunction> foreign_functions_;

// Actions to be done after the call into the VM returns.
std::deque<std::function<void()>> after_vm_call_actions_;
};

// Handle which enables shutdown operations to run post deletion (e.g. post listener drain).
class WasmHandle : public std::enable_shared_from_this<WasmHandle> {
class WasmHandleBase : public std::enable_shared_from_this<WasmHandleBase> {
public:
explicit WasmHandle(std::shared_ptr<WasmBase> wasm) : wasm_(wasm) {}
~WasmHandle() { wasm_->startShutdown(); }
explicit WasmHandleBase(std::shared_ptr<WasmBase> wasm_base) : wasm_base_(wasm_base) {}
~WasmHandleBase() { wasm_base_->startShutdown(); }

std::shared_ptr<WasmBase> &wasm() { return wasm_; }
std::shared_ptr<WasmBase> &wasm() { return wasm_base_; }

private:
std::shared_ptr<WasmBase> wasm_;
protected:
std::shared_ptr<WasmBase> wasm_base_;
};

std::string makeVmKey(string_view vm_id, string_view configuration, string_view code);

using WasmHandleFactory = std::function<std::shared_ptr<WasmHandle>(string_view vm_id)>;
using WasmHandleFactory = std::function<std::shared_ptr<WasmHandleBase>(string_view vm_id)>;
using WasmHandleCloneFactory =
std::function<std::shared_ptr<WasmHandle>(std::shared_ptr<WasmHandle> wasm)>;
std::function<std::shared_ptr<WasmHandleBase>(std::shared_ptr<WasmHandleBase> wasm)>;

// Returns nullptr on failure (i.e. initialization of the VM fails).
std::shared_ptr<WasmHandle>
std::shared_ptr<WasmHandleBase>
createWasm(std::string vm_key, std::string code, std::shared_ptr<PluginBase> plugin,
WasmHandleFactory factory, bool allow_precompiled,
std::unique_ptr<ContextBase> root_context_for_testing = nullptr);
// Get an existing ThreadLocal VM matching 'vm_id' or nullptr if there isn't one.
std::shared_ptr<WasmHandle> getThreadLocalWasm(string_view vm_id);
std::shared_ptr<WasmHandleBase> getThreadLocalWasm(string_view vm_id);
// Get an existing ThreadLocal VM matching 'vm_id' or create one using 'base_wavm' by cloning or by
// using it it as a template.
std::shared_ptr<WasmHandle> getOrCreateThreadLocalWasm(std::shared_ptr<WasmHandle> &base_wasm,
std::shared_ptr<PluginBase> plugin,
WasmHandleCloneFactory factory);
std::shared_ptr<WasmHandleBase>
getOrCreateThreadLocalWasm(std::shared_ptr<WasmHandleBase> base_wasm,
std::shared_ptr<PluginBase> plugin, WasmHandleCloneFactory factory);

inline const std::string &WasmBase::vm_configuration() const {
if (base_wasm_handle_)
Expand Down
Loading