Skip to content

Commit 6d2e01a

Browse files
committed
fix: Add test cases and improve backend
- Add support for Input objects, add utilities - Add modeling e2e test cases for Dynamo backend - Improve defaults and settings in Dynamo class
1 parent aa0dda8 commit 6d2e01a

File tree

6 files changed

+243
-11
lines changed

6 files changed

+243
-11
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,27 @@ def _supported_input_size_type(input_size: Any) -> bool:
237237
else:
238238
return False
239239

240+
@staticmethod
241+
def _dtype_to_torch_type(dtype: _enums.dtype) -> torch.dtype:
242+
if isinstance(dtype, _enums.dtype):
243+
if dtype == _enums.dtype.long:
244+
return torch.long
245+
elif dtype == _enums.dtype.int32:
246+
return torch.int32
247+
elif dtype == _enums.dtype.half:
248+
return torch.half
249+
elif dtype == _enums.dtype.float:
250+
return torch.float
251+
elif dtype == _enums.dtype.bool:
252+
return torch.bool
253+
else:
254+
raise TypeError(
255+
"Provided an unsupported data type as an input data type (support: bool, int32, long, half, float), got: "
256+
+ str(dtype)
257+
)
258+
else:
259+
raise ValueError("Did not provide an _enums.dtype type as input.")
260+
240261
@staticmethod
241262
def _parse_dtype(dtype: Any) -> _enums.dtype:
242263
if isinstance(dtype, torch.dtype):
@@ -416,8 +437,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
416437
)
417438

