Skip to content

Commit 889585c

Browse files
authored
Add support for request rescheduling (triton-inference-server#319)
* Add support for request rescheduling * Address comment * Add documentation * Fix up for doc * Revert response sender changes * Address comment
1 parent 60a9091 commit 889585c

File tree

7 files changed

+249
-36
lines changed

7 files changed

+249
-36
lines changed

README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ any C++ code.
5050
- [Decoupled mode](#decoupled-mode)
5151
- [Use Cases](#use-cases)
5252
- [Known Issues](#known-issues)
53+
- [Request Rescheduling](#request-rescheduling)
5354
- [`finalize`](#finalize)
5455
- [Model Config File](#model-config-file)
5556
- [Inference Request Parameters](#inference-request-parameters)
@@ -623,6 +624,102 @@ for more details on how to host a decoupled model.
623624

624625
* Currently, decoupled Python models can not make async infer requests.
625626

627+
#### Request Rescheduling
628+
629+
Starting from 23.11, Python backend supports request rescheduling. By calling
630+
the `set_release_flags` function on the request object with the flag
631+
`pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE`, you can reschedule the
632+
request for further execution in a future batch. This feature is useful for
633+
handling generative sequences.
634+
635+
The model config must be configured to enable generative sequence batching in
636+
order to use the request rescheduling API:
637+
638+
```
639+
sequence_batching {
640+
generative_sequence : true
641+
}
642+
```
643+
644+
For non-decoupled models, there can only be one response for each request. Since
645+
the rescheduled request is the same as the original, you must append a `None`
646+
object to the response list for the rescheduled request. For example:
647+
648+
```python
649+
import triton_python_backend_utils as pb_utils
650+
651+
class TritonPythonModel:
652+
...
653+
654+
def execute(self, requests):
655+
responses = []
656+
657+
for request in requests:
658+
# Explicitly reschedule the first request
659+
if self.idx == 0:
660+
request.set_release_flags(
661+
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
662+
)
663+
responses.append(None)
664+
self.idx += 1
665+
else:
666+
responses.append(inference_response)
667+
668+
return responses
669+
```
670+
671+
For decoupled models, it is required to reschedule a request *before* returning
672+
from the `execute` function.
673+
Below is an example of a decoupled model using request rescheduling. This model
674+
takes 1 input tensor, an INT32 [ 1 ] input named "IN", and produces an output
675+
tensor "OUT" with the same shape as the input tensor. The input value indicates
676+
the total number of responses to be generated and the output value indicates the
677+
number of remaining responses. For example, if the request input has value 2,
678+
the model will:
679+
- Send a response with value 1.
680+
- Release request with RESCHEDULE flag.
681+
- When execute on the same request, send the last response with value 0.
682+
- Release request with ALL flag.
683+
684+
```python
685+
import triton_python_backend_utils as pb_utils
686+
687+
class TritonPythonModel:
688+
...
689+
690+
def execute(self, requests):
691+
responses = []
692+
693+
for request in requests:
694+
in_input = pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()
695+
696+
if self.reset_flag:
697+
self.remaining_response = in_input[0]
698+
self.reset_flag = False
699+
700+
response_sender = request.get_response_sender()
701+
702+
self.remaining_response -= 1
703+
704+
out_output = pb_utils.Tensor(
705+
"OUT", np.array([self.remaining_response], np.int32)
706+
)
707+
response = pb_utils.InferenceResponse(output_tensors=[out_output])
708+
709+
if self.remaining_response <= 0:
710+
response_sender.send(
711+
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
712+
)
713+
self.reset_flag = True
714+
else:
715+
request.set_release_flags(
716+
pb_utils.TRITONSERVER_REQUEST_RELEASE_RESCHEDULE
717+
)
718+
response_sender.send(response)
719+
720+
return None
721+
```
722+
626723
### `finalize`
627724

628725
Implementing `finalize` is optional. This function allows you to do any clean

src/infer_request.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ InferRequest::InferRequest(
5050
model_version_(model_version), parameters_(parameters), flags_(flags),
5151
timeout_(timeout), response_factory_address_(response_factory_address),
5252
request_address_(request_address), preferred_memory_(preferred_memory),
53-
trace_(trace)
53+
trace_(trace), request_release_flags_(TRITONSERVER_REQUEST_RELEASE_ALL)
5454
{
5555
for (auto& input : inputs) {
5656
if (!input) {
@@ -175,6 +175,20 @@ InferRequest::Trace()
175175
return trace_;
176176
}
177177

178+
uint32_t
179+
InferRequest::ReleaseFlags()
180+
{
181+
request_release_flags_ = infer_request_shm_ptr_->request_release_flags;
182+
return request_release_flags_;
183+
}
184+
185+
void
186+
InferRequest::SetReleaseFlags(const uint32_t& flags)
187+
{
188+
request_release_flags_ = flags;
189+
infer_request_shm_ptr_->request_release_flags = request_release_flags_;
190+
}
191+
178192
void
179193
InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
180194
{
@@ -201,6 +215,7 @@ InferRequest::SaveToSharedMemory(std::unique_ptr<SharedMemoryManager>& shm_pool)
201215
infer_request_shm_ptr_->timeout = timeout_;
202216
infer_request_shm_ptr_->preferred_memory = preferred_memory_;
203217
infer_request_shm_ptr_->trace = trace_;
218+
infer_request_shm_ptr_->request_release_flags = request_release_flags_;
204219

205220
output_names_handle_shm_ptr_ =
206221
reinterpret_cast<bi::managed_external_buffer::handle_t*>(
@@ -379,6 +394,7 @@ InferRequest::InferRequest(
379394
timeout_ = infer_request_shm_ptr_->timeout;
380395
preferred_memory_ = infer_request_shm_ptr_->preferred_memory;
381396
trace_ = infer_request_shm_ptr_->trace;
397+
request_release_flags_ = infer_request_shm_ptr_->request_release_flags;
382398

383399
#ifdef TRITON_PB_STUB
384400
pb_cancel_ =

src/infer_request.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct InferRequestShm {
7373
int32_t timeout;
7474
PreferredMemory preferred_memory;
7575
InferenceTrace trace;
76+
uint32_t request_release_flags;
7677
};
7778

7879
class InferRequest {
@@ -104,6 +105,8 @@ class InferRequest {
104105
void SetIsDecoupled(const bool is_decoupled);
105106
PreferredMemory& GetPreferredMemory();
106107
InferenceTrace& Trace();
108+
uint32_t ReleaseFlags();
109+
void SetReleaseFlags(const uint32_t& flags);
107110

108111
#ifdef TRITON_PB_STUB
109112
std::shared_ptr<InferResponse> Exec(const bool is_decoupled);
@@ -161,6 +164,7 @@ class InferRequest {
161164
bool is_decoupled_;
162165
PreferredMemory preferred_memory_;
163166
InferenceTrace trace_;
167+
uint32_t request_release_flags_;
164168

165169
// Shared Memory Data Structures
166170
AllocatedSharedMemory<char> infer_request_shm_;

src/pb_stub.cc

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -793,26 +793,39 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr)
793793
std::to_string(response_size) + "\n";
794794
throw PythonBackendException(err);
795795
}
796-
for (auto& response : responses) {
796+
797+
for (size_t i = 0; i < response_size; i++) {
797798
// Check the return type of execute function.
798-
if (!py::isinstance<InferResponse>(response)) {
799-
std::string str = py::str(response.get_type());
800-
throw PythonBackendException(
801-
std::string("Expected an 'InferenceResponse' object in the execute "
802-
"function return list, found type '") +
803-
str + "'.");
799+
InferRequest* infer_request = py_request_list[i].cast<InferRequest*>();
800+
if (infer_request->ReleaseFlags() ==
801+
TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) {
802+
if (!py::isinstance<py::none>(responses[i])) {
803+
// When the request is rescheduled in non-decoupled model, the
804+
// response must be None.
805+
std::string str = py::str(responses[i].get_type());
806+
throw PythonBackendException(
807+
"Expected a None object in the execute function return list for "
808+
"reschduled request, "
809+
"found type '" +
810+
str + "'.");
811+
}
812+
} else {
813+
if (!py::isinstance<InferResponse>(responses[i])) {
814+
std::string str = py::str(responses[i].get_type());
815+
throw PythonBackendException(
816+
std::string(
817+
"Expected an 'InferenceResponse' object in the execute "
818+
"function return list, found type '") +
819+
str + "'.");
820+
}
821+
InferResponse* infer_response = responses[i].cast<InferResponse*>();
822+
infer_response->PruneOutputTensors(
823+
infer_request->RequestedOutputNames());
824+
ProcessResponse(infer_response);
825+
responses_shm_handle[i] = infer_response->ShmHandle();
804826
}
805827
}
806828
response_batch_shm_ptr->batch_size = response_size;
807-
808-
for (size_t i = 0; i < batch_size; i++) {
809-
InferResponse* infer_response = responses[i].cast<InferResponse*>();
810-
InferRequest* infer_request = py_request_list[i].cast<InferRequest*>();
811-
infer_response->PruneOutputTensors(infer_request->RequestedOutputNames());
812-
813-
ProcessResponse(infer_response);
814-
responses_shm_handle[i] = infer_response->ShmHandle();
815-
}
816829
}
817830
catch (const PythonBackendException& pb_exception) {
818831
has_exception = true;
@@ -1675,7 +1688,9 @@ PYBIND11_EMBEDDED_MODULE(c_python_backend_utils, module)
16751688
"requested_output_names", &InferRequest::RequestedOutputNames,
16761689
py::return_value_policy::reference_internal)
16771690
.def("get_response_sender", &InferRequest::GetResponseSender)
1678-
.def("is_cancelled", &InferRequest::IsCancelled);
1691+
.def("is_cancelled", &InferRequest::IsCancelled)
1692+
.def("set_release_flags", &InferRequest::SetReleaseFlags),
1693+
py::arg("flags").none(false);
16791694

16801695
py::class_<PbTensor, std::shared_ptr<PbTensor>>(module, "Tensor")
16811696
.def(py::init(&PbTensor::FromNumpy))

0 commit comments

Comments
 (0)