Skip to content

Commit 4fc9cf6

Browse files
authored
Add JAX example (triton-inference-server#186)
* Add JAX example * Update comments * Add jax api link * Refactor README.md and add step by step guide * Add full output to README
1 parent cdfd06d commit 4fc9cf6

File tree

5 files changed

+410
-0
lines changed

5 files changed

+410
-0
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ any C++ code.
6868
- [Examples](#examples)
6969
- [AddSub in NumPy](#addsub-in-numpy)
7070
- [AddSubNet in PyTorch](#addsubnet-in-pytorch)
71+
- [AddSub in JAX](#addsub-in-jax)
7172
- [Business Logic Scripting](#business-logic-scripting-1)
7273
- [Preprocessing](#preprocessing)
7374
- [Decoupled Models](#decoupled-models)
@@ -1034,6 +1035,11 @@ Make sure that PyTorch is available in the same Python environment as other
10341035
dependencies. Alternatively, you can create a [Python Execution Environment](#using-custom-python-execution-environments).
10351036
You can find the files for this example in [examples/pytorch](examples/pytorch).
10361037

1038+
## AddSub in JAX
1039+
1040+
The JAX example shows how to serve JAX in Triton using Python Backend.
1041+
You can find the complete example instructions in [examples/jax](examples/jax/README.md).
1042+
10371043
## Business Logic Scripting
10381044

10391045
The BLS example needs the dependencies required for both of the above examples.

examples/jax/README.md

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
<!--
2+
# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
-->
28+
29+
# JAX Example
30+
31+
In this section, we demonstrate an end-to-end example for using
32+
[JAX](https://jax.readthedocs.io/en/latest/) in Python Backend.
33+
34+
## Create a JAX AddSub model repository
35+
36+
We will use the files that come with this example to create the model
37+
repository.
38+
39+
First, download the [client.py](client.py), [config.pbtxt](config.pbtxt) and
40+
[model.py](model.py) to your local machine.
41+
42+
Next, at the directory where the three files located, create the model
43+
repository with the following commands:
44+
```
45+
$ mkdir -p models/jax/1
46+
$ mv model.py models/jax/1
47+
$ mv config.pbtxt models/jax
48+
```
49+
50+
## Pull the Triton Docker images
51+
52+
We need to install Docker and NVIDIA Container Toolkit before proceeding, refer
53+
to the
54+
[installation steps](https://github.com/triton-inference-server/server/tree/main/docs#installation).
55+
56+
To pull the latest containers, run the following commands:
57+
```
58+
$ docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3
59+
$ docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk
60+
```
61+
See the installation steps above for the `<yy.mm>` version.
62+
63+
At the time of writing, the latest version is `22.08`, which translates to the
64+
following commands:
65+
```
66+
$ docker pull nvcr.io/nvidia/tritonserver:22.08-py3
67+
$ docker pull nvcr.io/nvidia/tritonserver:22.08-py3-sdk
68+
```
69+
70+
Be sure to replace the `<yy.mm>` with the version pulled for all the remaining
71+
parts of this example.
72+
73+
## Start the Triton Server
74+
75+
At the directory where we created the JAX models (at where the "models" folder
76+
is located), run the following command:
77+
```
78+
$ docker run --gpus all -it --rm -p 8000:8000 -v `pwd`:/jax nvcr.io/nvidia/tritonserver:<yy.mm>-py3 /bin/bash
79+
```
80+
81+
Inside the container, we need to install JAX to run this example.
82+
83+
We recommend using the `pip` method mentioned in the
84+
[JAX documentation](https://github.com/google/jax#pip-installation-gpu-cuda).
85+
Make sure that JAX is available in the same Python environment as other
86+
dependencies.
87+
88+
To install for this example, run the following command:
89+
```
90+
$ pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
91+
```
92+
93+
Finally, we need to start the Triton Server, run the following command:
94+
```
95+
$ tritonserver --model-repository=/jax/models
96+
```
97+
98+
To leave the container for the next step, press: `CTRL + P + Q`.
99+
100+
## Test inference
101+
102+
At the directory where the client.py is located, run the following command:
103+
```
104+
$ docker run --rm --net=host -v `pwd`:/jax nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk python3 /jax/client.py
105+
```
106+
107+
A successful inference will print the following at the end:
108+
```
109+
INPUT0 ([0.89262384 0.645457 0.18913145 0.17099917]) + INPUT1 ([0.5703733 0.21917151 0.22854741 0.97336507]) = OUTPUT0 ([1.4629972 0.86462855 0.41767886 1.1443642 ])
110+
INPUT0 ([0.89262384 0.645457 0.18913145 0.17099917]) - INPUT1 ([0.5703733 0.21917151 0.22854741 0.97336507]) = OUTPUT0 ([ 0.32225055 0.4262855 -0.03941596 -0.8023659 ])
111+
PASS: jax
112+
```
113+
Note: You inputs can be different from the above, but the outputs always
114+
correspond to its inputs.

examples/jax/client.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from tritonclient.utils import *
28+
import tritonclient.http as httpclient
29+
import sys
30+
import numpy as np
31+
32+
model_name = "jax"
33+
shape = [4]
34+
35+
with httpclient.InferenceServerClient("localhost:8000") as client:
36+
37+
input0_data = np.random.rand(*shape).astype(np.float32)
38+
input1_data = np.random.rand(*shape).astype(np.float32)
39+
inputs = [
40+
httpclient.InferInput("INPUT0", input0_data.shape,
41+
np_to_triton_dtype(input0_data.dtype)),
42+
httpclient.InferInput("INPUT1", input1_data.shape,
43+
np_to_triton_dtype(input1_data.dtype)),
44+
]
45+
46+
inputs[0].set_data_from_numpy(input0_data)
47+
inputs[1].set_data_from_numpy(input1_data)
48+
49+
outputs = [
50+
httpclient.InferRequestedOutput("OUTPUT0"),
51+
httpclient.InferRequestedOutput("OUTPUT1"),
52+
]
53+
54+
response = client.infer(model_name,
55+
inputs,
56+
request_id=str(1),
57+
outputs=outputs)
58+
59+
result = response.get_response()
60+
output0_data = response.as_numpy("OUTPUT0")
61+
output1_data = response.as_numpy("OUTPUT1")
62+
63+
print("INPUT0 ({}) + INPUT1 ({}) = OUTPUT0 ({})".format(
64+
input0_data, input1_data, output0_data))
65+
print("INPUT0 ({}) - INPUT1 ({}) = OUTPUT0 ({})".format(
66+
input0_data, input1_data, output1_data))
67+
68+
if not np.allclose(input0_data + input1_data, output0_data):
69+
print("jax example error: incorrect sum")
70+
sys.exit(1)
71+
72+
if not np.allclose(input0_data - input1_data, output1_data):
73+
print("jax example error: incorrect difference")
74+
sys.exit(1)
75+
76+
print('PASS: jax')
77+
sys.exit(0)

examples/jax/config.pbtxt

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "jax"
28+
backend: "python"
29+
30+
input [
31+
{
32+
name: "INPUT0"
33+
data_type: TYPE_FP32
34+
dims: [ 4 ]
35+
}
36+
]
37+
input [
38+
{
39+
name: "INPUT1"
40+
data_type: TYPE_FP32
41+
dims: [ 4 ]
42+
}
43+
]
44+
output [
45+
{
46+
name: "OUTPUT0"
47+
data_type: TYPE_FP32
48+
dims: [ 4 ]
49+
}
50+
]
51+
output [
52+
{
53+
name: "OUTPUT1"
54+
data_type: TYPE_FP32
55+
dims: [ 4 ]
56+
}
57+
]
58+
59+
instance_group [{ kind: KIND_CPU }]

0 commit comments

Comments
 (0)