Skip to content

Commit b387dd5

Browse files
committed
chore: Modify circle ci to reduce tests
Signed-off-by: Dheeraj Peri <[email protected]>
2 parents a9b0711 + 33255de commit b387dd5

File tree

17 files changed

+798
-2
lines changed

17 files changed

+798
-2
lines changed

.circleci/config.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,24 @@ commands:
727727
- store_artifacts:
728728
path: /tmp/testlogs
729729

730+
test-dynamo-torch_compile:
731+
description: "Test the Dynamo torch_compile path"
732+
steps:
733+
- run:
734+
name: Run Dynamo torch_compile E2E tests
735+
command: |
736+
cd py/torch_tensorrt/dynamo/
737+
pushd test/
738+
pip3 install timm
739+
pip3 install transformers
740+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/test_results.xml --ir torch_compile
741+
popd
742+
743+
- store_test_results:
744+
path: /tmp/artifacts
745+
- store_artifacts:
746+
path: /tmp/testlogs
747+
730748
test-dynamo-fx_ts:
731749
description: "Test the dynamo backend"
732750
steps:
@@ -947,6 +965,7 @@ jobs:
947965
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
948966
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
949967
- dump-test-env
968+
- test-dynamo-torch_compile
950969
- test-dynamo-fx_ts
951970

952971
package-x86_64-linux:

py/torch_tensorrt/_Input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
367367
)
368368

