Skip to content

Commit 2be072a

Browse files
committed
refactor: Rigging python tests in pytest for CI and Nox
Signed-off-by: Naren Dasan <[email protected]>
1 parent 5ad9826 commit 2be072a

19 files changed

+844
-838
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,6 @@ bazel-Torch-TensorRT-Preview
6262
docsrc/src/
6363
bazel-TensorRT
6464
bazel-tensorrt
65+
.pytest_cache
66+
*.cache
67+
*cifar-10-batches-py*
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
torch>=1.10.0
22
tensorboard>=1.14.0
3-
pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com
3+
nvidia-pyindex
4+
--extra-index-url https://pypi.ngc.nvidia.com
5+
pytorch-quantization>=2.1.2
6+
tqdm

noxfile.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55

66
# Use system installed Python packages
77
PYT_PATH='/opt/conda/lib/python3.8/site-packages' if not 'PYT_PATH' in os.environ else os.environ["PYT_PATH"]
8+
print(f"Using python path {PYT_PATH}")
89

910
# Set the root directory to the directory of the noxfile unless the user wants to
1011
# TOP_DIR
1112
TOP_DIR=os.path.dirname(os.path.realpath(__file__)) if not 'TOP_DIR' in os.environ else os.environ["TOP_DIR"]
13+
print(f"Test root directory {TOP_DIR}")
1214

1315
# Set the USE_CXX11=1 to use cxx11_abi
1416
USE_CXX11=0 if not 'USE_CXX11' in os.environ else os.environ["USE_CXX11"]
17+
if USE_CXX11:
18+
print("Using cxx11 abi")
1519

1620
# Set the USE_HOST_DEPS=1 to use host dependencies for tests
1721
USE_HOST_DEPS=0 if not 'USE_HOST_DEPS' in os.environ else os.environ["USE_HOST_DEPS"]
22+
if USE_HOST_DEPS:
23+
print("Using dependencies from host python")
1824

1925
SUPPORTED_PYTHON_VERSIONS=["3.7", "3.8", "3.9", "3.10"]
2026

@@ -58,6 +64,12 @@ def download_datasets(session):
5864

5965
def train_model(session):
6066
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
67+
session.install("-r", "requirements.txt")
68+
if os.path.exists('vgg16_ckpts/ckpt_epoch25.pth'):
69+
session.run_always('python',
70+
'export_ckpt.py',
71+
'vgg16_ckpts/ckpt_epoch25.pth')
72+
return
6173
if USE_HOST_DEPS:
6274
session.run_always('python',
6375
'main.py',
@@ -140,14 +152,14 @@ def run_base_tests(session):
140152
print("Running basic tests")
141153
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
142154
tests = [
143-
"test_api.py",
144-
"test_to_backend_api.py",
155+
"api",
156+
"integrations/test_to_backend_api.py",
145157
]
146158
for test in tests:
147159
if USE_HOST_DEPS:
148-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
160+
session.run_always('pytest', test, env={'PYTHONPATH': PYT_PATH})
149161
else:
150-
session.run_always("python", test)
162+
session.run_always("pytest", test)
151163

152164
def run_accuracy_tests(session):
153165
print("Running accuracy tests")
@@ -169,23 +181,23 @@ def copy_model(session):
169181
session.run_always('cp',
170182
'-rpf',
171183
os.path.join(TOP_DIR, src_file),
172-
os.path.join(TOP_DIR, str('tests/py/') + file_name),
184+
os.path.join(TOP_DIR, str('tests/modules/') + file_name),
173185
external=True)
174186

175187
def run_int8_accuracy_tests(session):
176188
print("Running accuracy tests")
177189
copy_model(session)
178190
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
179191
tests = [
180-
"test_ptq_dataloader_calibrator.py",
181-
"test_ptq_to_backend.py",
182-
"test_qat_trt_accuracy.py",
192+
"ptq/test_ptq_to_backend.py",
193+
"ptq/test_ptq_dataloader_calibrator.py",
194+
"qat/",
183195
]
184196
for test in tests:
185197
if USE_HOST_DEPS:
186-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
198+
session.run_always('pytest', test, env={'PYTHONPATH': PYT_PATH})
187199
else:
188-
session.run_always("python", test)
200+
session.run_always("pytest", test)
189201

190202
def run_trt_compatibility_tests(session):
191203
print("Running TensorRT compatibility tests")
@@ -197,9 +209,9 @@ def run_trt_compatibility_tests(session):
197209
]
198210
for test in tests:
199211
if USE_HOST_DEPS:
200-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
212+
session.run_always('pytest', test, env={'PYTHONPATH': PYT_PATH})
201213
else:
202-
session.run_always("python", test)
214+
session.run_always("pytest", test)
203215

