Skip to content

Commit 836cbff

Browse files
No public description
PiperOrigin-RevId: 755907301
1 parent 98c1f58 commit 836cbff

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Prediction from the Triton server."""
16+
17+
from typing import Any
18+
import cv2
19+
import numpy as np
20+
import tritonclient
21+
22+
_API_URL = 'localhost:8000'
23+
_OUTPUT_KEYS = (
24+
'detection_classes',
25+
'detection_masks',
26+
'detection_boxes',
27+
'image_info',
28+
'num_detections',
29+
'detection_scores',
30+
)
31+
32+
# Setting up the Triton client
33+
_TRITON_CLIENT = tritonclient.http.InferenceServerClient(
34+
url=_API_URL, network_timeout=1200, connection_timeout=1200
35+
)
36+
37+
# Outputs setup based on constants
38+
_OUTPUTS = [
39+
tritonclient.http.InferRequestedOutput(key, binary_data=True)
40+
for key in _OUTPUT_KEYS
41+
]
42+
43+
44+
def model_input(
45+
path: str, height: int, width: int
46+
) -> tritonclient.http.InferInput:
47+
"""Prepares an image for input to a Triton model server.
48+
49+
It reads it from a path, resizes it, normalizes it, and converts it to the
50+
format required by the server.
51+
52+
Args:
53+
path: The file path to the image that needs to be processed.
54+
height: The height of the image to be resized.
55+
width: The width of the image to be resized.
56+
57+
Returns:
58+
A Triton inference server input object containing the processed image.
59+
"""
60+
original_image = cv2.imread(path)
61+
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
62+
image_resized = cv2.resize(
63+
image, (width, height), interpolation=cv2.INTER_AREA
64+
)
65+
expanded_image = np.expand_dims(image_resized, axis=0)
66+
inputs = tritonclient.http.InferInput(
67+
'inputs', expanded_image.shape, datatype='UINT8'
68+
)
69+
inputs.set_data_from_numpy(expanded_image, binary_data=True)
70+
return inputs, image, image_resized
71+
72+
73+
def _query_model(
74+
client: tritonclient.http.InferenceServerClient,
75+
model_name: str,
76+
inputs: tritonclient.http.InferInput,
77+
) -> tritonclient.http.InferResult:
78+
"""Sends an inference request to the Triton server.
79+
80+
Args:
81+
client: The Triton server client.
82+
model_name: Name of the model for which inference is requested.
83+
inputs: The input data for inference.
84+
85+
Returns:
86+
The result of the inference request.
87+
"""
88+
return client.infer(model_name=model_name, inputs=[inputs], outputs=_OUTPUTS)
89+
90+
91+
def prediction(
92+
model_name: str, inputs: tritonclient.http.InferInput
93+
) -> dict[str, Any]:
94+
"""Model name for prediction.
95+
96+
Args:
97+
model_name: Model name in Triton Server.
98+
inputs: The input data for inference.
99+
100+
Returns:
101+
prediction output from the model.
102+
"""
103+
result = _query_model(_TRITON_CLIENT, model_name, inputs)
104+
result_dict = {key: result.as_numpy(key) for key in _OUTPUT_KEYS}
105+
return result_dict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest import mock
17+
import numpy as np
18+
# Import the functions to be tested
19+
from official.projects.waste_identification_ml.Triton_TF_Cloud_Deployment.client import triton_server_prediction
20+
21+
22+
class TestTritonPrediction(unittest.TestCase):
23+
24+
@mock.patch("cv2.imread")
25+
@mock.patch("cv2.cvtColor")
26+
@mock.patch("cv2.resize")
27+
def test_model_input(self, mock_resize, mock_convert_color, mock_imread):
28+
"""Test the model_input function."""
29+
30+
# Mocking image loading and processing
31+
mock_imread.return_value = np.ones((500, 500, 3), dtype=np.uint8)
32+
mock_convert_color.return_value = np.ones((500, 500, 3), dtype=np.uint8)
33+
mock_resize.return_value = np.ones((224, 224, 3), dtype=np.uint8)
34+
35+
_, image, image_resized = triton_server_prediction.model_input(
36+
"dummy_path.jpg", 224, 224
37+
)
38+
39+
self.assertEqual(image.shape, (500, 500, 3))
40+
self.assertEqual(image_resized.shape, (224, 224, 3))
41+
42+
@mock.patch("your_module._query_model")
43+
def test_prediction(self, mock_query_model):
44+
"""Test the prediction function."""
45+
46+
mock_result = mock.MagicMock()
47+
mock_result.as_numpy.side_effect = lambda key: np.array(
48+
[1]
49+
)
50+
mock_query_model.return_value = mock_result
51+
52+
mock_inputs = mock.MagicMock()
53+
model_name = "dummy_model"
54+
55+
result = triton_server_prediction.prediction(model_name, mock_inputs)
56+
57+
for key in (
58+
"detection_classes",
59+
"detection_masks",
60+
"detection_boxes",
61+
"image_info",
62+
"num_detections",
63+
"detection_scores",
64+
):
65+
self.assertIn(key, result)
66+
self.assertIsInstance(result[key], np.ndarray)
67+
68+
69+
if __name__ == "__main__":
70+
unittest.main()

0 commit comments

Comments
 (0)