@@ -77,7 +77,7 @@ class ModelState : public BackendModel {
7777 TRITONSERVER_Error* LoadModel (
7878 const std::string& artifact_name, const torch::Device device,
7979 std::string* model_path,
80- std::unique_ptr <torch::jit::script::Module>* torch_model);
80+ std::shared_ptr <torch::jit::script::Module>* torch_model);
8181
8282 bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
8383 const std::pair<bool , bool >& EnabledTensorExprFuser () const
@@ -98,6 +98,8 @@ class ModelState : public BackendModel {
9898 return enable_nvfuser_pair_;
9999 }
100100
101+ bool EnabledWeightSharing () { return enable_weight_sharing_; }
102+
101103 private:
102104 ModelState (TRITONBACKEND_Model* triton_model);
103105 TRITONSERVER_Error* AutoCompleteConfig ();
@@ -111,6 +113,9 @@ class ModelState : public BackendModel {
111113 // Flag to indicate whether inference mode is enabled. Defaults to false.
112114 bool enable_inference_mode_;
113115
116+ // Flag to indicate whether weight sharing is enabled. Defaults to false.
117+ bool enable_weight_sharing_;
118+
114119 // Flag pairs to indicate if various JIT settings are set and
115120 // enabled respectively. Defaults to (false, true). Default behavior
116121 // is to do nothing if not explicitly set. Tensor fuser flag is
@@ -122,6 +127,12 @@ class ModelState : public BackendModel {
122127 // Flag pair to indicate whether nvfuser is set and enabled respectively.
123128 // Defaults to (false, false).
124129 std::pair<bool , bool > enable_nvfuser_pair_;
130+
131+ // Model mapping for shared TorchScript model across all instances on the
132+ // same device. The key is a pair of isGPU and device index.
133+ std::map<
134+ std::pair<bool , int64_t >, std::shared_ptr<torch::jit::script::Module>>
135+ torch_models_;
125136};
126137
127138TRITONSERVER_Error*
@@ -161,7 +172,8 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
161172
162173ModelState::ModelState (TRITONBACKEND_Model* triton_model)
163174 : BackendModel(triton_model), enable_optimized_execution_(true ),
164- enable_inference_mode_ (false ), enable_tensor_fuser_pair_({false , true }),
175+ enable_inference_mode_ (false ), enable_weight_sharing_(false ),
176+ enable_tensor_fuser_pair_({false , true }),
165177 enable_jit_profiling_pair_({false , true }),
166178 enable_jit_executor_pair_({false , true }),
167179 enable_nvfuser_pair_({false , false })
@@ -172,7 +184,7 @@ TRITONSERVER_Error*
172184ModelState::LoadModel (
173185 const std::string& artifact_name, const torch::Device device,
174186 std::string* model_path,
175- std::unique_ptr <torch::jit::script::Module>* torch_model)
187+ std::shared_ptr <torch::jit::script::Module>* torch_model)
176188{
177189 // Find the TorchScript file that describes the model. If the model
178190 // configuration doesn't have an explicit model file specified then
@@ -194,6 +206,23 @@ ModelState::LoadModel(
194206 " ' for model instance '" + Name () + " '" );
195207 }
196208
209+ // If weight sharing is enabled, skip loading model if
210+ // it is already available on the target device
211+ std::pair<bool , int > device_pair;
212+ if (enable_weight_sharing_) {
213+ device_pair = std::make_pair (!device.is_cpu (), device.index ());
214+ auto mit = torch_models_.find (device_pair);
215+ if (mit != torch_models_.end ()) {
216+ *torch_model = mit->second ;
217+ LOG_MESSAGE (
218+ TRITONSERVER_LOG_INFO,
219+ (std::string (" Reusing TorchScript model for instance '" ) + Name () +
220+ " '" )
221+ .c_str ());
222+ return nullptr ; // success
223+ }
224+ }
225+
197226 // Serialize the torch model to string
198227 std::string model_data_str;
199228 RETURN_IF_ERROR (ReadTextFile (*model_path, &model_data_str));
@@ -213,6 +242,17 @@ ModelState::LoadModel(
213242 (" failed to load model '" + Name () + " ': " + ex.what ()).c_str ());
214243 }
215244
245+ if (enable_weight_sharing_) {
246+ if (!((torch_models_.emplace (device_pair, *torch_model)).second )) {
247+ std::string type = device.is_cpu () ? " CPU" : " GPU" ;
248+ LOG_MESSAGE (
249+ TRITONSERVER_LOG_WARN,
250+ (std::string (" Model already found on target " ) + type + " device " +
251+ " (id " + std::to_string (device.index ()) + " ) for '" + Name () + " '" )
252+ .c_str ());
253+ }
254+ }
255+
216256 return nullptr ; // success
217257}
218258
@@ -295,6 +335,25 @@ ModelState::ParseParameters()
295335 .c_str ());
296336 }
297337
338+ // If 'ENABLE_WEIGHT_SHARING' is not present in 'parameters' then no
339+ // update is made to 'enable_weight_sharing'.
340+ err = ParseParameter (
341+ params, " ENABLE_WEIGHT_SHARING" , &enable_weight_sharing_);
342+ if (err != nullptr ) {
343+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
344+ return err;
345+ } else {
346+ TRITONSERVER_ErrorDelete (err);
347+ }
348+ } else {
349+ LOG_MESSAGE (
350+ TRITONSERVER_LOG_INFO,
351+ (std::string (" Weight sharing is " ) +
352+ (enable_weight_sharing_ ? " enabled" : " disabled" ) +
353+ " for model instance '" + Name () + " '" )
354+ .c_str ());
355+ }
356+
298357 // If 'ENABLE_JIT_PROFILING' is not present in 'parameters' then no update
299358 // is made to 'enable_jit_profiling'.
300359 bool enable_jit_profiling = false ;
@@ -419,7 +478,7 @@ class ModelInstanceState : public BackendModelInstance {
419478 // The full path to the TorchScript model file.
420479 std::string model_path_;
421480
422- std::unique_ptr <torch::jit::script::Module> torch_model_;
481+ std::shared_ptr <torch::jit::script::Module> torch_model_;
423482 torch::Device device_;
424483
425484 // Map from configuration name for an input to the index of
0 commit comments