204216
def run_dla_tests(session):
205217
print("Running DLA tests")
@@ -209,9 +221,9 @@ def run_dla_tests(session):
209221
]
210222
for test in tests:
211223
if USE_HOST_DEPS:
212-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
224+
session.run_always('pytest', test, env={'PYTHONPATH': PYT_PATH})
213225
else:
214-
session.run_always("python", test)
226+
session.run_always("pytest", test)
215227

216228
def run_multi_gpu_tests(session):
217229
print("Running multi GPU tests")
@@ -221,9 +233,9 @@ def run_multi_gpu_tests(session):
221233
]
222234
for test in tests:
223235
if USE_HOST_DEPS:
224-
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
236+
session.run_always('pytest', test, env={'PYTHONPATH': PYT_PATH})
225237
else:
226-
session.run_always("python", test)
238+
session.run_always("pytest", test)
227239

228240
def run_l0_api_tests(session):
229241
if not USE_HOST_DEPS:
@@ -245,7 +257,6 @@ def run_l1_accuracy_tests(session):
245257
if not USE_HOST_DEPS:
246258
install_deps(session)
247259
install_torch_trt(session)
248-
download_models(session)
249260
download_datasets(session)
250261
train_model(session)
251262
run_accuracy_tests(session)
@@ -255,7 +266,6 @@ def run_l1_int8_accuracy_tests(session):
255266
if not USE_HOST_DEPS:
256267
install_deps(session)
257268
install_torch_trt(session)
258-
download_models(session)
259269
download_datasets(session)
260270
train_model(session)
261271
finetune_model(session)
@@ -313,4 +323,8 @@ def l2_multi_gpu_tests(session):
313323
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
314324
def download_test_models(session):
315325
"""Grab all the models needed for testing"""
326+
try:
327+
import torch
328+
except ModuleNotFoundError:
329+
install_deps(session)
316330
download_models(session)

