Skip to content

Commit dcf3386

Browse files
authored
Merge pull request #1057 from pytorch/anuragd/reduce_nox_sessions
feat(nox): Replacing session with environment variable
2 parents e9e824c + b02d4da commit dcf3386

File tree

2 files changed

+95
-86
lines changed

2 files changed

+95
-86
lines changed

README.md

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,43 @@ docker run -it -v$(pwd)/..:/workspace/Torch-TensorRT build_torch_tensorrt_wheel
250250

251251
Python compilation expects using the tarball based compilation strategy from above.
252252

253+
254+
## Testing using Python backend
255+
256+
Torch-TensorRT supports testing in Python using [nox](https://nox.thea.codes/en/stable)
257+
258+
To install the nox using python-pip
259+
260+
```
261+
python3 -m pip install --upgrade nox
262+
```
263+
264+
To list supported nox sessions:
265+
266+
```
267+
nox --session -l
268+
```
269+
270+
Environment variables supported by nox
271+
272+
```
273+
PYT_PATH - To use different PYTHONPATH than system installed Python packages
274+
TOP_DIR - To set the root directory of the noxfile
275+
USE_CXX11 - To use cxx11_abi (Defaults to 0)
276+
USE_HOST_DEPS - To use host dependencies for tests (Defaults to 0)
277+
```
278+
279+
Usage example
280+
281+
```
282+
nox --session l0_api_tests
283+
```
284+
285+
Supported Python versions:
286+
```
287+
["3.7", "3.8", "3.9", "3.10"]
288+
```
289+
253290
## How do I add support for a new op...
254291

255292
### In Torch-TensorRT?
@@ -279,4 +316,4 @@ Take a look at the [CONTRIBUTING.md](CONTRIBUTING.md)
279316

280317
## License
281318

282-
The Torch-TensorRT license can be found in the LICENSE file. It is licensed with a BSD Style licence
319+
The Torch-TensorRT license can be found in the LICENSE file. It is licensed with a BSD Style licence

noxfile.py

Lines changed: 57 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# Set the USE_CXX11=1 to use cxx11_abi
1414
USE_CXX11=0 if not 'USE_CXX11' in os.environ else os.environ["USE_CXX11"]
1515

16+
# Set the USE_HOST_DEPS=1 to use host dependencies for tests
17+
USE_HOST_DEPS=0 if not 'USE_HOST_DEPS' in os.environ else os.environ["USE_HOST_DEPS"]
18+
1619
SUPPORTED_PYTHON_VERSIONS=["3.7", "3.8", "3.9", "3.10"]
1720

1821
nox.options.sessions = ["l0_api_tests-" + "{}.{}".format(sys.version_info.major, sys.version_info.minor)]
@@ -22,15 +25,14 @@ def install_deps(session):
2225
session.install("-r", os.path.join(TOP_DIR, "py", "requirements.txt"))
2326
session.install("-r", os.path.join(TOP_DIR, "tests", "py", "requirements.txt"))
2427

25-
def download_models(session, use_host_env=False):
28+
def download_models(session):
2629
print("Downloading test models")
2730
session.install("-r", os.path.join(TOP_DIR, "tests", "modules", "requirements.txt"))
2831
print(TOP_DIR)
2932
session.chdir(os.path.join(TOP_DIR, "tests", "modules"))
30-
if use_host_env:
33+
if USE_HOST_DEPS:
3134
session.run_always('python', 'hub.py', env={'PYTHONPATH': PYT_PATH})
3235
else:
33-
session.install("-r", os.path.join(TOP_DIR, "py", "requirements.txt"))
3436
session.run_always('python', 'hub.py')
3537

3638
def install_torch_trt(session):
@@ -54,9 +56,9 @@ def download_datasets(session):
5456
os.path.join(TOP_DIR, 'tests/accuracy/datasets/data/cidar-10-batches-bin'),
5557
external=True)
5658

57-
def train_model(session, use_host_env=False):
59+
def train_model(session):
5860
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
59-
if use_host_env:
61+
if USE_HOST_DEPS:
6062
session.run_always('python',
6163
'main.py',
6264
'--lr', '0.01',
@@ -83,12 +85,12 @@ def train_model(session, use_host_env=False):
8385
'export_ckpt.py',
8486
'vgg16_ckpts/ckpt_epoch25.pth')
8587

86-
def finetune_model(session, use_host_env=False):
88+
def finetune_model(session):
8789
# Install pytorch-quantization dependency
8890
session.install('pytorch-quantization', '--extra-index-url', 'https://pypi.ngc.nvidia.com')
8991
session.chdir(os.path.join(TOP_DIR, 'examples/int8/training/vgg16'))
9092