418439
if self.shape_mode == Input._ShapeMode.STATIC:
419-
return torch.randn(self.shape).to(dtype=self.dtype)
440+
return torch.rand(self.shape).to(
441+
dtype=Input._dtype_to_torch_type(self.dtype)
442+
)
420443
else:
421-
return torch.randn(self.shape[optimization_profile_field]).to(
422-
dtype=self.dtype
444+
return torch.rand(self.shape[optimization_profile_field]).to(
445+
dtype=Input._dtype_to_torch_type(self.dtype)
423446
)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import torch
22
import logging
3+
import collections.abc
34
import torch_tensorrt
45
from functools import partial
56

6-
from typing import Sequence, Any
7+
from typing import Any
78
from torch_tensorrt import EngineCapability, Device
89
from torch_tensorrt.fx.utils import LowerPrecision
910

1011
from torch_tensorrt.dynamo._settings import CompilationSettings
12+
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
1113
from torch_tensorrt.dynamo.backends import tensorrt_backend
1214
from torch_tensorrt.dynamo._defaults import (
1315
PRECISION,
@@ -22,7 +24,7 @@
2224

2325
def compile(
2426
gm: torch.nn.Module,
25-
example_inputs: Sequence[Any],
27+
inputs: Any,
2628
*,
2729
device=Device._current_device(),
2830
disable_tf32=False,
@@ -51,6 +53,11 @@ def compile(
5153
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
5254
)
5355

56+
if not isinstance(inputs, collections.abc.Sequence):
57+
inputs = [inputs]
58+
59+
inputs = prepare_inputs(inputs, prepare_device(device))
60+
5461
if (
5562
torch.float16 in enabled_precisions
5663
or torch_tensorrt.dtype.half in enabled_precisions
@@ -79,7 +86,7 @@ def compile(
7986
model = torch.compile(gm, backend=custom_backend)
8087

8188
# Ensure compilation occurs by calling the function with provided inputs
82-
model(*example_inputs)
89+
model(*inputs)
8390

8491
return model
8592

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
7-
MAX_NUM_TRT_ENGINES = 10
7+
MAX_NUM_TRT_ENGINES = 200

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@dataclass(frozen=True)
1313
class CompilationSettings:
14-
precision: LowerPrecision = (PRECISION,)
15-
debug: bool = (DEBUG,)
16-
workspace_size: int = (MAX_WORKSPACE_SIZE,)
17-
max_num_trt_engines: int = (MAX_NUM_TRT_ENGINES,)
14+
precision: LowerPrecision = PRECISION
15+
debug: bool = DEBUG
16+
workspace_size: int = MAX_WORKSPACE_SIZE
17+
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES

py/torch_tensorrt/dynamo/utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
3+
from typing import Any, Union, Sequence, Dict
4+
from torch_tensorrt import _Input, Device
5+
6+
7+
def prepare_inputs(
8+
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
9+
device: torch.device = torch.device("cuda"),
10+
) -> Any:
11+
if isinstance(inputs, _Input.Input):
12+
if isinstance(inputs.shape, dict):
13+
return inputs.example_tensor(optimization_profile_field="opt_shape").to(
14+
device
15+
)
16+
else:
17+
return inputs.example_tensor().to(device)
18+
19+
elif isinstance(inputs, torch.Tensor):
20+
return inputs
21+
22+
elif isinstance(inputs, list):
23+
prepared_input = list()
24+
25+
for input_obj in inputs:
26+
prepared_input.append(prepare_inputs(input_obj))
27+
28+
return prepared_input
29+
30+
elif isinstance(inputs, tuple):
31+
prepared_input = list()
32+
33+
for input_obj in inputs:
34+
prepared_input.append(prepare_inputs(input_obj))
35+
36+
return tuple(prepared_input)
37+
38+
elif isinstance(inputs, dict):
39+
prepared_input = dict()
40+
41+
for key, input_obj in inputs.items():
42+
prepared_input[key] = prepare_inputs(input_obj)
43+
44+
return prepared_input
45+
46+
else:
47+
raise ValueError(
48+
f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. "
49+
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
50+
)
51+
52+
53+
def prepare_device(device: Union[Device, torch.device]) -> torch.device:
54+
if isinstance(device, Device):
55+
if device.gpu_id != -1:
56+
device = torch.device(device.gpu_id)
57+
else:
58+
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
59+
60+
elif isinstance(device, torch.device):
61+
device = device
62+
63+
else:
64+
raise ValueError(
65+
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
66+
)

tests/py/api/test_dynamo_backend.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import unittest
2+
import torch
3+
import timm
4+
5+
import torch_tensorrt as torchtrt
6+
import torchvision.models as models
7+
8+
from transformers import BertModel
9+
from utils import COSINE_THRESHOLD, cosine_similarity
10+
11+
12+
class TestModels(unittest.TestCase):
13+
def test_resnet18(self):
14+
self.model = models.resnet18(pretrained=True).eval().to("cuda")
15+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
16+
17+
compile_spec = {
18+
"inputs": [
19+
torchtrt.Input(
20+
self.input.shape, dtype=torch.float, format=torch.contiguous_format
21+
)
22+
],
23+
"device": torchtrt.Device("cuda:0"),
24+
"enabled_precisions": {torch.float},
25+
}
26+
27+
trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec)
28+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
29+
self.assertTrue(
30+
cos_sim > COSINE_THRESHOLD,
31+
msg=f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
32+
)
33+
34+
def test_mobilenet_v2(self):
35+
self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
36+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
37+
38+
compile_spec = {
39+
"inputs": [
40+
torchtrt.Input(
41+
self.input.shape, dtype=torch.float, format=torch.contiguous_format
42+
)
43+
],
44+
"device": torchtrt.Device("cuda:0"),
45+
"enabled_precisions": {torch.float},
46+
}
47+
48+
trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec)
49+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
50+
self.assertTrue(
51+
cos_sim > COSINE_THRESHOLD,
52+
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
53+
)
54+
55+
def test_efficientnet_b0(self):
56+
self.model = (
57+
timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
58+
)
59+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
60+
61+
compile_spec = {
62+
"inputs": [
63+
torchtrt.Input(
64+
self.input.shape, dtype=torch.float, format=torch.contiguous_format
65+
)
66+
],
67+
"device": torchtrt.Device("cuda:0"),
68+
"enabled_precisions": {torch.float},
69+
}
70+
71+
trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec)
72+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
73+
self.assertTrue(
74+
cos_sim > COSINE_THRESHOLD,
75+
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
76+
)
77+
78+
def test_bert_base_uncased(self):
79+
self.model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
80+
self.input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
81+
self.input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
82+
83+
compile_spec = {
84+
"inputs": [
85+
torchtrt.Input(
86+
self.input.shape,
87+
dtype=self.input.dtype,
88+
format=torch.contiguous_format,
89+
),
90+
torchtrt.Input(
91+
self.input.shape,
92+
dtype=self.input.dtype,
93+
format=torch.contiguous_format,
94+
),
95+
],
96+
"device": torchtrt.Device("cuda:0"),
97+
"enabled_precisions": {torch.float},
98+
"truncate_long_and_double": True,
99+
"debug": True,
100+
}
101+
trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec)
102+
103+
model_outputs = self.model(self.input, self.input2)
104+
trt_model_outputs = trt_mod(self.input, self.input2)
105+
for key in model_outputs.keys():
106+
out, trt_out = model_outputs[key], trt_model_outputs[key]
107+
cos_sim = cosine_similarity(out, trt_out)
108+
self.assertTrue(
109+
cos_sim > COSINE_THRESHOLD,
110+
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
111+
)
112+
113+
def test_resnet18_half(self):
114+
self.model = models.resnet18(pretrained=True).eval().to("cuda").half()
115+
self.input = torch.randn((1, 3, 224, 224)).to("cuda").half()
116+
117+
compile_spec = {
118+
"inputs": [
119+
torchtrt.Input(
120+
self.input.shape, dtype=torch.half, format=torch.contiguous_format
121+
)
122+
],
123+
"device": torchtrt.Device("cuda:0"),
124+
"enabled_precisions": {torch.half},
125+
}
126+
127+
trt_mod = torchtrt.dynamo.compile(self.model, **compile_spec)
128+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
129+
self.assertTrue(
130+
cos_sim > COSINE_THRESHOLD,
131+
msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
132+
)
133+
134+
135+
if __name__ == "__main__":
136+
unittest.main()

0 commit comments

Comments
 (0)