@@ -69,7 +69,11 @@ def auto_complete_config(auto_complete_model_config):
6969 "optional" : True ,
7070 },
7171 ]
72- outputs = [{"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]}]
72+ outputs = [
73+ {"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
74+ {"name" : "input_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
75+ {"name" : "output_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
76+ ]
7377
7478 # Store the model configuration as a dictionary.
7579 config = auto_complete_model_config .as_dict ()
@@ -108,6 +112,14 @@ def initialize(self, args):
108112 self .model_config , "text_output"
109113 )
110114 self .output_dtype = pb_utils .triton_string_to_numpy (output_config ["data_type" ])
115+ output_tokens_config = pb_utils .get_output_config_by_name (
116+ self .model_config , "output_tokens"
117+ )
118+ self .output_tokens_dtype = pb_utils .triton_string_to_numpy (output_tokens_config ["data_type" ])
119+ input_tokens_config = pb_utils .get_output_config_by_name (
120+ self .model_config , "input_tokens"
121+ )
122+ self .input_tokens_dtype = pb_utils .triton_string_to_numpy (input_tokens_config ["data_type" ])
111123
112124 # Prepare vLLM engine
113125 self .init_engine ()
@@ -313,10 +325,17 @@ def create_response(self, vllm_output, prepend_input):
313325 text_outputs = [
314326 (prompt + output .text ).encode ("utf-8" ) for output in vllm_output .outputs
315327 ]
328+ output_tokens = sum ([len (output .token_ids ) for output in vllm_output .outputs ])
316329 triton_output_tensor = pb_utils .Tensor (
317- "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
330+ "text_output" , np .asarray (text_outputs , dtype = self .output_dtype ),
318331 )
319- return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
332+ triton_tokens_tensor = pb_utils .Tensor (
333+ "output_tokens" , np .asarray (output_tokens , dtype = self .output_tokens_dtype ),
334+ )
335+ triton_input_tokens_tensor = pb_utils .Tensor (
336+ "input_tokens" , np .asarray (len (vllm_output .prompt_token_ids ), dtype = self .input_tokens_dtype ),
337+ )
338+ return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor , triton_tokens_tensor , triton_input_tokens_tensor ])
320339
321340 def create_stream_response (self , vllm_output , previous_outputs_lengths ):
322341 """
0 commit comments