tests/modules/hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"model": timm.create_model('vit_base_patch16_224', pretrained=True),
8181
"path": "script"
8282
},
83-
"pool": {
83+
"pooling": {
8484
"model": cm.Pool(),
8585
"path": "trace"
8686
},
@@ -104,7 +104,7 @@
104104
"model": cm.FallbackInplaceOPIf(),
105105
"path": "script"
106106
},
107-
"bert-base-uncased": {
107+
"bert_base_uncased": {
108108
"model": cm.BertModule(),
109109
"path": "trace"
110110
}

tests/modules/requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
-f https://download.pytorch.org/whl/torch_stable.html
2-
#torch==1.11.0+cu113
31
timm==v0.4.12
42
transformers==4.17.0

tests/py/api/test_classes.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import unittest
2+
import torch_tensorrt as torchtrt
3+
import torch
4+
import torchvision.models as models
5+
import copy
6+
from typing import Dict
7+
8+
class TestDevice(unittest.TestCase):
9+
10+
def test_from_string_constructor(self):
11+
device = torchtrt.Device("cuda:0")
12+
self.assertEqual(device.device_type, torchtrt.DeviceType.GPU)
13+
self.assertEqual(device.gpu_id, 0)
14+
15+
device = torchtrt.Device("gpu:1")
16+
self.assertEqual(device.device_type, torchtrt.DeviceType.GPU)
17+
self.assertEqual(device.gpu_id, 1)
18+
19+
def test_from_string_constructor_dla(self):
20+
device = torchtrt.Device("dla:0")
21+
self.assertEqual(device.device_type, torchtrt.DeviceType.DLA)
22+
self.assertEqual(device.gpu_id, 0)
23+
self.assertEqual(device.dla_core, 0)
24+
25+
device = torchtrt.Device("dla:1", allow_gpu_fallback=True)
26+
self.assertEqual(device.device_type, torchtrt.DeviceType.DLA)
27+
self.assertEqual(device.gpu_id, 0)
28+
self.assertEqual(device.dla_core, 1)
29+
self.assertEqual(device.allow_gpu_fallback, True)
30+
31+
def test_kwargs_gpu(self):
32+
device = torchtrt.Device(gpu_id=0)
33+
self.assertEqual(device.device_type, torchtrt.DeviceType.GPU)
34+
self.assertEqual(device.gpu_id, 0)
35+
36+
def test_kwargs_dla_and_settings(self):
37+
device = torchtrt.Device(dla_core=1, allow_gpu_fallback=False)
38+
self.assertEqual(device.device_type, torchtrt.DeviceType.DLA)
39+
self.assertEqual(device.gpu_id, 0)
40+
self.assertEqual(device.dla_core, 1)
41+
self.assertEqual(device.allow_gpu_fallback, False)
42+
43+
device = torchtrt.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
44+
self.assertEqual(device.device_type, torchtrt.DeviceType.DLA)
45+
self.assertEqual(device.gpu_id, 1)
46+
self.assertEqual(device.dla_core, 0)
47+
self.assertEqual(device.allow_gpu_fallback, True)
48+
49+
def test_from_torch(self):
50+
device = torchtrt.Device._from_torch_device(torch.device("cuda:0"))
51+
self.assertEqual(device.device_type, torchtrt.DeviceType.GPU)
52+
self.assertEqual(device.gpu_id, 0)
53+
54+
55+
class TestInput(unittest.TestCase):
56+
57+
def _verify_correctness(self, struct: torchtrt.Input, target: Dict) -> bool:
58+
internal = struct._to_internal()
59+
60+
list_eq = lambda al, bl: all([a == b for (a, b) in zip(al, bl)])
61+
62+
eq = lambda a, b: a == b
63+
64+
def field_is_correct(field, equal_fn, a1, a2):
65+
equal = equal_fn(a1, a2)
66+
if not equal:
67+
print("\nField {} is incorrect: {} != {}".format(field, a1, a2))
68+
return equal
69+
70+
min_ = field_is_correct("min", list_eq, internal.min, target["min"])
71+
opt_ = field_is_correct("opt", list_eq, internal.opt, target["opt"])
72+
max_ = field_is_correct("max", list_eq, internal.max, target["max"])
73+
is_dynamic_ = field_is_correct("is_dynamic", eq, internal.input_is_dynamic, target["input_is_dynamic"])
74+
explicit_set_dtype_ = field_is_correct("explicit_dtype", eq, internal._explicit_set_dtype,
75+
target["explicit_set_dtype"])
76+
dtype_ = field_is_correct("dtype", eq, int(internal.dtype), int(target["dtype"]))
77+
format_ = field_is_correct("format", eq, int(internal.format), int(target["format"]))
78+
79+
return all([min_, opt_, max_, is_dynamic_, explicit_set_dtype_, dtype_, format_])
80+
81+
def test_infer_from_example_tensor(self):
82+
shape = [1, 3, 255, 255]
83+
target = {
84+
"min": shape,
85+
"opt": shape,
86+
"max": shape,
87+
"input_is_dynamic": False,
88+
"dtype": torchtrt.dtype.half,
89+
"format": torchtrt.TensorFormat.contiguous,
90+
"explicit_set_dtype": True
91+
}
92+
93+
example_tensor = torch.randn(shape).half()
94+
i = torchtrt.Input._from_tensor(example_tensor)
95+
self.assertTrue(self._verify_correctness(i, target))
96+
97+
def test_static_shape(self):
98+
shape = [1, 3, 255, 255]
99+
target = {
100+
"min": shape,
101+
"opt": shape,
102+
"max": shape,
103+
"input_is_dynamic": False,
104+
"dtype": torchtrt.dtype.unknown,
105+
"format": torchtrt.TensorFormat.contiguous,
106+
"explicit_set_dtype": False
107+
}
108+
109+
i = torchtrt.Input(shape)
110+
self.assertTrue(self._verify_correctness(i, target))
111+
112+
i = torchtrt.Input(tuple(shape))
113+
self.assertTrue(self._verify_correctness(i, target))
114+
115+
i = torchtrt.Input(torch.randn(shape).shape)
116+
self.assertTrue(self._verify_correctness(i, target))
117+
118+
i = torchtrt.Input(shape=shape)
119+
self.assertTrue(self._verify_correctness(i, target))
120+
121+
i = torchtrt.Input(shape=tuple(shape))
122+
self.assertTrue(self._verify_correctness(i, target))
123+
124+
i = torchtrt.Input(shape=torch.randn(shape).shape)
125+
self.assertTrue(self._verify_correctness(i, target))
126+
127+
def test_data_type(self):
128+
shape = [1, 3, 255, 255]
129+
target = {
130+
"min": shape,
131+
"opt": shape,
132+
"max": shape,
133+
"input_is_dynamic": False,
134+
"dtype": torchtrt.dtype.half,
135+
"format": torchtrt.TensorFormat.contiguous,
136+
"explicit_set_dtype": True
137+
}
138+
139+
i = torchtrt.Input(shape, dtype=torchtrt.dtype.half)
140+
self.assertTrue(self._verify_correctness(i, target))
141+
142+
i = torchtrt.Input(shape, dtype=torch.half)
143+
self.assertTrue(self._verify_correctness(i, target))
144+
145+
def test_tensor_format(self):
146+
shape = [1, 3, 255, 255]
147+
target = {
148+
"min": shape,
149+
"opt": shape,
150+
"max": shape,
151+
"input_is_dynamic": False,
152+
"dtype": torchtrt.dtype.unknown,
153+
"format": torchtrt.TensorFormat.channels_last,
154+
"explicit_set_dtype": False
155+
}
156+
157+
i = torchtrt.Input(shape, format=torchtrt.TensorFormat.channels_last)
158+
self.assertTrue(self._verify_correctness(i, target))
159+
160+
i = torchtrt.Input(shape, format=torch.channels_last)
161+
self.assertTrue(self._verify_correctness(i, target))
162+
163+
def test_dynamic_shape(self):
164+
min_shape = [1, 3, 128, 128]
165+
opt_shape = [1, 3, 256, 256]
166+
max_shape = [1, 3, 512, 512]
167+
target = {
168+
"min": min_shape,
169+
"opt": opt_shape,
170+
"max": max_shape,
171+
"input_is_dynamic": True,
172+
"dtype": torchtrt.dtype.unknown,
173+
"format": torchtrt.TensorFormat.contiguous,
174+
"explicit_set_dtype": False
175+
}
176+
177+
i = torchtrt.Input(min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape)
178+
self.assertTrue(self._verify_correctness(i, target))
179+
180+
i = torchtrt.Input(min_shape=tuple(min_shape), opt_shape=tuple(opt_shape), max_shape=tuple(max_shape))
181+
self.assertTrue(self._verify_correctness(i, target))
182+
183+
tensor_shape = lambda shape: torch.randn(shape).shape
184+
i = torchtrt.Input(min_shape=tensor_shape(min_shape),
185+
opt_shape=tensor_shape(opt_shape),
186+
max_shape=tensor_shape(max_shape))
187+
self.assertTrue(self._verify_correctness(i, target))
188+
189+
if __name__ == "__main__":
190+
unittest.main()

0 commit comments

Comments
 (0)