Skip to content

Commit 74c2f82

Browse files
authored
Fix shared memory region naming scheme (triton-inference-server#103)
1 parent e27e0d4 commit 74c2f82

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

src/python.cc

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ struct BackendState {
192192
int64_t shm_growth_byte_size;
193193
int64_t stub_timeout_seconds;
194194
int64_t shm_message_queue_size;
195+
std::atomic<int> number_of_instance_inits;
196+
std::string shared_memory_region_prefix;
195197
std::unique_ptr<EnvironmentManager> env_manager;
196198
};
197199

@@ -510,8 +512,7 @@ ModelInstanceState::ExecuteBLSRequest(
510512
std::unique_ptr<InferRequest> infer_request;
511513
std::shared_ptr<std::mutex> cuda_ipc_mutex;
512514
infer_request = InferRequest::LoadFromSharedMemory(
513-
shm_pool_, request_batch->requests, cuda_ipc_mutex,
514-
cuda_ipc_mutex);
515+
shm_pool_, request_batch->requests, cuda_ipc_mutex, cuda_ipc_mutex);
515516
std::unique_ptr<InferResponse> infer_response;
516517

517518
// If the BLS inputs are in GPU an additional round trip between the
@@ -1314,11 +1315,14 @@ TRITONSERVER_Error*
13141315
ModelInstanceState::SetupStubProcess()
13151316
{
13161317
std::string kind = TRITONSERVER_InstanceGroupKindString(kind_);
1317-
shm_region_name_ = std::string("/") + Name() + "_" +
1318-
std::to_string(Model()->Version()) + "_" + kind + "_" +
1319-
std::to_string(device_id_);
1320-
13211318
ModelState* model_state = reinterpret_cast<ModelState*>(Model());
1319+
1320+
// Increase the stub process count to avoid shared memory region name
1321+
// collision
1322+
model_state->StateForBackend()->number_of_instance_inits++;
1323+
shm_region_name_ =
1324+
model_state->StateForBackend()->shared_memory_region_prefix +
1325+
std::to_string(model_state->StateForBackend()->number_of_instance_inits);
13221326
int64_t shm_growth_size =
13231327
model_state->StateForBackend()->shm_growth_byte_size;
13241328
int64_t shm_default_size =
@@ -1711,6 +1715,9 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend)
17111715
backend_state->shm_growth_byte_size = 64 * 1024 * 1024; // 64 MBs
17121716
backend_state->stub_timeout_seconds = 30;
17131717
backend_state->shm_message_queue_size = 1000;
1718+
backend_state->number_of_instance_inits = 0;
1719+
backend_state->shared_memory_region_prefix =
1720+
"triton_python_backend_shm_region_";
17141721

17151722
if (backend_config.Find("cmdline", &cmdline)) {
17161723
triton::common::TritonJson::Value shm_growth_size;
@@ -1752,6 +1759,21 @@ TRITONBACKEND_Initialize(TRITONBACKEND_Backend* backend)
17521759
}
17531760
}
17541761

1762+
triton::common::TritonJson::Value shm_region_prefix;
1763+
std::string shm_region_prefix_str;
1764+
if (cmdline.Find("shm-region-prefix-name", &shm_region_prefix)) {
1765+
RETURN_IF_ERROR(shm_region_prefix.AsString(&shm_region_prefix_str));
1766+
// Shared memory default byte size can't be less than 4 MBs.
1767+
if (shm_region_prefix_str.size() == 0) {
1768+
return TRITONSERVER_ErrorNew(
1769+
TRITONSERVER_ERROR_INVALID_ARG,
1770+
(std::string("shm-region-prefix-name") +
1771+
" must at least contain one character.")
1772+
.c_str());
1773+
}
1774+
backend_state->shared_memory_region_prefix = shm_region_prefix_str;
1775+
}
1776+
17551777
triton::common::TritonJson::Value shm_message_queue_size;
17561778
std::string shm_message_queue_size_str;
17571779
if (cmdline.Find("shm_message_queue_size", &shm_message_queue_size)) {

0 commit comments

Comments
 (0)