Skip to content

Commit c840588

Browse files
authored
Revert "Fix np.object_ for Python backend (triton-inference-server#36)" (triton-inference-server#37)
This reverts commit 809643f.
1 parent 809643f commit c840588

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

src/resources/startup.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -53,34 +53,29 @@
5353

5454
def serialize_byte_tensor(input_tensor):
5555
"""
56-
Serializes a bytes tensor into a flat numpy array of length prepended bytes.
57-
Can pass bytes tensor as numpy array of bytes with dtype of np.bytes_,
58-
or python strings with dtype of np.object_. np.object_ is the recommended
59-
type to be used. np.str_ and np.bytes_ remove trailing zeros at the end of
60-
byte sequence and because of this it should be avoided.
61-
62-
Parameters
63-
----------
64-
input_tensor : np.array
65-
The bytes tensor to serialize.
66-
67-
Returns
68-
-------
69-
serialized_bytes_tensor : np.array
70-
The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order.
71-
72-
Raises
73-
------
74-
InferenceServerException
75-
If unable to serialize the given tensor.
76-
"""
56+
Serializes a bytes tensor into a flat numpy array of length prepend bytes.
57+
Can pass bytes tensor as numpy array of bytes with dtype of np.bytes_,
58+
numpy strings with dtype of np.str_ or python strings with dtype of np.object.
59+
Parameters
60+
----------
61+
input_tensor : np.array
62+
The bytes tensor to serialize.
63+
Returns
64+
-------
65+
serialized_bytes_tensor : np.array
66+
The 1-D numpy array of type uint8 containing the serialized bytes in 'C' order.
67+
Raises
68+
------
69+
InferenceServerException
70+
If unable to serialize the given tensor.
71+
"""
7772

7873
if input_tensor.size == 0:
7974
return np.empty([0])
8075

81-
# If the input is a tensor of string/bytes objects, then must flatten those into
82-
# a 1-dimensional array containing the 4-byte byte size followed by the
83-
# actual element bytes. All elements are concatenated together in "C"
76+
# If the input is a tensor of string/bytes objects, then must flatten those
77+
# into a 1-dimensional array containing the 4-byte byte size followed by
78+
# the actual element bytes. All elements are concatenated together in "C"
8479
# order.
8580
if (input_tensor.dtype == np.object) or (input_tensor.dtype.type
8681
== np.bytes_):
@@ -89,10 +84,13 @@ def serialize_byte_tensor(input_tensor):
8984
# If directly passing bytes to BYTES type,
9085
# don't convert it to str as Python will encode the
9186
# bytes which may distort the meaning
92-
if type(obj.item()) == bytes:
93-
s = obj.item()
87+
if obj.dtype.type == np.bytes_:
88+
if type(obj.item()) == bytes:
89+
s = obj.item()
90+
else:
91+
s = bytes(obj)
9492
else:
95-
s = bytes(obj)
93+
s = str(obj).encode('utf-8')
9694
flattened += struct.pack("<I", len(s))
9795
flattened += s
9896
flattened_array = np.asarray(flattened)
@@ -130,7 +128,7 @@ def deserialize_bytes_tensor(encoded_tensor):
130128
sb = struct.unpack_from("<{}s".format(l), val_buf, offset)[0]
131129
offset += l
132130
strs.append(sb)
133-
return (np.array(strs, dtype=np.bytes_))
131+
return (np.array(strs, dtype=bytes))
134132

135133

136134
def parse_startup_arguments():
@@ -151,6 +149,7 @@ def parse_startup_arguments():
151149
class PythonHost(PythonInterpreterServicer):
152150
"""This class handles inference request for python script.
153151
"""
152+
154153
def __init__(self, module_path, *args, **kwargs):
155154
super(PythonInterpreterServicer, self).__init__(*args, **kwargs)
156155

@@ -306,6 +305,7 @@ def Execute(self, request, context):
306305
# We need to serialize TYPE_STRING
307306
if output_np_array.dtype == np.object or output_np_array.dtype.type is np.bytes_:
308307
output_np_array = serialize_byte_tensor(output_np_array)
308+
309309
tensor = Tensor(name=output_tensor.name(),
310310
dtype=tpb_utils.numpy_to_triton_type(
311311
output_np_array.dtype.type),

0 commit comments

Comments
 (0)