91-
if use_host_env:
93+
if USE_HOST_DEPS:
9294
session.run_always('python',
9395
'finetune_qat.py',
9496
'--lr', '0.01',
@@ -134,25 +136,25 @@ def cleanup(session):
134136
str('rm -rf ') + target,
135137
external=True)
136138

137-
def run_base_tests(session, use_host_env=False):
139+
def run_base_tests(session):
138140
print("Running basic tests")
139141
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
140142
tests = [
141143
"test_api.py",
142144
"test_to_backend_api.py",
143145
]
144146
for test in tests:
145-
if use_host_env:
147+
if USE_HOST_DEPS:
146148
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
147149
else:
148150
session.run_always("python", test)
149151

150-
def run_accuracy_tests(session, use_host_env=False):
152+
def run_accuracy_tests(session):
151153
print("Running accuracy tests")
152154
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
153155
tests = []
154156
for test in tests:
155-
if use_host_env:
157+
if USE_HOST_DEPS:
156158
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
157159
else:
158160
session.run_always("python", test)
@@ -170,7 +172,7 @@ def copy_model(session):
170172
os.path.join(TOP_DIR, str('tests/py/') + file_name),
171173
external=True)
172174

173-
def run_int8_accuracy_tests(session, use_host_env=False):
175+
def run_int8_accuracy_tests(session):
174176
print("Running accuracy tests")
175177
copy_model(session)
176178
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
@@ -180,12 +182,12 @@ def run_int8_accuracy_tests(session, use_host_env=False):
180182
"test_qat_trt_accuracy.py",
181183
]
182184
for test in tests:
183-
if use_host_env:
185+
if USE_HOST_DEPS:
184186
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
185187
else:
186188
session.run_always("python", test)
187189

188-
def run_trt_compatibility_tests(session, use_host_env=False):
190+
def run_trt_compatibility_tests(session):
189191
print("Running TensorRT compatibility tests")
190192
copy_model(session)
191193
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
@@ -194,151 +196,121 @@ def run_trt_compatibility_tests(session, use_host_env=False):
194196
"test_ptq_trt_calibrator.py",
195197
]
196198
for test in tests:
197-
if use_host_env:
199+
if USE_HOST_DEPS:
198200
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
199201
else:
200202
session.run_always("python", test)
201203

202-
def run_dla_tests(session, use_host_env=False):
204+
def run_dla_tests(session):
203205
print("Running DLA tests")
204206
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
205207
tests = [
206208
"test_api_dla.py",
207209
]
208210
for test in tests:
209-
if use_host_env:
211+
if USE_HOST_DEPS:
210212
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
211213
else:
212214
session.run_always("python", test)
213215

214-
def run_multi_gpu_tests(session, use_host_env=False):
216+
def run_multi_gpu_tests(session):
215217
print("Running multi GPU tests")
216218
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
217219
tests = [
218220
"test_multi_gpu.py",
219221
]
220222
for test in tests:
221-
if use_host_env:
223+
if USE_HOST_DEPS:
222224
session.run_always('python', test, env={'PYTHONPATH': PYT_PATH})
223225
else:
224226
session.run_always("python", test)
225227

226-
def run_l0_api_tests(session, use_host_env=False):
227-
if not use_host_env:
228+
def run_l0_api_tests(session):
229+
if not USE_HOST_DEPS:
228230
install_deps(session)
229231
install_torch_trt(session)
230-
download_models(session, use_host_env)
231-
run_base_tests(session, use_host_env)
232+
download_models(session)
233+
run_base_tests(session)
232234
cleanup(session)
233235

234-
def run_l0_dla_tests(session, use_host_env=False):
235-
if not use_host_env:
236+
def run_l0_dla_tests(session):
237+
if not USE_HOST_DEPS:
236238
install_deps(session)
237239
install_torch_trt(session)
238-
download_models(session, use_host_env)
239-
run_base_tests(session, use_host_env)
240+
download_models(session)
241+
run_base_tests(session)
240242
cleanup(session)
241243

242-
def run_l1_accuracy_tests(session, use_host_env=False):
243-
if not use_host_env:
244+
def run_l1_accuracy_tests(session):
245+
if not USE_HOST_DEPS:
244246
install_deps(session)
245247
install_torch_trt(session)
246-
download_models(session, use_host_env)
248+
download_models(session)
247249
download_datasets(session)
248-
train_model(session, use_host_env)
249-
run_accuracy_tests(session, use_host_env)
250+
train_model(session)
251+
run_accuracy_tests(session)
250252
cleanup(session)
251253