369369
if self.shape_mode == Input._ShapeMode.STATIC:
370-
return torch.randn(self.shape).to(
370+
return torch.rand(self.shape).to(
371371
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
372372
)
373373
else:
374-
return torch.randn(self.shape[optimization_profile_field]).to(
374+
return torch.rand(self.shape[optimization_profile_field]).to(
375375
dtype=self.dtype if not self.torch_dtype else self.torch_dtype
376376
)

py/torch_tensorrt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696
from torch_tensorrt import dynamo
97+
from torch_tensorrt.dynamo import torch_compile
9798

9899

99100
def _register_with_torch():

py/torch_tensorrt/_compile.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class _IRType(Enum):
1616
ts = 0
1717
fx = 1
1818
fx_ts_compat = 2
19+
torch_compile = 3
1920

2021

2122
class _ModuleType(Enum):
@@ -46,6 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4647

4748
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4849
ir_targets_fx = ir == "fx"
50+
ir_targets_torch_compile = ir == "torch_compile"
4951
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
5052

5153
if module_is_tsable and ir_targets_torchscript:
@@ -54,6 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
5456
return _IRType.fx
5557
elif module_is_fxable and ir_targets_fx_ts_compat:
5658
return _IRType.fx_ts_compat
59+
elif module_is_fxable and ir_targets_torch_compile:
60+
return _IRType.torch_compile
5761
else:
5862
if ir == "default":
5963
# Options are listed in order of preference
@@ -152,6 +156,10 @@ def compile(
152156
dynamic_batch=False,
153157
**kwargs,
154158
)
159+
elif target_ir == _IRType.torch_compile:
160+
return torch_tensorrt.dynamo.torch_compile(
161+
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162+
)
155163
elif target_ir == _IRType.fx_ts_compat:
156164
return torch_tensorrt.dynamo.fx_ts_compat.compile(
157165
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
2+
from .torch_compile import compile as torch_compile
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption(
6+
"--ir",
7+
metavar="Internal Representation",
8+
nargs=1,
9+
type=str,
10+
required=True,
11+
help="IR to compile with",
12+
choices=["torch_compile", "fx_ts_compat"],
13+
)
14+
15+
16+
@pytest.fixture
17+
def ir(request):
18+
return request.config.getoption("--ir")[0]
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import torch
2+
import timm
3+
import pytest
4+
5+
import torch_tensorrt as torchtrt
6+
import torchvision.models as models
7+
8+
from transformers import BertModel
9+
10+
from utils import COSINE_THRESHOLD, cosine_similarity
11+
12+
13+
@pytest.mark.unit
14+
def test_resnet18(ir):
15+
model = models.resnet18(pretrained=True).eval().to("cuda")
16+
input = torch.randn((1, 3, 224, 224)).to("cuda")
17+
18+
compile_spec = {
19+
"inputs": [
20+
torchtrt.Input(
21+
input.shape, dtype=torch.float, format=torch.contiguous_format
22+
)
23+
],
24+
"device": torchtrt.Device("cuda:0"),
25+
"enabled_precisions": {torch.float},
26+
"ir": ir,
27+
}
28+
29+
trt_mod = torchtrt.compile(model, **compile_spec)
30+
cos_sim = cosine_similarity(model(input), trt_mod(input))
31+
assert (
32+
cos_sim > COSINE_THRESHOLD,
33+
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
34+
)
35+
36+
37+
@pytest.mark.unit
38+
def test_mobilenet_v2(ir):
39+
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
40+
input = torch.randn((1, 3, 224, 224)).to("cuda")
41+
42+
compile_spec = {
43+
"inputs": [
44+
torchtrt.Input(
45+
input.shape, dtype=torch.float, format=torch.contiguous_format
46+
)
47+
],
48+
"device": torchtrt.Device("cuda:0"),
49+
"enabled_precisions": {torch.float},
50+
"ir": ir,
51+
}
52+
53+
trt_mod = torchtrt.compile(model, **compile_spec)
54+
cos_sim = cosine_similarity(model(input), trt_mod(input))
55+
assert (
56+
cos_sim > COSINE_THRESHOLD,
57+
f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
58+
)
59+
60+
61+
@pytest.mark.unit
62+
def test_efficientnet_b0(ir):
63+
model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
64+
input = torch.randn((1, 3, 224, 224)).to("cuda")
65+
66+
compile_spec = {
67+
"inputs": [
68+
torchtrt.Input(
69+
input.shape, dtype=torch.float, format=torch.contiguous_format
70+
)
71+
],
72+
"device": torchtrt.Device("cuda:0"),
73+
"enabled_precisions": {torch.float},
74+
"ir": ir,
75+
}
76+
77+
trt_mod = torchtrt.compile(model, **compile_spec)
78+
cos_sim = cosine_similarity(model(input), trt_mod(input))
79+
assert (
80+
cos_sim > COSINE_THRESHOLD,
81+
f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
82+
)
83+
84+
85+
@pytest.mark.unit
86+
def test_bert_base_uncased(ir):
87+
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
88+
input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
89+
input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda")
90+
91+
compile_spec = {
92+
"inputs": [
93+
torchtrt.Input(
94+
input.shape,
95+
dtype=input.dtype,
96+
format=torch.contiguous_format,
97+
),
98+
torchtrt.Input(
99+
input.shape,
100+
dtype=input.dtype,
101+
format=torch.contiguous_format,
102+
),
103+
],
104+
"device": torchtrt.Device("cuda:0"),
105+
"enabled_precisions": {torch.float},
106+
"truncate_long_and_double": True,
107+
"debug": True,
108+
"ir": ir,
109+
}
110+
trt_mod = torchtrt.compile(model, **compile_spec)
111+
112+
model_outputs = model(input, input2)
113+
trt_model_outputs = trt_mod(input, input2)
114+
for key in model_outputs.keys():
115+
out, trt_out = model_outputs[key], trt_model_outputs[key]
116+
cos_sim = cosine_similarity(out, trt_out)
117+
assert (
118+
cos_sim > COSINE_THRESHOLD,
119+
f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
120+
)
121+
122+
123+
@pytest.mark.unit
124+
def test_resnet18_half(ir):
125+
model = models.resnet18(pretrained=True).eval().to("cuda").half()
126+
input = torch.randn((1, 3, 224, 224)).to("cuda").half()
127+
128+
compile_spec = {
129+
"inputs": [
130+
torchtrt.Input(
131+
input.shape, dtype=torch.half, format=torch.contiguous_format
132+
)
133+
],
134+
"device": torchtrt.Device("cuda:0"),
135+
"enabled_precisions": {torch.half},
136+
"ir": ir,
137+
}
138+
139+
trt_mod = torchtrt.compile(model, **compile_spec)
140+
cos_sim = cosine_similarity(model(input), trt_mod(input))
141+
assert (
142+
cos_sim > COSINE_THRESHOLD,
143+
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
144+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
COSINE_THRESHOLD = 0.99
4+
5+
6+
def cosine_similarity(gt_tensor, pred_tensor):
7+
gt_tensor = gt_tensor.flatten().to(torch.float32)
8+
pred_tensor = pred_tensor.flatten().to(torch.float32)
9+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
10+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
11+
return 1.0
12+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
13+
res = res.cpu().detach().item()
14+
15+
return res
16+
17+
18+
def same_output_format(trt_output, torch_output):
19+
# For each encountered collection type, ensure the torch and trt outputs agree
20+
# on type and size, checking recursively through all member elements.
21+
if isinstance(trt_output, tuple):
22+
return (
23+
isinstance(torch_output, tuple)
24+
and (len(trt_output) == len(torch_output))
25+
and all(
26+
same_output_format(trt_entry, torch_entry)
27+
for trt_entry, torch_entry in zip(trt_output, torch_output)
28+
)
29+
)
30+
elif isinstance(trt_output, list):
31+
return (
32+
isinstance(torch_output, list)
33+
and (len(trt_output) == len(torch_output))
34+
and all(
35+
same_output_format(trt_entry, torch_entry)
36+
for trt_entry, torch_entry in zip(trt_output, torch_output)
37+
)
38+
)
39+
elif isinstance(trt_output, dict):
40+
return (
41+
isinstance(torch_output, dict)
42+
and (len(trt_output) == len(torch_output))
43+
and (trt_output.keys() == torch_output.keys())
44+
and all(
45+
same_output_format(trt_output[key], torch_output[key])
46+
for key in trt_output.keys()
47+
)
48+
)
49+
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
50+
raise AssertionError(
51+
"Unsupported output type 'set' encountered in output format check."
52+
)
53+
else:
54+
return type(trt_output) is type(torch_output)

0 commit comments

Comments
 (0)