Skip to content

Commit 73e8fc7

Browse files
the-david-oyhemantj
andauthored
Share Model Weights for Instances on the Same Device (triton-inference-server#54)
* Share model weights for instances on the same device * Added destructor. * Shortened comment * Clarified comment. * Addressing comments. * Make weight sharing optional, off by default. * Addressing comments. * Formatting. * Reworded comment for optionality of weight sharing. * Let destructor remove last shared pointer * Weight sharing documentation * Fixing typo * Updated copyright years Co-authored-by: hemantj <[email protected]>
1 parent 1c4db70 commit 73e8fc7

File tree

2 files changed

+82
-8
lines changed

2 files changed

+82
-8
lines changed

README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<!--
2-
# Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -111,7 +111,7 @@ execution of models without these optimizations. In some models, optimized execu
111111
does not benefit performance as seen [here](https://github.com/pytorch/pytorch/issues/19978)
112112
and in other cases impacts performance negatively, as seen [here](https://github.com/pytorch/pytorch/issues/53824).
113113

114-
The section of model config file specifying this parameters will look like:
114+
The section of model config file specifying this parameter will look like:
115115

116116
```
117117
parameters: {
@@ -133,7 +133,7 @@ this mode gets better performance by disabling autograd.
133133
Please note that in some models, InferenceMode might not benefit performance
134134
and in fewer cases might impact performance negatively.
135135

136-
The section of model config file specifying this parameters will look like:
136+
The section of model config file specifying this parameter will look like:
137137

138138
```
139139
parameters: {
@@ -153,7 +153,7 @@ Please note that in some models generated using trace in old PyTorch versions mi
153153
correctly with NvFuser. We recommend using scripting and a recent version of PyTorch
154154
to generate these models.
155155

156-
The section of model config file specifying this parameters will look like:
156+
The section of model config file specifying this parameter will look like:
157157

158158
```
159159
parameters: {
@@ -164,6 +164,21 @@ key: "ENABLE_NVFUSER"
164164
}
165165
```
166166

167+
* `ENABLE_WEIGHT_SHARING`: Boolean flag to enable model instances on the same device to
168+
share weights. This optimization should not be used with stateful models. If not specified,
169+
weight sharing is disabled.
170+
171+
The section of model config file specifying this parameter will look like:
172+
173+
```
174+
parameters: {
175+
key: "ENABLE_WEIGHT_SHARING"
176+
value: {
177+
string_value:"true"
178+
}
179+
}
180+
```
181+
167182
* Additional Optimizations: Three additional boolean parameters are available to disable
168183
certain Torch optimizations that can sometimes cause latency regressions in models with
169184
complex execution modes and dynamic shapes. If not specified, all are enabled by default.

src/libtorch.cc

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class ModelState : public BackendModel {
7777
TRITONSERVER_Error* LoadModel(
7878
const std::string& artifact_name, const torch::Device device,
7979
std::string* model_path,
80-
std::unique_ptr<torch::jit::script::Module>* torch_model);
80+
std::shared_ptr<torch::jit::script::Module>* torch_model);
8181

8282
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
8383
const std::pair<bool, bool>& EnabledTensorExprFuser() const
@@ -98,6 +98,8 @@ class ModelState : public BackendModel {
9898
return enable_nvfuser_pair_;
9999
}
100100

101+
bool EnabledWeightSharing() { return enable_weight_sharing_; }
102+
101103
private:
102104
ModelState(TRITONBACKEND_Model* triton_model);
103105
TRITONSERVER_Error* AutoCompleteConfig();
@@ -111,6 +113,9 @@ class ModelState : public BackendModel {
111113
// Flag to indicate whether inference mode is enabled. Defaults to false.
112114
bool enable_inference_mode_;
113115

116+
// Flag to indicate whether weight sharing is enabled. Defaults to false.
117+
bool enable_weight_sharing_;
118+
114119
// Flag pairs to indicate if various JIT settings are set and
115120
// enabled respectively. Defaults to (false, true). Default behavior
116121
// is to do nothing if not explicitly set. Tensor fuser flag is
@@ -122,6 +127,12 @@ class ModelState : public BackendModel {
122127
// Flag pair to indicate whether nvfuser is set and enabled respectively.
123128
// Defaults to (false, false).
124129
std::pair<bool, bool> enable_nvfuser_pair_;
130+
131+
// Model mapping for shared TorchScript model across all instances on the
132+
// same device. The key is a pair of isGPU and device index.
133+
std::map<
134+
std::pair<bool, int64_t>, std::shared_ptr<torch::jit::script::Module>>
135+
torch_models_;
125136
};
126137

127138
TRITONSERVER_Error*
@@ -161,7 +172,8 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
161172

162173
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
163174
: BackendModel(triton_model), enable_optimized_execution_(true),
164-
enable_inference_mode_(false), enable_tensor_fuser_pair_({false, true}),
175+
enable_inference_mode_(false), enable_weight_sharing_(false),
176+
enable_tensor_fuser_pair_({false, true}),
165177
enable_jit_profiling_pair_({false, true}),
166178
enable_jit_executor_pair_({false, true}),
167179
enable_nvfuser_pair_({false, false})
@@ -172,7 +184,7 @@ TRITONSERVER_Error*
172184
ModelState::LoadModel(
173185
const std::string& artifact_name, const torch::Device device,
174186
std::string* model_path,
175-
std::unique_ptr<torch::jit::script::Module>* torch_model)
187+
std::shared_ptr<torch::jit::script::Module>* torch_model)
176188
{
177189
// Find the TorchScript file that describes the model. If the model
178190
// configuration doesn't have an explicit model file specified then
@@ -194,6 +206,23 @@ ModelState::LoadModel(
194206
"' for model instance '" + Name() + "'");
195207
}
196208

209+
// If weight sharing is enabled, skip loading model if
210+
// it is already available on the target device
211+
std::pair<bool, int> device_pair;
212+
if (enable_weight_sharing_) {
213+
device_pair = std::make_pair(!device.is_cpu(), device.index());
214+
auto mit = torch_models_.find(device_pair);
215+
if (mit != torch_models_.end()) {
216+
*torch_model = mit->second;
217+
LOG_MESSAGE(
218+
TRITONSERVER_LOG_INFO,
219+
(std::string("Reusing TorchScript model for instance '") + Name() +
220+
"'")
221+
.c_str());
222+
return nullptr; // success
223+
}
224+
}
225+
197226
// Serialize the torch model to string
198227
std::string model_data_str;
199228
RETURN_IF_ERROR(ReadTextFile(*model_path, &model_data_str));
@@ -213,6 +242,17 @@ ModelState::LoadModel(
213242
("failed to load model '" + Name() + "': " + ex.what()).c_str());
214243
}
215244

245+
if (enable_weight_sharing_) {
246+
if (!((torch_models_.emplace(device_pair, *torch_model)).second)) {
247+
std::string type = device.is_cpu() ? "CPU" : "GPU";
248+
LOG_MESSAGE(
249+
TRITONSERVER_LOG_WARN,
250+
(std::string("Model already found on target ") + type + " device " +
251+
"(id " + std::to_string(device.index()) + ") for '" + Name() + "'")
252+
.c_str());
253+
}
254+
}
255+
216256
return nullptr; // success
217257
}
218258

@@ -295,6 +335,25 @@ ModelState::ParseParameters()
295335
.c_str());
296336
}
297337

