Skip to content

Commit c387f6a

Browse files
authored
Migrate to Neuron Runtime 2.X in model.py (triton-inference-server#93)
* Migrate to Neuron Runtime 2.X in model.py * Address review comments * Addressing review comments * Add documentation for data parallel mode
1 parent c346b5a commit c387f6a

File tree

2 files changed

+103
-123
lines changed

2 files changed

+103
-123
lines changed

inferentia/README.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,25 @@ Once the TorchScript model supporting Inferentia is obtained, use the [gen_trito
9696
An example invocation for the `gen_triton_model.py` can look like:
9797

9898
```
99-
$python3 inferentia/scripts/gen_triton_model.py --triton_input INPUT0,FP16,4x384 INPUT1,FP16,4x384 INPUT2,FP16,4x384 --triton_output OUTPUT0,FP16,4x384 OUTPUT1,FP16,4x384 --compiled_model /home/ubuntu/bert_large_mlperf_neuron_hack_bs1_dynamic.pt --triton_model_dir bert-large-mlperf-bs1x4
99+
$python3 inferentia/scripts/gen_triton_model.py --triton_input INPUT__0,INT64,4x384 INPUT__1,INT64,4x384 INPUT__2,INT64,4x384 --triton_output OUTPUT__0,INT64,4x384 OUTPUT__1,INT64,4x384 --compiled_model /home/ubuntu/bert_large_mlperf_neuron_hack_bs1_dynamic.pt --neuron_core_range 0:3 --triton_model_dir bert-large-mlperf-bs1x4
100100
```
101101

102+
NOTE: Due to the absence of names for inputs and outputs in a
103+
TorchScript model, the name of tensor of both the inputs and
104+
outputs provided to the above script must follow a specific naming
105+
convention i.e. `<name>__<index>`. Where `<name>` can be any
106+
string and `<index>` refers to the position of the corresponding
107+
input/output. This means if there are two inputs and two outputs
108+
they must be named as: "INPUT__0", "INPUT__1" and "OUTPUT__0",
109+
"OUTPUT__1" such that "INPUT__0" refers to first input and
110+
INPUT__1 refers to the second input, etc.
111+
112+
Additionally, `--neuron-core-range` specifies the neuron cores to
113+
be used while serving this models. Currently, only
114+
`torch.neuron.DataParallel()` mode is supported. See
115+
[Data Parallel Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-guide/appnotes/perf/torch-neuron-dataparallel-app-note.html)
116+
for more information.
117+
102118
The invocation should create a triton model directory with following
103119
structutre:
104120

inferentia/scripts/gen_triton_model.py

Lines changed: 86 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def get_parameter_spec(key1, value):
4343
return param_spec
4444

4545
def create_modelconfig(model_name, max_batch_size, inputs, outputs,
46-
compiled_model_path, avbl_neuron_cores_count,
47-
threads_per_core, batch_per_thread):
46+
compiled_model_path, nc_start_idx, nc_end_idx,
47+
threads_per_core):
4848
config = "name: \"{}\"\n".format(model_name)
4949
config += "backend: \"python\"\n"
5050
config += "max_batch_size: {}\n".format(max_batch_size)
@@ -70,9 +70,9 @@ def create_modelconfig(model_name, max_batch_size, inputs, outputs,
7070
]\n'''.format(output_name, "TYPE_" + data_type, shape)
7171
config += "instance_group [ { kind: KIND_MODEL }]\n"
7272
config += get_parameter_spec("COMPILED_MODEL", compiled_model_path)
73-
config += get_parameter_spec("AVAIL_NEURONCORES", avbl_neuron_cores_count)
74-
config += get_parameter_spec("NUM_THREADS_PER_PREDICTOR", threads_per_core)
75-
config += get_parameter_spec("BATCH_PER_THREAD", batch_per_thread)
73+
config += get_parameter_spec("NEURON_CORE_START_INDEX", nc_start_idx)
74+
config += get_parameter_spec("NEURON_CORE_END_INDEX", nc_end_idx)
75+
config += get_parameter_spec("NUM_THREADS_PER_CORE", threads_per_core)
7676
return config
7777

7878
def get_model_license():
@@ -104,88 +104,34 @@ def get_model_license():
104104
'''
105105
return lic
106106

107-
def get_neuron_simple_data_parallel_impl():
108-
neuron_sdpi = '''\n
109-
class NeuronSimpleDataParallel():
110-
111-
def __init__(self, model_file, num_neuron_cores, num_threads, batch_size):
112-
# Construct a list of models
113-
self.num_neuron_cores = num_neuron_cores
114-
self.batch_size = batch_size
115-
self.num_threads = num_threads
116-
117-
class SimpleWrapper():
118-
119-
def __init__(self, model):
120-
self.model = model
121-
122-
def eval(self):
123-
self.model.eval()
124-
125-
def train(self):
126-
self.model.train()
127-
128-
def __call__(self, *inputs):
129-
results = self.model(*inputs)
130-
# Make the output iterable - if it is not already a tuple or list
131-
if not isinstance(results, tuple) or isinstance(results, list):
132-
results = [results]
133-
134-
return results
135-
136-
self.models = [
137-
SimpleWrapper(torch.jit.load(model_file))
138-
for i in range(self.num_threads)
139-
]
140-
nc_env = ','.join(['1'] * num_neuron_cores)
141-
os.environ['NEURONCORE_GROUP_SIZES'] = nc_env
142-
143-
self.executor = futures.ThreadPoolExecutor(max_workers=self.num_threads)
144-
145-
def eval(self):
146-
for m in self.models:
147-
m.eval()
148-
149-
def train(self):
150-
for m in self.models:
151-
m.train()
152-
153-
def __call__(self, *args):
154-
155-
args_per_core = [None for i in range(self.num_threads)]
156-
# Split args
157-
for a in args:
158-
159-
# Based on batch size for arg
160-
step_size = self.batch_size
161-
for i in range(self.num_threads):
162-
# Append a slice of a view
163-
start = i * step_size
164-
end = (i + 1) * step_size
165-
166-
# Slice
167-
args_per_core[i] = []
168-
for input in a:
169-
args_per_core[i].append(input[start:end])
170-
# Call each core with their split and wait to complete
171-
running = {
172-
self.executor.submit(self.models[idx], *args_per_core[idx]): idx
173-
for idx in range(self.num_threads)
174-
}
175-
176-
results = [None] * self.num_threads
177-
178-
for future in futures.as_completed(running):
179-
idx = running[future]
180-
results[idx] = future.result()
181-
182-
return results
183-
184-
'''
185-
return neuron_sdpi
186-
187107
def get_initialize_impl():
188108
init_impl = '''
109+
def _validate_and_get_index(self, name):
110+
parts = name.split('__')
111+
if len(parts) != 2:
112+
raise pb_utils.TritonModelException(
113+
"tensor names are expected to be in format <name>__<index>, got {}"
114+
.format(name))
115+
116+
if not parts[1].isnumeric():
117+
raise pb_utils.TritonModelException(
118+
"tensor names are expected to be in format <name>__<index> where <index> should be numeric, got {}"
119+
.format(name))
120+
121+
return int(parts[1])
122+
123+
def _validate_input_dict(self, expected_count):
124+
for i in range(expected_count):
125+
if i not in self.input_dict:
126+
raise pb_utils.TritonModelException(
127+
"input corresponding to index {} not found".format(i))
128+
129+
def _validate_output_dict(self, expected_count):
130+
for i in range(expected_count):
131+
if i not in self.output_dict:
132+
raise pb_utils.TritonModelException(
133+
"output corresponding to index {} not found".format(i))
134+
189135
def initialize(self, args):
190136
"""`initialize` is called only once when the model is being loaded.
191137
Implementing `initialize` function is optional. This function allows
@@ -207,29 +153,46 @@ def initialize(self, args):
207153
self.model_config = model_config = json.loads(args['model_config'])
208154
209155
self.input_dict = {}
156+
expected_input_count = 0
210157
for config_input in model_config['input']:
211-
self.input_dict[config_input['name']] = [
212-
config_input['data_type'], config_input['dims']
158+
index = self._validate_and_get_index(config_input['name'])
159+
self.input_dict[index] = [
160+
config_input['name'], config_input['data_type'],
161+
config_input['dims']
213162
]
163+
expected_input_count += 1
164+
self._validate_input_dict(expected_input_count)
214165
215166
self.output_dict = {}
216167
for config_output in model_config['output']:
217-
self.output_dict[config_output['name']] = [
218-
config_output['data_type'], config_output['dims']
168+
index = self._validate_and_get_index(config_output['name'])
169+
self.output_dict[index] = [
170+
config_output['name'], config_output['data_type'],
171+
config_output['dims']
219172
]
220173
221174
params = model_config['parameters']
222175
compiled_model = params['COMPILED_MODEL']['string_value']
223-
avbl_neuron_cores_count = int(
224-
params['AVAIL_NEURONCORES']['string_value'])
225-
threads_per_core = int(
226-
params['NUM_THREADS_PER_PREDICTOR']['string_value'])
227-
batch_per_thread = int(params['BATCH_PER_THREAD']['string_value'])
228-
self.num_threads = avbl_neuron_cores_count * threads_per_core
229-
self.model_neuron = NeuronSimpleDataParallel(compiled_model,
230-
avbl_neuron_cores_count,
231-
self.num_threads,
232-
batch_per_thread)
176+
nc_start_idx = int(params['NEURON_CORE_START_INDEX']['string_value'])
177+
nc_end_idx = int(params['NEURON_CORE_END_INDEX']['string_value'])
178+
if nc_end_idx < nc_start_idx:
179+
raise pb_utils.TritonModelException(
180+
"the neuron core end index should be greater than or equal to the start index")
181+
182+
threads_per_core = int(params['NUM_THREADS_PER_CORE']['string_value'])
183+
if threads_per_core < 1:
184+
raise pb_utils.TritonModelException(
185+
"the number of threads per core should be greater than or equal to 1")
186+
num_threads = (nc_end_idx - nc_start_idx + 1) * threads_per_core
187+
188+
# FIXME: Should distribute equally for multiple instance case
189+
consumed_cores_list = []
190+
for i in range(nc_start_idx, (nc_end_idx + 1)):
191+
consumed_cores_list.append(i)
192+
193+
self.model_neuron = torch.neuron.DataParallel(
194+
torch.jit.load(compiled_model), device_ids=consumed_cores_list)
195+
self.model_neuron.num_workers = num_threads
233196
234197
'''
235198
return init_impl
@@ -261,29 +224,28 @@ def execute(self, requests):
261224
responses = []
262225
263226
for request in requests:
264-
num_threads = self.num_threads
265227
inputs = []
266-
for name in self.input_dict.keys():
228+
for i in range(len(self.input_dict)):
229+
name, dt, shape = self.input_dict[i]
267230
tensor = pb_utils.get_input_tensor_by_name(request,
268231
name).as_numpy()
269-
inputs.append(torch.LongTensor(tensor))
270-
results = self.model_neuron(inputs)
232+
inputs.append(torch.as_tensor(tensor))
233+
234+
results = self.model_neuron(*inputs)
271235
272236
output_tensors = []
273-
for name in self.output_dict.keys():
274-
result_shards = []
275-
for i in range(num_threads):
276-
result_shards.append(results[i][len(output_tensors)])
277-
merged_result = np.concatenate(result_shards, axis=0)
278-
dt, shape = self.output_dict[name]
279-
output_tensor = pb_utils.Tensor(name,
280-
merged_result.astype(pb_utils.triton_string_to_numpy(dt)))
237+
for i in self.output_dict.keys():
238+
name, dt, shape = self.output_dict[i]
239+
output_tensor = pb_utils.Tensor(
240+
name, results[i].numpy().astype(
241+
pb_utils.triton_string_to_numpy(dt)))
281242
282243
output_tensors.append(output_tensor)
283244
284245
inference_response = pb_utils.InferenceResponse(
285246
output_tensors=output_tensors)
286247
responses.append(inference_response)
248+
287249
return responses
288250
'''
289251
return exec_impl
@@ -327,7 +289,6 @@ def create_model_file():
327289
import triton_python_backend_utils as pb_utils
328290
'''
329291

330-
triton_model += get_neuron_simple_data_parallel_impl()
331292
triton_model += get_triton_python_model_impl()
332293

333294
return triton_model
@@ -369,18 +330,19 @@ def create_model_file():
369330
type=str,
370331
required=True,
371332
help='Fullpath to the compiled model')
372-
parser.add_argument('--avbl_neuron_cores_count',
373-
type=int,
374-
default=4,
375-
help='The number of available neuron cores')
333+
parser.add_argument('--neuron_core_range',
334+
type=str,
335+
required=True,
336+
help='''The range of neuron core indices
337+
where the model needs to be loaded. The
338+
range should be specified in format
339+
<start_idx>:<end_idx>. For example to
340+
load model on neuron cores (0-7), specify
341+
the following: 0:7''')
376342
parser.add_argument('--threads_per_core',
377343
type=int,
378344
default=1,
379345
help='The number of threads per neuron core')
380-
parser.add_argument('--batch_per_thread',
381-
type=int,
382-
default=1,
383-
help='The batch size per threads')
384346
parser.add_argument('--triton_model_dir',
385347
type=str,
386348
required=True,
@@ -392,6 +354,8 @@ def create_model_file():
392354
inputs = parse_io_tensors(FLAGS.triton_input)
393355
outputs = parse_io_tensors(FLAGS.triton_output)
394356

357+
nc_start_idx, nc_end_idx = [int(i) for i in FLAGS.neuron_core_range.split(":")]
358+
395359
model_version_dir = FLAGS.triton_model_dir + "/" + str(FLAGS.model_version)
396360
try:
397361
os.makedirs(model_version_dir)
@@ -400,8 +364,8 @@ def create_model_file():
400364

401365
model_name = os.path.basename(FLAGS.triton_model_dir)
402366
mc = create_modelconfig(model_name, FLAGS.max_batch_size, inputs, outputs,
403-
FLAGS.compiled_model, FLAGS.avbl_neuron_cores_count,
404-
FLAGS.threads_per_core, FLAGS.batch_per_thread)
367+
FLAGS.compiled_model, nc_start_idx, nc_end_idx,
368+
FLAGS.threads_per_core)
405369
with open(FLAGS.triton_model_dir + "/config.pbtxt", "w") as config_file:
406370
config_file.write(mc)
407371

0 commit comments

Comments
 (0)