Skip to content

Commit 02c9c1c

Browse files
authored
Add PyTorch platform handler example (triton-inference-server#287)
* Add PyTorch platform handler example * Refactor docs structure * Add more comments and minor refactoring * Further break down client.py * Remove exit 0 if terminated normally * Simplify comments * Improve comment * List mug.jpg paths * Docs update * Describe the source of mug.jpg
1 parent 74722ba commit 02c9c1c

File tree

7 files changed

+1402
-135
lines changed

7 files changed

+1402
-135
lines changed

README.md

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ any C++ code.
7272
- [Input Tensor Device Placement](#input-tensor-device-placement)
7373
- [Frameworks](#frameworks)
7474
- [PyTorch](#pytorch)
75+
- [PyTorch Platform \[Experimental\]](#pytorch-platform-experimental)
7576
- [PyTorch Determinism](#pytorch-determinism)
7677
- [TensorFlow](#tensorflow)
7778
- [TensorFlow Determinism](#tensorflow-determinism)
@@ -1397,9 +1398,115 @@ this workflow.
13971398
For a simple example of using PyTorch in a Python Backend model, see the
13981399
[AddSubNet PyTorch example](#addsubnet-in-pytorch).
13991400

1400-
PyTorch models may be served directly without implementing the `model.py`, see
1401-
[Serving PyTorch models using Python Backend \[Experimental\]](src/resources/platform_handlers/pytorch/README.md)
1402-
for more details.
1401+
### PyTorch Platform \[Experimental\]
1402+
1403+
**NOTE**: *This feature is subject to change and removal, and should not
1404+
be used in production.*
1405+
1406+
Starting from 23.08, we are adding an experimental support for loading and
1407+
serving PyTorch models directly via Python backend. The model can be provided
1408+
within the triton server model repository, and a
1409+
[pre-built Python model](src/resources/platform_handlers/pytorch/model.py) will
1410+
be used to load and serve the PyTorch model.
1411+
1412+
#### Model Layout
1413+
1414+
The model repository should look like:
1415+
1416+
```
1417+
model_repository/
1418+
`-- model_directory
1419+
|-- 1
1420+
| |-- model.py
1421+
| `-- model.pt
1422+
`-- config.pbtxt
1423+
```
1424+
1425+
The `model.py` contains the class definition of the PyTorch model. The class
1426+
should extend the
1427+
[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module).
1428+
The `model.pt` may be optionally provided which contains the saved
1429+
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference)
1430+
of the model. For serving TorchScript models, a `model.pt` TorchScript can be
1431+
provided in place of the `model.py` file.
1432+
1433+
By default, Triton will use the
1434+
[PyTorch backend](https://github.com/triton-inference-server/pytorch_backend) to
1435+
load and serve TorchScript models. In order to serve from Python backend,
1436+
[model configuration](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md)
1437+
should explicitly provide the following settings:
1438+
1439+
```
1440+
backend: "python"
1441+
platform: "pytorch"
1442+
```
1443+
1444+
#### PyTorch Installation
1445+
1446+
This feature will take advantage of the
1447+
[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)
1448+
optimization, make sure the
1449+
[PyTorch 2.0+ pip package](https://pypi.org/project/torch/2.0.1/) is available
1450+
in the same Python environment.
1451+
1452+
```
1453+
pip install torch==2.0.1
1454+
```
1455+
Alternatively, a
1456+
[Python Execution Environment](#using-custom-python-execution-environments)
1457+
with the PyTorch dependency may be used.
1458+
1459+
#### Customization
1460+
1461+
The following PyTorch settings may be customized by setting parameters on the
1462+
`config.pbtxt`.
1463+
1464+
[`torch.set_num_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_threads.html#torch.set_num_threads)
1465+
- Key: NUM_THREADS
1466+
- Value: The number of threads used for intraop parallelism on CPU.
1467+
1468+
[`torch.set_num_interop_threads(int)`](https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads.html#torch.set_num_interop_threads)
1469+
- Key: NUM_INTEROP_THREADS
1470+
- Value: The number of threads used for interop parallelism (e.g. in JIT
1471+
interpreter) on CPU.
1472+
1473+
[`torch.compile()` parameters](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)
1474+
- Key: TORCH_COMPILE_OPTIONAL_PARAMETERS
1475+
- Value: Any of following parameter(s) encoded as a JSON object.
1476+
- fullgraph (*bool*): Whether it is ok to break model into several subgraphs.
1477+
- dynamic (*bool*): Use dynamic shape tracing.
1478+
- backend (*str*): The backend to be used.
1479+
- mode (*str*): Can be either "default", "reduce-overhead" or "max-autotune".
1480+
- options (*dict*): A dictionary of options to pass to the backend.
1481+
- disable (*bool*): Turn `torch.compile()` into a no-op for testing.
1482+
1483+
For example:
1484+
```
1485+
parameters: {
1486+
key: "NUM_THREADS"
1487+
value: { string_value: "4" }
1488+
}
1489+
parameters: {
1490+
key: "TORCH_COMPILE_OPTIONAL_PARAMETERS"
1491+
value: { string_value: "{\"disable\": true}" }
1492+
}
1493+
```
1494+
1495+
#### Example
1496+
1497+
You can find the complete example instructions in
1498+
[examples/pytorch_platform_handler](examples/pytorch_platform_handler/README.md).
1499+
1500+
#### Limitations
1501+
1502+
Following are few known limitations of this feature:
1503+
- Python functions optimizable by `torch.compile` may not be served directly in
1504+
the `model.py` file, they need to be enclosed by a class extending the
1505+
[`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module).
1506+
- Model weights cannot be shared across multiple instances on the same GPU
1507+
device.
1508+
- When using `KIND_MODEL` as model instance kind, the default device of the
1509+
first parameter on the model is used.
14031510

14041511
### PyTorch Determinism
14051512

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
<!--
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
-->
28+
29+
# PyTorch Example
30+
31+
In this section, we demonstrate an end-to-end example for using the
32+
[PyTorch Platform \[Experimental\]](../../README.md#pytorch-platform-experimental)
33+
to serve a PyTorch model directly, **without** needing to implement the
34+
`TritonPythonModel` class.
35+
36+
## Create a ResNet50 model repository
37+
38+
We will use the files that come with this example to create the model
39+
repository.
40+
41+
First, download [client.py](client.py), [config.pbtxt](config.pbtxt),
42+
[model.py](model.py),
43+
[mug.jpg](https://raw.githubusercontent.com/triton-inference-server/server/main/qa/images/mug.jpg)
44+
and [resnet50_labels.txt](resnet50_labels.txt) to your local machine.
45+
46+
Next, at the directory where the downloaded files are saved at, create a model
47+
repository with the following commands:
48+
```
49+
$ mkdir -p models/resnet50_pytorch/1
50+
$ mv model.py models/resnet50_pytorch/1
51+
$ mv config.pbtxt models/resnet50_pytorch
52+
```
53+
54+
## Pull the Triton Docker images
55+
56+
We need to install Docker and NVIDIA Container Toolkit before proceeding, refer
57+
to the
58+
[installation steps](https://github.com/triton-inference-server/server/tree/main/docs#installation).
59+
60+
To pull the latest containers, run the following commands:
61+
```
62+
$ docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3
63+
$ docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk
64+
```
65+
See the installation steps above for the `<yy.mm>` version.
66+
67+
For example, if the version is `23.08`, then:
68+
```
69+
$ docker pull nvcr.io/nvidia/tritonserver:23.08-py3
70+
$ docker pull nvcr.io/nvidia/tritonserver:23.08-py3-sdk
71+
```
72+
73+
Be sure to replace the `<yy.mm>` with the version pulled for all the remaining
74+
parts of this example.
75+
76+
## Start the Triton Server
77+
78+
At the directory where we created the PyTorch model (at where the "models"
79+
folder is located), run the following command:
80+
```
81+
$ docker run -it --rm --gpus all --shm-size 1g -p 8000:8000 -v `pwd`:/pytorch_example nvcr.io/nvidia/tritonserver:<yy.mm>-py3 /bin/bash
82+
```
83+
84+
Inside the container, we need to install PyTorch, Pillow and Requests to run this example.
85+
We recommend using `pip` method for the installations, for example:
86+
```
87+
$ pip3 install torch Pillow requests
88+
```
89+
90+
Finally, we need to start the Triton Server, run the following command:
91+
```
92+
$ tritonserver --model-repository=/pytorch_example/models
93+
```
94+
95+
To leave the container for the next step, press: `CTRL + P + Q`.
96+
97+
## Test inference
98+
99+
At the directory where the client.py is located, run the following command:
100+
```
101+
$ docker run --rm --net=host -v `pwd`:/pytorch_example nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk python3 /pytorch_example/client.py
102+
```
103+
104+
A successful inference will print the following at the end:
105+
```
106+
Result: COFFEE MUG
107+
Expected result: COFFEE MUG
108+
PASS: PyTorch platform handler
109+
```
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions
7+
# are met:
8+
# * Redistributions of source code must retain the above copyright
9+
# notice, this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright
11+
# notice, this list of conditions and the following disclaimer in the
12+
# documentation and/or other materials provided with the distribution.
13+
# * Neither the name of NVIDIA CORPORATION nor the names of its
14+
# contributors may be used to endorse or promote products derived
15+
# from this software without specific prior written permission.
16+
#
17+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
18+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
20+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
21+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
22+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
23+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import os
30+
import sys
31+
32+
import numpy as np
33+
from PIL import Image
34+
from tritonclient import http as httpclient
35+
from tritonclient.utils import *
36+
37+
script_directory = os.path.dirname(os.path.realpath(__file__))
38+
39+
server_url = "localhost:8000"
40+
model_name = "resnet50_pytorch"
41+
input_name = "INPUT"
42+
output_name = "OUTPUT"
43+
label_path = os.path.join(script_directory, "resnet50_labels.txt")
44+
# The 'mug.jpg' image will be present at the script_directory if the steps on
45+
# the provided README.md are followed. The image may also be found at
46+
# '/workspace/images/mug.jpg' on the SDK container or
47+
# '/opt/tritonserver/qa/images/mug.jpg' on the QA container.
48+
image_path = os.path.join(script_directory, "mug.jpg")
49+
expected_output_class = "COFFEE MUG"
50+
51+
52+
def _load_input_image():
53+
raw_image = Image.open(image_path)
54+
raw_image = raw_image.convert("RGB").resize((224, 224), Image.BILINEAR)
55+
input_image = np.array(raw_image).astype(np.float32)
56+
input_image = (input_image / 127.5) - 1
57+
input_image = np.transpose(input_image, (2, 0, 1))
58+
input_image = np.reshape(input_image, (1, 3, 224, 224))
59+
return input_image
60+
61+
62+
def _infer(input_image):
63+
with httpclient.InferenceServerClient(server_url) as client:
64+
input_tensors = httpclient.InferInput(input_name, input_image.shape, "FP32")
65+
input_tensors.set_data_from_numpy(input_image)
66+
results = client.infer(model_name=model_name, inputs=[input_tensors])
67+
output_tensors = results.as_numpy(output_name)
68+
return output_tensors
69+
70+
71+
def _check_output(output_tensors):
72+
with open(label_path) as f:
73+
labels_dict = {idx: line.strip() for idx, line in enumerate(f)}
74+
max_id = np.argmax(output_tensors, axis=1)[0]
75+
output_class = labels_dict[max_id]
76+
print("Result: " + output_class)
77+
print("Expected result: " + expected_output_class)
78+
if output_class != expected_output_class:
79+
return False
80+
return True
81+
82+
83+
if __name__ == "__main__":
84+
input_image = _load_input_image()
85+
output_tensors = _infer(input_image)
86+
result_valid = _check_output(output_tensors)
87+
88+
if not result_valid:
89+
print("PyTorch platform handler example error: Unexpected result")
90+
sys.exit(1)
91+
92+
print("PASS: PyTorch platform handler")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "resnet50_pytorch"
28+
backend: "python"
29+
platform: "pytorch"
30+
31+
max_batch_size: 128
32+
33+
input {
34+
name: "INPUT"
35+
data_type: TYPE_FP32
36+
format: FORMAT_NCHW
37+
dims: [ 3, 224, 224 ]
38+
}
39+
output {
40+
name: "OUTPUT"
41+
data_type: TYPE_FP32
42+
dims: [ 1000 ]
43+
}
44+
45+
instance_group [{ kind: KIND_CPU }]

0 commit comments

Comments
 (0)