252-
def run_l1_int8_accuracy_tests(session, use_host_env=False):
253-
if not use_host_env:
254+
def run_l1_int8_accuracy_tests(session):
255+
if not USE_HOST_DEPS:
254256
install_deps(session)
255257
install_torch_trt(session)
256-
download_models(session, use_host_env)
258+
download_models(session)
257259
download_datasets(session)
258-
train_model(session, use_host_env)
259-
finetune_model(session, use_host_env)
260-
run_int8_accuracy_tests(session, use_host_env)
260+
train_model(session)
261+
finetune_model(session)
262+
run_int8_accuracy_tests(session)
261263
cleanup(session)
262264

263-
def run_l2_trt_compatibility_tests(session, use_host_env=False):
264-
if not use_host_env:
265+
def run_l2_trt_compatibility_tests(session):
266+
if not USE_HOST_DEPS:
265267
install_deps(session)
266268
install_torch_trt(session)
267-
download_models(session, use_host_env)
269+
download_models(session)
268270
download_datasets(session)
269-
train_model(session, use_host_env)
270-
run_trt_compatibility_tests(session, use_host_env)
271+
train_model(session)
272+
run_trt_compatibility_tests(session)
271273
cleanup(session)
272274

273-
def run_l2_multi_gpu_tests(session, use_host_env=False):
274-
if not use_host_env:
275+
def run_l2_multi_gpu_tests(session):
276+
if not USE_HOST_DEPS:
275277
install_deps(session)
276278
install_torch_trt(session)
277-
download_models(session, use_host_env)
278-
run_multi_gpu_tests(session, use_host_env)
279+
download_models(session)
280+
run_multi_gpu_tests(session)
279281
cleanup(session)
280282

281283
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
282284
def l0_api_tests(session):
283285
"""When a developer needs to check correctness for a PR or something"""
284-
run_l0_api_tests(session, use_host_env=False)
285-
286-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
287-
def l0_api_tests_host_deps(session):
288-
"""When a developer needs to check basic api functionality using host dependencies"""
289-
run_l0_api_tests(session, use_host_env=True)
286+
run_l0_api_tests(session)
290287

291288
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
292-
def l0_dla_tests_host_deps(session):
289+
def l0_dla_tests(session):
293290
"""When a developer needs to check basic api functionality using host dependencies"""
294-
run_l0_dla_tests(session, use_host_env=True)
291+
run_l0_dla_tests(session)
295292

296293
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
297294
def l1_accuracy_tests(session):
298295
"""Checking accuracy performance on various usecases"""
299-
run_l1_accuracy_tests(session, use_host_env=False)
300-
301-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
302-
def l1_accuracy_tests_host_deps(session):
303-
"""Checking accuracy performance on various usecases using host dependencies"""
304-
run_l1_accuracy_tests(session, use_host_env=True)
296+
run_l1_accuracy_tests(session)
305297

306298
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
307299
def l1_int8_accuracy_tests(session):
308300
"""Checking accuracy performance on various usecases"""
309-
run_l1_int8_accuracy_tests(session, use_host_env=False)
310-
311-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
312-
def l1_int8_accuracy_tests_host_deps(session):
313-
"""Checking accuracy performance on various usecases using host dependencies"""
314-
run_l1_int8_accuracy_tests(session, use_host_env=True)
301+
run_l1_int8_accuracy_tests(session)
315302

316303
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
317304
def l2_trt_compatibility_tests(session):
318305
"""Makes sure that TensorRT Python and Torch-TensorRT can work together"""
319-
run_l2_trt_compatibility_tests(session, use_host_env=False)
320-
321-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
322-
def l2_trt_compatibility_tests_host_deps(session):
323-
"""Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
324-
run_l2_trt_compatibility_tests(session, use_host_env=True)
306+
run_l2_trt_compatibility_tests(session)
325307

326308
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
327309
def l2_multi_gpu_tests(session):
328310
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
329-
run_l2_multi_gpu_tests(session, use_host_env=False)
330-
331-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
332-
def l2_multi_gpu_tests_host_deps(session):
333-
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
334-
run_l2_multi_gpu_tests(session, use_host_env=True)
311+
run_l2_multi_gpu_tests(session)
335312

336313
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
337314
def download_test_models(session):
338315
"""Grab all the models needed for testing"""
339-
download_models(session, use_host_env=False)
340-
341-
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True)
342-
def download_test_models_host_deps(session):
343-
"""Grab all the models needed for testing using host dependencies"""
344-
download_models(session, use_host_env=True)
316+
download_models(session)

0 commit comments

Comments
 (0)