Skip to content

Commit 2e7ca90

Browse files
guanxinqtensorflow-copybara
authored andcommitted
Update CreateRPC API interface.
PiperOrigin-RevId: 407363453
1 parent d455f72 commit 2e7ca90

File tree

6 files changed

+24
-6
lines changed

6 files changed

+24
-6
lines changed

tensorflow_serving/experimental/tensorflow/ops/remote_predict/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ cc_library(
2626
"@org_tensorflow//tensorflow/core:framework_headers_lib",
2727
"@org_tensorflow//tensorflow/core:protos_all_cc",
2828
"@org_tensorflow//tensorflow/core/kernels:ops_util_hdrs",
29+
"@org_tensorflow//tensorflow/core/platform:statusor",
2930
],
3031
alwayslink = 1,
3132
)
@@ -96,6 +97,7 @@ tf_kernel_library(
9697
"@org_tensorflow//tensorflow/core:protos_all_cc",
9798
"@org_tensorflow//tensorflow/core/kernels:ops_util",
9899
"@org_tensorflow//tensorflow/core/kernels:split_lib",
100+
"@org_tensorflow//tensorflow/core/platform:statusor",
99101
],
100102
)
101103

@@ -163,6 +165,7 @@ cc_library(
163165
"@com_github_grpc_grpc//:grpc++",
164166
"@com_google_absl//absl/status",
165167
"@com_google_absl//absl/time",
168+
"@org_tensorflow//tensorflow/core/platform:statusor",
166169
],
167170
)
168171

tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/prediction_service_grpc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ PredictionServiceGrpc::PredictionServiceGrpc(
4141
stub_ = tensorflow::serving::PredictionService::NewStub(channel);
4242
}
4343

44-
::grpc::ClientContext* PredictionServiceGrpc::CreateRpc(
44+
StatusOr<::grpc::ClientContext*> PredictionServiceGrpc::CreateRpc(
4545
absl::Duration max_rpc_deadline) {
4646
::grpc::ClientContext* rpc = new ::grpc::ClientContext();
4747
// TODO(b/159739577): Set deadline as the min value between

tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/prediction_service_grpc.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818

1919
#include "absl/status/status.h"
2020
#include "absl/time/time.h"
21+
#include "tensorflow/core/platform/statusor.h"
2122
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
2223

2324
namespace tensorflow {
@@ -33,7 +34,7 @@ class PredictionServiceGrpc {
3334
return ::absl::OkStatus();
3435
}
3536

36-
::grpc::ClientContext* CreateRpc(absl::Duration max_rpc_deadline);
37+
StatusOr<::grpc::ClientContext*> CreateRpc(absl::Duration max_rpc_deadline);
3738

3839
void Predict(::grpc::ClientContext* rpc, PredictRequest* request,
3940
PredictResponse* response,

tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/prediction_service_grpc_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class PredictionServiceGrpcTest : public ::testing::Test {
3535

3636
TEST_F(PredictionServiceGrpcTest, TestSetDeadline) {
3737
const absl::Duration deadline = absl::Milliseconds(30000);
38-
rpc_.reset(grpc_stub_->CreateRpc(deadline));
38+
auto rpc_or = grpc_stub_->CreateRpc(deadline);
39+
ASSERT_TRUE(rpc_or.ok());
40+
rpc_.reset(rpc_or.ValueOrDie());
3941

4042
EXPECT_NEAR(absl::ToDoubleMilliseconds(deadline),
4143
absl::ToDoubleMilliseconds(absl::FromChrono(rpc_->deadline()) -

tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
#include "tensorflow/core/kernels/ops_util.h"
3030
#include "tensorflow/core/lib/core/threadpool.h"
3131
#include "tensorflow/core/lib/gtl/cleanup.h"
32+
#include "tensorflow/core/platform/status.h"
3233
#include "tensorflow/core/protobuf/named_tensor.pb.h"
3334
#include "tensorflow_serving/apis/model.pb.h"
3435
#include "tensorflow_serving/apis/predict.pb.h"
@@ -103,9 +104,17 @@ class RemotePredictOp : public AsyncOpKernel {
103104

104105
PredictResponse* response = new PredictResponse();
105106

106-
auto rpc = prediction_service_->CreateRpc(
107+
auto rpc_or = prediction_service_->CreateRpc(
107108
absl::Milliseconds(max_rpc_deadline_millis_));
108-
109+
OP_REQUIRES_ASYNC(context, rpc_or.ok(),
110+
tensorflow::Status(rpc_or.status().code(),
111+
rpc_or.status().error_message()),
112+
[&]() {
113+
delete request;
114+
delete response;
115+
done();
116+
});
117+
auto rpc = rpc_or.ValueOrDie();
109118
auto callback = [this, context, rpc, request, response,
110119
output_tensor_aliases, done](const absl::Status& status) {
111120
PostProcessResponse(context, response, status, fail_op_on_rpc_error_,

tensorflow_serving/experimental/tensorflow/ops/remote_predict/kernels/remote_predict_op_kernel_test.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "tensorflow/cc/ops/const_op.h"
2121
#include "tensorflow/core/framework/tensor_testutil.h"
2222
#include "tensorflow/core/lib/core/status_test_util.h"
23+
#include "tensorflow/core/platform/status.h"
2324
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
2425
#include "tensorflow_serving/experimental/tensorflow/ops/remote_predict/cc/ops/remote_predict_op.h"
2526

@@ -39,7 +40,9 @@ class MockPredictionService {
3940
return ::absl::OkStatus();
4041
}
4142

42-
MockRpc* CreateRpc(absl::Duration max_rpc_deadline) { return new MockRpc; }
43+
StatusOr<MockRpc*> CreateRpc(absl::Duration max_rpc_deadline) {
44+
return new MockRpc;
45+
}
4346

4447
// The model_name in request determines response and/or status.
4548
void Predict(MockRpc* rpc, PredictRequest* request, PredictResponse* response,

0 commit comments

Comments
 (0)