Skip to content

Commit 8a0dda6

Browse files
authored
Add inf2 examples for Triton on SageMaker (#4247)
* Add inf2 examples for Triton on SageMaker * Change to use merged inf2 link on nv triton github * Nit changes * Enable larger timeout for LLM * Format notebook
1 parent 0419b26 commit 8a0dda6

File tree

6 files changed

+1832
-0
lines changed

6 files changed

+1832
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
ARG SM_TRITON_IMAGE_URI
2+
FROM ${SM_TRITON_IMAGE_URI}
3+
4+
RUN mkdir -p /mylib/udev/rules.d/
5+
# setup.sh script from python backend picks up the rules from here.
6+
COPY ./mylib/* /mylib/udev/rules.d/
7+
8+
# Install both TensorFlow and PyTorch Neuron packages
9+
RUN git clone https://github.com/triton-inference-server/python_backend \
10+
&& cd python_backend \
11+
&& source inferentia/scripts/setup.sh -t -inf2 \
12+
&& source inferentia/scripts/setup.sh -p -inf2
13+
14+
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import json
2+
import numpy as np
3+
import os
4+
import sys
5+
import triton_python_backend_utils as pb_utils
6+
7+
import torch
8+
import torch_neuronx
9+
10+
from transformers_neuronx.gptj.model import GPTJForSampling
11+
12+
class TritonPythonModel:
13+
"""Your Python model must use the same class name. Every Python model
14+
that is created must have "TritonPythonModel" as the class name.
15+
"""
16+
17+
def _validate_and_get_index(self, name):
18+
parts = name.split('__')
19+
if len(parts) != 2:
20+
raise pb_utils.TritonModelException(
21+
"tensor names are expected to be in format <name>__<index>, got {}"
22+
.format(name))
23+
24+
if not parts[1].isnumeric():
25+
raise pb_utils.TritonModelException(
26+
"tensor names are expected to be in format <name>__<index> where <index> should be numeric, got {}"
27+
.format(name))
28+
29+
return int(parts[1])
30+
31+
def _validate_input_dict(self, expected_count):
32+
for i in range(expected_count):
33+
if i not in self.input_dict:
34+
raise pb_utils.TritonModelException(
35+
"input corresponding to index {} not found".format(i))
36+
37+
def _validate_output_dict(self, expected_count):
38+
for i in range(expected_count):
39+
if i not in self.output_dict:
40+
raise pb_utils.TritonModelException(
41+
"output corresponding to index {} not found".format(i))
42+
43+
def initialize(self, args):
44+
"""`initialize` is called only once when the model is being loaded.
45+
Implementing `initialize` function is optional. This function allows
46+
the model to intialize any state associated with this model.
47+
48+
Parameters
49+
----------
50+
args : dict
51+
Both keys and values are strings. The dictionary keys and values are:
52+
* model_config: A JSON string containing the model configuration
53+
* model_instance_kind: A string containing model instance kind
54+
* model_instance_device_id: A string containing model instance device ID
55+
* model_repository: Model repository path
56+
* model_version: Model version
57+
* model_name: Model name
58+
"""
59+
60+
# You must parse model_config. JSON string is not parsed here
61+
self.model_config = model_config = json.loads(args['model_config'])
62+
63+
if (len(model_config['instance_group']) != 1):
64+
raise pb_utils.TritonModelException(
65+
"this model supports only a single instance group, got {}".
66+
format(len(model_config['instance_group'])))
67+
68+
instance_group_config = model_config['instance_group'][0]
69+
instance_count = instance_group_config['count']
70+
71+
instance_idx = 0
72+
if instance_count > 1:
73+
instance_name_parts = args['model_instance_name'].split("_")
74+
if not instance_name_parts[-1].isnumeric():
75+
raise pb_utils.TritonModelException(
76+
"internal error: the model instance name should end with '_<instance_idx>', got {}"
77+
.format(args['model_instance_name']))
78+
instance_idx = int(instance_name_parts[-1])
79+
80+
params = model_config['parameters']
81+
compiled_model = params['COMPILED_MODEL']['string_value']
82+
compiled_model = os.path.join(args['model_repository'], compiled_model)
83+
84+
nc_start_idx = int(params['NEURON_CORE_START_INDEX']['string_value'])
85+
nc_end_idx = int(params['NEURON_CORE_END_INDEX']['string_value'])
86+
if nc_end_idx < nc_start_idx:
87+
raise pb_utils.TritonModelException(
88+
"the neuron core end index should be greater than or equal to the start index"
89+
)
90+
91+
threads_per_core = int(params['NUM_THREADS_PER_CORE']['string_value'])
92+
if threads_per_core < 1:
93+
raise pb_utils.TritonModelException(
94+
"the number of threads per core should be greater than or equal to 1"
95+
)
96+
num_threads = (nc_end_idx - nc_start_idx + 1) * threads_per_core
97+
98+
total_core_count = nc_end_idx - nc_start_idx + 1
99+
if (instance_count > total_core_count):
100+
raise pb_utils.TritonModelException(
101+
"can not distribute {} triton model instances to {} neuron cores"
102+
.format(instance_count, total_core_count))
103+
cores_per_instance = total_core_count // instance_count
104+
105+
self.input_dict = {}
106+
expected_input_count = 0
107+
for config_input in model_config['input']:
108+
index = self._validate_and_get_index(config_input['name'])
109+
self.input_dict[index] = [
110+
config_input['name'], config_input['data_type'],
111+
config_input['dims']
112+
]
113+
expected_input_count += 1
114+
self._validate_input_dict(expected_input_count)
115+
116+
self.output_dict = {}
117+
for config_output in model_config['output']:
118+
index = self._validate_and_get_index(config_output['name'])
119+
self.output_dict[index] = [
120+
config_output['name'], config_output['data_type'],
121+
config_output['dims']
122+
]
123+
124+
adjusted_nc_start_idx = (instance_idx *
125+
cores_per_instance) + nc_start_idx
126+
cores_range = '{}-{}'.format(
127+
adjusted_nc_start_idx,
128+
(adjusted_nc_start_idx + cores_per_instance - 1))
129+
os.environ["NEURON_RT_VISIBLE_CORES"] = cores_range
130+
131+
consumed_cores_list = [i for i in range(cores_per_instance)]
132+
133+
#self.model_neuron = torch.jit.load(compiled_model)
134+
batch_size = 1
135+
tp_degree = 4
136+
n_positions = 2048
137+
amp = 'bf16'
138+
unroll = None
139+
self.model_neuron = GPTJForSampling.from_pretrained(compiled_model, batch_size=batch_size, amp=amp, tp_degree=tp_degree, n_positions=n_positions, unroll=unroll)
140+
self.model_neuron.to_neuron()
141+
142+
self.model_neuron.num_workers = num_threads
143+
144+
def execute(self, requests):
145+
"""`execute` MUST be implemented in every Python model. `execute`
146+
function receives a list of pb_utils.InferenceRequest as the only
147+
argument. This function is called when an inference request is made
148+
for this model. Depending on the batching configuration (e.g. Dynamic
149+
Batching) used, `requests` may contain multiple requests. Every
150+
Python model, must create one pb_utils.InferenceResponse for every
151+
pb_utils.InferenceRequest in `requests`. If there is an error, you can
152+
set the error argument when creating a pb_utils.InferenceResponse
153+
154+
Parameters
155+
----------
156+
requests : list
157+
A list of pb_utils.InferenceRequest
158+
159+
Returns
160+
-------
161+
list
162+
A list of pb_utils.InferenceResponse. The length of this list must
163+
be the same as `requests`
164+
"""
165+
166+
responses = []
167+
inputs = []
168+
num_requests = len(requests)
169+
request_batch_sizes = []
170+
for i in self.input_dict.keys():
171+
name, dt, shape = self.input_dict[i]
172+
first_tensor = torch.as_tensor(pb_utils.get_input_tensor_by_name(requests[0],
173+
name).as_numpy())
174+
request_batch_sizes.append(first_tensor.size(dim=0))
175+
batched_tensor = first_tensor
176+
for j in range(1, num_requests):
177+
tensor = torch.as_tensor(pb_utils.get_input_tensor_by_name(requests[j],
178+
name).as_numpy())
179+
request_batch_sizes.append(request_batch_sizes[-1] + tensor.size(dim=0))
180+
batched_tensor = torch.cat((batched_tensor, tensor), dim=0)
181+
inputs.append(batched_tensor)
182+
183+
batched_results = self.model_neuron.sample(batched_tensor, 512)
184+
chunky_batched_results = []
185+
for i in self.output_dict.keys():
186+
batch = batched_results[i] if isinstance(batched_results, tuple) else batched_results
187+
chunky_batched_results.append(torch.tensor_split(batch, request_batch_sizes, dim=0))
188+
for i in range(num_requests):
189+
output_tensors = []
190+
for j in self.output_dict.keys():
191+
name, dt, shape = self.output_dict[j]
192+
result = chunky_batched_results[j][i]
193+
output_tensor = pb_utils.Tensor(
194+
name, result.numpy().astype(
195+
pb_utils.triton_string_to_numpy(dt)))
196+
output_tensors.append(output_tensor)
197+
inference_response = pb_utils.InferenceResponse(
198+
output_tensors=output_tensors)
199+
responses.append(inference_response)
200+
201+
return responses
202+
203+
def finalize(self):
204+
"""`finalize` is called only once when the model is being unloaded.
205+
Implementing `finalize` function is OPTIONAL. This function allows
206+
the model to perform any necessary clean ups before exit.
207+
"""
208+
print('Cleaning up...')
209+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: "gptj"
2+
backend: "python"
3+
max_batch_size: 128
4+
5+
input [
6+
{
7+
name: "INPUT__0"
8+
data_type: TYPE_INT64
9+
dims: [128]
10+
}
11+
]
12+
13+
output [
14+
{
15+
name: "OUTPUT__0"
16+
data_type: TYPE_INT64
17+
dims: [2048]
18+
}
19+
]
20+
21+
instance_group [
22+
{
23+
kind: KIND_MODEL
24+
count: 1
25+
}
26+
]
27+
dynamic_batching {
28+
preferred_batch_size: 128
29+
}
30+
parameters: {key: "COMPILED_MODEL", value: {string_value: "EleutherAI-gpt-j-6B-bf16-local"}}
31+
parameters: {key: "NEURON_CORE_START_INDEX", value: {string_value: "0"}}
32+
parameters: {key: "NEURON_CORE_END_INDEX", value: {string_value: "11"}}
33+
parameters: {key: "NUM_THREADS_PER_CORE", value: {string_value: "1"}}
34+

0 commit comments

Comments
 (0)