Skip to content

Commit c2cb969

Browse files
committed
added triton deplolyment
1 parent cb1c0b2 commit c2cb969

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Getting Started
2828
* :ref:`use_from_pytorch`
2929
* :ref:`runtime`
3030
* :ref:`using_dla`
31+
* :ref:`deploy_torch_tensorrt_to_triton`
3132

3233
.. toctree::
3334
:caption: Getting Started
@@ -43,6 +44,7 @@ Getting Started
4344
tutorials/use_from_pytorch
4445
tutorials/runtime
4546
tutorials/using_dla
47+
tutorials/deploy_torch_tensorrt_to_triton
4648

4749
.. toctree::
4850
:caption: Notebooks
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
Deploying a Torch-TensorRT model (to Triton)
2+
============================================
3+
4+
Optimization and deployment go hand in hand in a discussion about Machine
5+
Learning infrastructure. For a Torch-TensorRT user, network level optimzation
6+
to get the maximum performance would already be an area of expertize.
7+
8+
However, serving this optimized model comes with it's own set of considerations
9+
and challenges like: building an infrastructure to support concorrent model
10+
executions, supporting clients over HTTP or gRPC and more.
11+
12+
The `Triton Inference Server <https://github.com/triton-inference-server/server>`__
13+
solves the aforementioned and more. Let's discuss step-by-step, the process of
14+
optimizing a model with Torch-TensorRT, deploying it on Triton Inference
15+
Server, and building a client to query the model.
16+
17+
Step 1: Optimize your model with Torch-TensorRT
18+
-----------------------------------------------
19+
20+
Most Torch-TensorRT users will be familiar with this step. For the purpose of
21+
this demoonstration, we will be using a ResNet50 model from Torchhub.
22+
23+
Let’s first pull the NGC PyTorch Docker container. You may need to create
24+
an account and get the API key from `here <https://ngc.nvidia.com/setup/>`__.
25+
Sign up and login with your key (follow the instructions
26+
`here <https://ngc.nvidia.com/setup/api-key>`__ after signing up).
27+
28+
::
29+
30+
# <xx.xx> is the yy:mm for the publishing tag for NVIDIA's Pytorch
31+
# container; eg. 22.04
32+
33+
docker run -it --gpus all -v /path/to/folder:/resnet50_eg nvcr.io/nvidia/pytorch:<xx.xx>-py3
34+
35+
Once inside the container, we can proceed to download a ResNet model from
36+
Torchhub and optimize it with Torch-TensorRT.
37+
38+
::
39+
40+
import torch
41+
import torch_tensorrt
42+
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
43+
44+
# load model
45+
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
46+
47+
# Compile with Torch TensorRT;
48+
trt_model = torch_tensorrt.compile(model,
49+
inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
50+
enabled_precisions= { torch.half} # Run with FP32
51+
)
52+
53+
# Save the model
54+
torch.jit.save(trt_model, "model.pt")
55+
56+
The next step in the process is to set up a Triton Inference Server.
57+
58+
Step 2: Set Up Triton Inference Server
59+
--------------------------------------
60+
61+
If you are new to the Triton Inference Server and want to learn more, we
62+
highly recommend to checking our `Github
63+
Repository <https://github.com/triton-inference-server>`__.
64+
65+
To use Triton, we need to make a model repository. A model repository, as the
66+
name suggested, is a repository of the models the Inference server hosts. While
67+
Triton can serve models from multiple repositories, in this example, we will
68+
discuss the simplest possible form of the model repository.
69+
70+
The structure of this repository should look something like this:
71+
72+
::
73+
74+
model_repository
75+
|
76+
+-- resnet50
77+
|
78+
+-- config.pbtxt
79+
+-- 1
80+
|
81+
+-- model.pt
82+
83+
There are two files that Triton requires to serve the model: the model itself
84+
and a model configuration file which is typically provided in ``config.pbtxt``.
85+
For the model we prepared in step 1, the following configuration can be used:
86+
87+
::
88+
89+
name: "resnet50"
90+
platform: "pytorch_libtorch"
91+
max_batch_size : 0
92+
input [
93+
{
94+
name: "input__0"
95+
data_type: TYPE_FP32
96+
dims: [ 3, 224, 224 ]
97+
reshape { shape: [ 1, 3, 224, 224 ] }
98+
}
99+
]
100+
output [
101+
{
102+
name: "output__0"
103+
data_type: TYPE_FP32
104+
dims: [ 1, 1000 ,1, 1]
105+
reshape { shape: [ 1, 1000 ] }
106+
}
107+
]
108+
109+
The ``config.pbtxt`` file is used to describe the exact model configuration
110+
with details like the names and shapes of the input and output layer(s),
111+
datatypes, scheduling and batching details and more. If you are new to Triton,
112+
we highly encourage you to check out this `section of our
113+
documentation <https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md>`__
114+
for more details.
115+
116+
With the model repository setup, we can proceed to launch the Triton server
117+
with the docker command below.
118+
119+
::
120+
121+
# Make sure that the TensorRT version in the Triton container
122+
# and TensorRT version in the environment used to optimize the model
123+
# are the same.
124+
125+
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /full/path/to/docs/examples/model_repository:/models nvcr.io/nvidia/tritonserver:<xx.yy>-py3 tritonserver --model-repository=/models
126+
127+
This should spin up a Triton Inference server. Next step, building a simple
128+
http client to query the server.
129+
130+
Step 3: Building a Triton Client to Query the Server
131+
----------------------------------------------------
132+
133+
Before proceeding, make sure to have a sample image on hand. If you don't
134+
have one, download an example image to test inference. In this section, we
135+
will be going over a very basic client. For a variety of more fleshed out
136+
examples, refer to the `Triton Client Repository <https://github.com/triton-inference-server/client/tree/main/src/python/examples>`__
137+
138+
::
139+
140+
wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
141+
142+
We then need to install dependencies for building a python client. These will
143+
change from client to client. For a full list of all languages supported by Triton,
144+
please refer to `Triton's client repository <https://github.com/triton-inference-server/client>`__.
145+
146+
::
147+
148+
pip install torchvision
149+
pip install attrdict
150+
pip install nvidia-pyindex
151+
pip install tritonclient[all]
152+
153+
Let's jump into the client. Firstly, we write a small preprocessing function to
154+
resize and normalize the query image.
155+
156+
::
157+
158+
import numpy as np
159+
from torchvision import transforms
160+
from PIL import Image
161+
import tritonclient.http as httpclient
162+
from tritonclient.utils import triton_to_np_dtype
163+
164+
# preprocessing function
165+
def rn50_preprocess(img_path="img1.jpg"):
166+
img = Image.open(img_path)
167+
preprocess = transforms.Compose([
168+
transforms.Resize(256),
169+
transforms.CenterCrop(224),
170+
transforms.ToTensor(),
171+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
172+
])
173+
return preprocess(img).numpy()
174+
175+
transformed_img = rn50_preprocess()
176+
177+
Building a client requires three basic points. Firstly, we setup a connection
178+
with the Triton Inference Server.
179+
180+
::
181+
182+
# Setting up client
183+
triton_client = httpclient.InferenceServerClient(url="localhost:8000")
184+
185+
Secondly, we specify the names of the input and output layer(s) of our model.
186+
187+
::
188+
189+
test_input = httpclient.InferInput("input__0", transformed_img.shape, datatype="FP32")
190+
test_input.set_data_from_numpy(transformed_img, binary_data=True)
191+
192+
test_output = httpclient.InferRequestedOutput("output__0", binary_data=True, class_count=1000)
193+
194+
Lastly, we send an inference request to the Triton Inference Server.
195+
196+
::
197+
198+
# Querying the server
199+
results = triton_client.infer(model_name="resnet50", inputs=[test_input], outputs=[test_output])
200+
test_output_fin = results.as_numpy('output__0')
201+
print(test_output_fin[:5])
202+
203+
The output of the same should look like below:
204+
205+
::
206+
207+
[b'12.468750:90' b'11.523438:92' b'9.664062:14' b'8.429688:136'
208+
b'8.234375:11']
209+
210+
The output format here is ``<confidence_score>:<classification_index>``.
211+
To learn how to map these to the label names and more, refer to our
212+
`documentation <https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_classification.md>`__.

0 commit comments

Comments
 (0)