Skip to content

Commit 7f5f32e

Browse files
authored
Add request parameters to Python models (triton-inference-server#213)
* Add request parameters to Python models * Add documentation about the inference request parameters
1 parent f007255 commit 7f5f32e

File tree

5 files changed

+90
-13
lines changed

5 files changed

+90
-13
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ any C++ code.
5050
- [Known Issues](#known-issues)
5151
- [`finalize`](#finalize)
5252
- [Model Config File](#model-config-file)
53+
- [Inference Request Parameters](#inference-request-parameters)
5354
- [Managing Python Runtime and Libraries](#managing-python-runtime-and-libraries)
5455
- [Building Custom Python Backend Stub](#building-custom-python-backend-stub)
5556
- [Creating Custom Execution Environments](#creating-custom-execution-environments)
@@ -560,6 +561,17 @@ models
560561
└── config.pbtxt
561562
```
562563

564+
## Inference Request Parameters
565+
566+
You can retrieve the parameters associated with an inference request
567+
using the `inference_request.parameters()` function. This function
568+
returns a JSON object where the keys are the keys of the parameters
569+
object and the values are the values for the parameters field.
570+
571+
You can read more about the inference request parameters in the [parameters
572+
extension](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_parameters.md)
573+
documentation.
574+
563575
## Managing Python Runtime and Libraries
564576

565577
Python backend shipped in the [NVIDIA GPU Cloud](https://ngc.nvidia.com/)

src/infer_request.cc

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "infer_request.h"
2828

2929
#include <boost/interprocess/sync/scoped_lock.hpp>
30+
3031
#include "pb_utils.h"
3132
#include "scoped_defer.h"
3233
#ifdef TRITON_PB_STUB
@@ -40,12 +41,12 @@ InferRequest::InferRequest(
4041
const std::vector<std::shared_ptr<PbTensor>>& inputs,
4142
const std::set<std::string>& requested_output_names,
4243
const std::string& model_name, const int64_t model_version,
43-
const uint32_t flags, const int32_t timeout,
44+
const std::string& parameters, const uint32_t flags, const int32_t timeout,
4445
const intptr_t response_factory_address, const intptr_t request_address)
4546
: request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs),
4647
requested_output_names_(requested_output_names), model_name_(model_name),
47-
model_version_(model_version), flags_(flags), timeout_(timeout),
48-
response_factory_address_(response_factory_address),
48+
model_version_(model_version), parameters_(parameters), flags_(flags),
49+
timeout_(timeout), response_factory_address_(response_factory_address),
4950
request_address_(request_address)
5051
{
5152
for (auto& input : inputs) {
@@ -79,6 +80,12 @@ InferRequest::Inputs()
7980
return inputs_;
8081
}
8182

83+
const std::string&
84+
InferRequest::Parameters()
85+
{
86+
return parameters_;
87+
}
88+
8289
const std::string&
8390
InferRequest::RequestId()
8491
{
@@ -160,7 +167,8 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
160167
sizeof(bi::managed_external_buffer::handle_t)) +
161168
(Inputs().size() * sizeof(bi::managed_external_buffer::handle_t)) +
162169
PbString::ShmStructSize(ModelName()) +
163-
PbString::ShmStructSize(RequestId()));
170+
PbString::ShmStructSize(RequestId()) +
171+
PbString::ShmStructSize(Parameters()));
164172

165173
infer_request_shm_ptr_ =
166174
reinterpret_cast<InferRequestShm*>(infer_request_shm.data_.get());
@@ -222,10 +230,18 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
222230
reinterpret_cast<char*>(infer_request_shm_ptr_) + request_id_offset,
223231
infer_request_shm.handle_ + request_id_offset);
224232

233+
size_t parameters_offset =
234+
request_id_offset + PbString::ShmStructSize(RequestId());
235+
std::unique_ptr<PbString> parameters_shm = PbString::Create(
236+
Parameters(),
237+
reinterpret_cast<char*>(infer_request_shm_ptr_) + parameters_offset,
238+
infer_request_shm.handle_ + parameters_offset);
239+
225240
// Save the references to shared memory.
226241
infer_request_shm_ = std::move(infer_request_shm);
227242
request_id_shm_ = std::move(request_id_shm);
228243
model_name_shm_ = std::move(model_name_shm);
244+
parameters_shm_ = std::move(parameters_shm);
229245
shm_handle_ = infer_request_shm_.handle_;
230246
requested_output_names_shm_ = std::move(requested_output_names_shm);
231247
}
@@ -286,21 +302,28 @@ InferRequest::LoadFromSharedMemory(
286302
request_handle + request_id_offset,
287303
reinterpret_cast<char*>(infer_request_shm_ptr) + request_id_offset);
288304

305+
size_t parameters_offset = request_id_offset + request_id_shm->Size();
306+
std::unique_ptr<PbString> parameters_shm = PbString::LoadFromSharedMemory(
307+
request_handle + request_id_offset,
308+
reinterpret_cast<char*>(infer_request_shm_ptr) + parameters_offset);
309+
289310
return std::unique_ptr<InferRequest>(new InferRequest(
290311
infer_request_shm, request_id_shm, requested_output_names_shm,
291-
model_name_shm, input_tensors));
312+
model_name_shm, input_tensors, parameters_shm));
292313
}
293314

294315
InferRequest::InferRequest(
295316
AllocatedSharedMemory<char>& infer_request_shm,
296317
std::unique_ptr<PbString>& request_id_shm,
297318
std::vector<std::unique_ptr<PbString>>& requested_output_names_shm,
298319
std::unique_ptr<PbString>& model_name_shm,
299-
std::vector<std::shared_ptr<PbTensor>>& input_tensors)
320+
std::vector<std::shared_ptr<PbTensor>>& input_tensors,
321+
std::unique_ptr<PbString>& parameters_shm)
300322
: infer_request_shm_(std::move(infer_request_shm)),
301323
request_id_shm_(std::move(request_id_shm)),
302324
requested_output_names_shm_(std::move(requested_output_names_shm)),
303-
model_name_shm_(std::move(model_name_shm))
325+
model_name_shm_(std::move(model_name_shm)),
326+
parameters_shm_(std::move(parameters_shm))
304327
{
305328
infer_request_shm_ptr_ =
306329
reinterpret_cast<InferRequestShm*>(infer_request_shm_.data_.get());
@@ -325,6 +348,7 @@ InferRequest::InferRequest(
325348
}
326349

327350
request_id_ = request_id_shm_->String();
351+
parameters_ = parameters_shm_->String();
328352
requested_output_names_ = std::move(requested_output_names);
329353
model_name_ = model_name_shm_->String();
330354
flags_ = infer_request_shm_ptr_->flags;

src/infer_request.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ class InferRequest {
6161
const std::vector<std::shared_ptr<PbTensor>>& inputs,
6262
const std::set<std::string>& requested_output_names,
6363
const std::string& model_name, const int64_t model_version,
64-
const uint32_t flags = 0, const int32_t timeout = 0,
65-
const intptr_t response_factory_address = 0,
64+
const std::string& parameters, const uint32_t flags = 0,
65+
const int32_t timeout = 0, const intptr_t response_factory_address = 0,
6666
const intptr_t request_address = 0);
6767

6868
const std::vector<std::shared_ptr<PbTensor>>& Inputs();
6969
const std::string& RequestId();
70+
const std::string& Parameters();
7071
uint64_t CorrelationId();
7172
const std::string& ModelName();
7273
int64_t ModelVersion();
@@ -116,14 +117,16 @@ class InferRequest {
116117
std::unique_ptr<PbString>& request_id_shm,
117118
std::vector<std::unique_ptr<PbString>>& requested_output_names_shm,
118119
std::unique_ptr<PbString>& model_name_shm,
119-
std::vector<std::shared_ptr<PbTensor>>& input_tensors);
120+
std::vector<std::shared_ptr<PbTensor>>& input_tensors,
121+
std::unique_ptr<PbString>& parameters_shm);
120122

121123
std::string request_id_;
122124
uint64_t correlation_id_;
123125
std::vector<std::shared_ptr<PbTensor>> inputs_;
124126
std::set<std::string> requested_output_names_;
125127
std::string model_name_;
126128
int64_t model_version_;
129+
std::string parameters_;
127130
uint32_t flags_;
128131
int32_t timeout_;
129132
intptr_t response_factory_address_;
@@ -140,6 +143,7 @@ class InferRequest {
140143
bi::managed_external_buffer::handle_t* output_names_handle_shm_ptr_;
141144
bi::managed_external_buffer::handle_t* input_tensors_handle_ptr_;
142145
bi::managed_external_buffer::handle_t shm_handle_;
146+
std::unique_ptr<PbString> parameters_shm_;
143147

144148
#ifdef TRITON_PB_STUB
145149
std::shared_ptr<ResponseSender> response_sender_;

src/pb_stub.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <sys/stat.h>
3030
#include <sys/types.h>
3131
#include <sys/wait.h>
32+
3233
#include <atomic>
3334
#include <boost/interprocess/sync/interprocess_condition.hpp>
3435
#include <boost/interprocess/sync/interprocess_mutex.hpp>
@@ -41,6 +42,7 @@
4142
#include <regex>
4243
#include <thread>
4344
#include <unordered_map>
45+
4446
#include "infer_response.h"
4547
#include "pb_error.h"
4648
#include "pb_map.h"
@@ -1272,9 +1274,10 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
12721274
for (auto& requested_output_name : requested_output_names) {
12731275
requested_outputs.emplace(requested_output_name);
12741276
}
1277+
// FIXME: InferenceRequest parameters are not supported in BLS now.
12751278
return std::make_shared<InferRequest>(
12761279
request_id, correlation_id, inputs, requested_outputs,
1277-
model_name, model_version, flags, timeout);
1280+
model_name, model_version, "" /*parameters*/, flags, timeout);
12781281
}),
12791282
py::arg("request_id").none(false) = "",
12801283
py::arg("correlation_id").none(false) = 0,
@@ -1291,6 +1294,7 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
12911294
.def("flags", &InferRequest::Flags)
12921295
.def("set_flags", &InferRequest::SetFlags)
12931296
.def("timeout", &InferRequest::Timeout)
1297+
.def("parameters", &InferRequest::Parameters)
12941298
.def(
12951299
"exec",
12961300
[](std::shared_ptr<InferRequest>& infer_request,

src/python_be.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,39 @@ ModelInstanceState::SaveRequestsToSharedMemory(
356356
requested_output_names.emplace(requested_output_name);
357357
}
358358

359+
triton::common::TritonJson::Value parameters_json(
360+
triton::common::TritonJson::ValueType::OBJECT);
361+
uint32_t parameter_count;
362+
RETURN_IF_ERROR(
363+
TRITONBACKEND_RequestParameterCount(request, &parameter_count));
364+
for (size_t i = 0; i < parameter_count; i++) {
365+
const char* name;
366+
TRITONSERVER_ParameterType type;
367+
const void* vvalue;
368+
RETURN_IF_ERROR(
369+
TRITONBACKEND_RequestParameter(request, i, &name, &type, &vvalue));
370+
if (type == TRITONSERVER_PARAMETER_INT) {
371+
RETURN_IF_ERROR(parameters_json.AddInt(
372+
name, *(reinterpret_cast<const int64_t*>(vvalue))));
373+
} else if (type == TRITONSERVER_PARAMETER_BOOL) {
374+
RETURN_IF_ERROR(parameters_json.AddBool(
375+
name, *(reinterpret_cast<const bool*>(vvalue))));
376+
} else if (type == TRITONSERVER_PARAMETER_STRING) {
377+
std::string string = reinterpret_cast<const char*>(vvalue);
378+
RETURN_IF_ERROR(parameters_json.AddString(name, string));
379+
} else {
380+
return TRITONSERVER_ErrorNew(
381+
TRITONSERVER_ERROR_INVALID_ARG,
382+
(std::string("Unsupported parameter type for parameter '") + name +
383+
"'.")
384+
.c_str());
385+
}
386+
}
387+
388+
triton::common::TritonJson::WriteBuffer buffer;
389+
RETURN_IF_ERROR(parameters_json.Write(&buffer));
390+
const auto& parameters_string = buffer.Contents();
391+
359392
// request id
360393
const char* id;
361394
RETURN_IF_ERROR(TRITONBACKEND_RequestId(request, &id));
@@ -373,13 +406,13 @@ ModelInstanceState::SaveRequestsToSharedMemory(
373406
RETURN_IF_ERROR(TRITONBACKEND_ResponseFactoryNew(&factory_ptr, request));
374407
infer_request = std::make_unique<InferRequest>(
375408
id, correlation_id, pb_input_tensors, requested_output_names,
376-
model_state->Name(), model_state->Version(), flags,
409+
model_state->Name(), model_state->Version(), parameters_string, flags,
377410
0 /* BLS request timeout*/, reinterpret_cast<intptr_t>(factory_ptr),
378411
reinterpret_cast<intptr_t>(request));
379412
} else {
380413
infer_request = std::make_unique<InferRequest>(
381414
id, correlation_id, pb_input_tensors, requested_output_names,
382-
model_state->Name(), model_state->Version(), flags,
415+
model_state->Name(), model_state->Version(), parameters_string, flags,
383416
0 /* BLS request timeout*/, 0 /* response_factory_address */,
384417
reinterpret_cast<intptr_t>(request));
385418
}

0 commit comments

Comments
 (0)