From 1b2ffb093df1f22e13819a936370b86ced22e178 Mon Sep 17 00:00:00 2001 From: Sandeep Kumar Behera <64504172+sandeepkumar-skb@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:01:13 -0700 Subject: [PATCH] Create model_gpu.py --- examples/pytorch/model_gpu.py | 158 ++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 examples/pytorch/model_gpu.py diff --git a/examples/pytorch/model_gpu.py b/examples/pytorch/model_gpu.py new file mode 100644 index 00000000..5d181e4b --- /dev/null +++ b/examples/pytorch/model_gpu.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +# triton_python_backend_utils is available in every Triton Python model. You +# need to use this module to create inference requests and responses. It also +# contains some utility functions for extracting information from model_config +# and converting Triton input/output types to numpy types. +import triton_python_backend_utils as pb_utils +from torch import nn +from torch.utils.dlpack import to_dlpack, from_dlpack + + +class AddSubNet(nn.Module): + """ + Simple AddSub network in PyTorch. This network outputs the sum and + subtraction of the inputs. + """ + + def __init__(self): + super(AddSubNet, self).__init__() + + def forward(self, input0, input1): + return (input0 + input1), (input0 - input1) + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + + # You must parse model_config. JSON string is not parsed here + self.model_config = model_config = json.loads(args["model_config"]) + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") + + # Get OUTPUT1 configuration + output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1") + + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output1_config["data_type"] + ) + + # Instantiate the PyTorch model + self.add_sub_model = AddSubNet() + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + output0_dtype = self.output0_dtype + output1_dtype = self.output1_dtype + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for request in requests: + # Get INPUT0 + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_0_t = from_dlpack(in_0.to_dlpack()).cuda() + # Get INPUT1 + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + in_1_t = from_dlpack(in_1.to_dlpack()).cuda() + + out_0, out_1 = self.add_sub_model(in_0_t, in_1_t) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + out_tensor_0 = pb_utils.Tensor.from_dlpack("OUTPUT0", to_dlpack(out_0)) # .astype(output0_dtype)) + out_tensor_1 = pb_utils.Tensor.from_dlpack("OUTPUT1", to_dlpack(out_1)) # .astype(output1_dtype)) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0, out_tensor_1] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...")