13
13
# Set the USE_CXX11=1 to use cxx11_abi
14
14
USE_CXX11 = 0 if not 'USE_CXX11' in os .environ else os .environ ["USE_CXX11" ]
15
15
16
+ # Set the USE_HOST_DEPS=1 to use host dependencies for tests
17
+ USE_HOST_DEPS = 0 if not 'USE_HOST_DEPS' in os .environ else os .environ ["USE_HOST_DEPS" ]
18
+
16
19
SUPPORTED_PYTHON_VERSIONS = ["3.7" , "3.8" , "3.9" , "3.10" ]
17
20
18
21
nox .options .sessions = ["l0_api_tests-" + "{}.{}" .format (sys .version_info .major , sys .version_info .minor )]
@@ -22,15 +25,14 @@ def install_deps(session):
22
25
session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
23
26
session .install ("-r" , os .path .join (TOP_DIR , "tests" , "py" , "requirements.txt" ))
24
27
25
- def download_models (session , use_host_env = False ):
28
+ def download_models (session ):
26
29
print ("Downloading test models" )
27
30
session .install ("-r" , os .path .join (TOP_DIR , "tests" , "modules" , "requirements.txt" ))
28
31
print (TOP_DIR )
29
32
session .chdir (os .path .join (TOP_DIR , "tests" , "modules" ))
30
- if use_host_env :
33
+ if USE_HOST_DEPS :
31
34
session .run_always ('python' , 'hub.py' , env = {'PYTHONPATH' : PYT_PATH })
32
35
else :
33
- session .install ("-r" , os .path .join (TOP_DIR , "py" , "requirements.txt" ))
34
36
session .run_always ('python' , 'hub.py' )
35
37
36
38
def install_torch_trt (session ):
@@ -54,9 +56,9 @@ def download_datasets(session):
54
56
os .path .join (TOP_DIR , 'tests/accuracy/datasets/data/cidar-10-batches-bin' ),
55
57
external = True )
56
58
57
- def train_model (session , use_host_env = False ):
59
+ def train_model (session ):
58
60
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
59
- if use_host_env :
61
+ if USE_HOST_DEPS :
60
62
session .run_always ('python' ,
61
63
'main.py' ,
62
64
'--lr' , '0.01' ,
@@ -83,12 +85,12 @@ def train_model(session, use_host_env=False):
83
85
'export_ckpt.py' ,
84
86
'vgg16_ckpts/ckpt_epoch25.pth' )
85
87
86
- def finetune_model (session , use_host_env = False ):
88
+ def finetune_model (session ):
87
89
# Install pytorch-quantization dependency
88
90
session .install ('pytorch-quantization' , '--extra-index-url' , 'https://pypi.ngc.nvidia.com' )
89
91
session .chdir (os .path .join (TOP_DIR , 'examples/int8/training/vgg16' ))
90
92
91
- if use_host_env :
93
+ if USE_HOST_DEPS :
92
94
session .run_always ('python' ,
93
95
'finetune_qat.py' ,
94
96
'--lr' , '0.01' ,
@@ -134,25 +136,25 @@ def cleanup(session):
134
136
str ('rm -rf ' ) + target ,
135
137
external = True )
136
138
137
- def run_base_tests (session , use_host_env = False ):
139
+ def run_base_tests (session ):
138
140
print ("Running basic tests" )
139
141
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
140
142
tests = [
141
143
"test_api.py" ,
142
144
"test_to_backend_api.py" ,
143
145
]
144
146
for test in tests :
145
- if use_host_env :
147
+ if USE_HOST_DEPS :
146
148
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
147
149
else :
148
150
session .run_always ("python" , test )
149
151
150
- def run_accuracy_tests (session , use_host_env = False ):
152
+ def run_accuracy_tests (session ):
151
153
print ("Running accuracy tests" )
152
154
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
153
155
tests = []
154
156
for test in tests :
155
- if use_host_env :
157
+ if USE_HOST_DEPS :
156
158
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
157
159
else :
158
160
session .run_always ("python" , test )
@@ -170,7 +172,7 @@ def copy_model(session):
170
172
os .path .join (TOP_DIR , str ('tests/py/' ) + file_name ),
171
173
external = True )
172
174
173
- def run_int8_accuracy_tests (session , use_host_env = False ):
175
+ def run_int8_accuracy_tests (session ):
174
176
print ("Running accuracy tests" )
175
177
copy_model (session )
176
178
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
@@ -180,12 +182,12 @@ def run_int8_accuracy_tests(session, use_host_env=False):
180
182
"test_qat_trt_accuracy.py" ,
181
183
]
182
184
for test in tests :
183
- if use_host_env :
185
+ if USE_HOST_DEPS :
184
186
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
185
187
else :
186
188
session .run_always ("python" , test )
187
189
188
- def run_trt_compatibility_tests (session , use_host_env = False ):
190
+ def run_trt_compatibility_tests (session ):
189
191
print ("Running TensorRT compatibility tests" )
190
192
copy_model (session )
191
193
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
@@ -194,151 +196,121 @@ def run_trt_compatibility_tests(session, use_host_env=False):
194
196
"test_ptq_trt_calibrator.py" ,
195
197
]
196
198
for test in tests :
197
- if use_host_env :
199
+ if USE_HOST_DEPS :
198
200
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
199
201
else :
200
202
session .run_always ("python" , test )
201
203
202
- def run_dla_tests (session , use_host_env = False ):
204
+ def run_dla_tests (session ):
203
205
print ("Running DLA tests" )
204
206
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
205
207
tests = [
206
208
"test_api_dla.py" ,
207
209
]
208
210
for test in tests :
209
- if use_host_env :
211
+ if USE_HOST_DEPS :
210
212
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
211
213
else :
212
214
session .run_always ("python" , test )
213
215
214
- def run_multi_gpu_tests (session , use_host_env = False ):
216
+ def run_multi_gpu_tests (session ):
215
217
print ("Running multi GPU tests" )
216
218
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
217
219
tests = [
218
220
"test_multi_gpu.py" ,
219
221
]
220
222
for test in tests :
221
- if use_host_env :
223
+ if USE_HOST_DEPS :
222
224
session .run_always ('python' , test , env = {'PYTHONPATH' : PYT_PATH })
223
225
else :
224
226
session .run_always ("python" , test )
225
227
226
- def run_l0_api_tests (session , use_host_env = False ):
227
- if not use_host_env :
228
+ def run_l0_api_tests (session ):
229
+ if not USE_HOST_DEPS :
228
230
install_deps (session )
229
231
install_torch_trt (session )
230
- download_models (session , use_host_env )
231
- run_base_tests (session , use_host_env )
232
+ download_models (session )
233
+ run_base_tests (session )
232
234
cleanup (session )
233
235
234
- def run_l0_dla_tests (session , use_host_env = False ):
235
- if not use_host_env :
236
+ def run_l0_dla_tests (session ):
237
+ if not USE_HOST_DEPS :
236
238
install_deps (session )
237
239
install_torch_trt (session )
238
- download_models (session , use_host_env )
239
- run_base_tests (session , use_host_env )
240
+ download_models (session )
241
+ run_base_tests (session )
240
242
cleanup (session )
241
243
242
- def run_l1_accuracy_tests (session , use_host_env = False ):
243
- if not use_host_env :
244
+ def run_l1_accuracy_tests (session ):
245
+ if not USE_HOST_DEPS :
244
246
install_deps (session )
245
247
install_torch_trt (session )
246
- download_models (session , use_host_env )
248
+ download_models (session )
247
249
download_datasets (session )
248
- train_model (session , use_host_env )
249
- run_accuracy_tests (session , use_host_env )
250
+ train_model (session )
251
+ run_accuracy_tests (session )
250
252
cleanup (session )
251
253
252
- def run_l1_int8_accuracy_tests (session , use_host_env = False ):
253
- if not use_host_env :
254
+ def run_l1_int8_accuracy_tests (session ):
255
+ if not USE_HOST_DEPS :
254
256
install_deps (session )
255
257
install_torch_trt (session )
256
- download_models (session , use_host_env )
258
+ download_models (session )
257
259
download_datasets (session )
258
- train_model (session , use_host_env )
259
- finetune_model (session , use_host_env )
260
- run_int8_accuracy_tests (session , use_host_env )
260
+ train_model (session )
261
+ finetune_model (session )
262
+ run_int8_accuracy_tests (session )
261
263
cleanup (session )
262
264
263
- def run_l2_trt_compatibility_tests (session , use_host_env = False ):
264
- if not use_host_env :
265
+ def run_l2_trt_compatibility_tests (session ):
266
+ if not USE_HOST_DEPS :
265
267
install_deps (session )
266
268
install_torch_trt (session )
267
- download_models (session , use_host_env )
269
+ download_models (session )
268
270
download_datasets (session )
269
- train_model (session , use_host_env )
270
- run_trt_compatibility_tests (session , use_host_env )
271
+ train_model (session )
272
+ run_trt_compatibility_tests (session )
271
273
cleanup (session )
272
274
273
- def run_l2_multi_gpu_tests (session , use_host_env = False ):
274
- if not use_host_env :
275
+ def run_l2_multi_gpu_tests (session ):
276
+ if not USE_HOST_DEPS :
275
277
install_deps (session )
276
278
install_torch_trt (session )
277
- download_models (session , use_host_env )
278
- run_multi_gpu_tests (session , use_host_env )
279
+ download_models (session )
280
+ run_multi_gpu_tests (session )
279
281
cleanup (session )
280
282
281
283
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
282
284
def l0_api_tests (session ):
283
285
"""When a developer needs to check correctness for a PR or something"""
284
- run_l0_api_tests (session , use_host_env = False )
285
-
286
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
287
- def l0_api_tests_host_deps (session ):
288
- """When a developer needs to check basic api functionality using host dependencies"""
289
- run_l0_api_tests (session , use_host_env = True )
286
+ run_l0_api_tests (session )
290
287
291
288
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
292
- def l0_dla_tests_host_deps (session ):
289
+ def l0_dla_tests (session ):
293
290
"""When a developer needs to check basic api functionality using host dependencies"""
294
- run_l0_dla_tests (session , use_host_env = True )
291
+ run_l0_dla_tests (session )
295
292
296
293
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
297
294
def l1_accuracy_tests (session ):
298
295
"""Checking accuracy performance on various usecases"""
299
- run_l1_accuracy_tests (session , use_host_env = False )
300
-
301
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
302
- def l1_accuracy_tests_host_deps (session ):
303
- """Checking accuracy performance on various usecases using host dependencies"""
304
- run_l1_accuracy_tests (session , use_host_env = True )
296
+ run_l1_accuracy_tests (session )
305
297
306
298
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
307
299
def l1_int8_accuracy_tests (session ):
308
300
"""Checking accuracy performance on various usecases"""
309
- run_l1_int8_accuracy_tests (session , use_host_env = False )
310
-
311
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
312
- def l1_int8_accuracy_tests_host_deps (session ):
313
- """Checking accuracy performance on various usecases using host dependencies"""
314
- run_l1_int8_accuracy_tests (session , use_host_env = True )
301
+ run_l1_int8_accuracy_tests (session )
315
302
316
303
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
317
304
def l2_trt_compatibility_tests (session ):
318
305
"""Makes sure that TensorRT Python and Torch-TensorRT can work together"""
319
- run_l2_trt_compatibility_tests (session , use_host_env = False )
320
-
321
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
322
- def l2_trt_compatibility_tests_host_deps (session ):
323
- """Makes sure that TensorRT Python and Torch-TensorRT can work together using host dependencies"""
324
- run_l2_trt_compatibility_tests (session , use_host_env = True )
306
+ run_l2_trt_compatibility_tests (session )
325
307
326
308
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
327
309
def l2_multi_gpu_tests (session ):
328
310
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems"""
329
- run_l2_multi_gpu_tests (session , use_host_env = False )
330
-
331
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
332
- def l2_multi_gpu_tests_host_deps (session ):
333
- """Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
334
- run_l2_multi_gpu_tests (session , use_host_env = True )
311
+ run_l2_multi_gpu_tests (session )
335
312
336
313
@nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
337
314
def download_test_models (session ):
338
315
"""Grab all the models needed for testing"""
339
- download_models (session , use_host_env = False )
340
-
341
- @nox .session (python = SUPPORTED_PYTHON_VERSIONS , reuse_venv = True )
342
- def download_test_models_host_deps (session ):
343
- """Grab all the models needed for testing using host dependencies"""
344
- download_models (session , use_host_env = True )
316
+ download_models (session )
0 commit comments