338+
// If 'ENABLE_WEIGHT_SHARING' is not present in 'parameters' then no
339+
// update is made to 'enable_weight_sharing'.
340+
err = ParseParameter(
341+
params, "ENABLE_WEIGHT_SHARING", &enable_weight_sharing_);
342+
if (err != nullptr) {
343+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
344+
return err;
345+
} else {
346+
TRITONSERVER_ErrorDelete(err);
347+
}
348+
} else {
349+
LOG_MESSAGE(
350+
TRITONSERVER_LOG_INFO,
351+
(std::string("Weight sharing is ") +
352+
(enable_weight_sharing_ ? "enabled" : "disabled") +
353+
" for model instance '" + Name() + "'")
354+
.c_str());
355+
}
356+
298357
// If 'ENABLE_JIT_PROFILING' is not present in 'parameters' then no update
299358
// is made to 'enable_jit_profiling'.
300359
bool enable_jit_profiling = false;
@@ -419,7 +478,7 @@ class ModelInstanceState : public BackendModelInstance {
419478
// The full path to the TorchScript model file.
420479
std::string model_path_;
421480

422-
std::unique_ptr<torch::jit::script::Module> torch_model_;
481+
std::shared_ptr<torch::jit::script::Module> torch_model_;
423482
torch::Device device_;
424483

425484
// Map from configuration name for an input to the index of

0 commit comments

Comments
 (0)