-
Notifications
You must be signed in to change notification settings - Fork 364
refactor(//tests) : Refactor the test suite #1329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5756169
c6f3103
beeac7c
7e6b36c
ed75e9d
3da78e9
0ca049f
c864096
8d8cbfd
13cc024
749048c
af20761
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,9 @@ | |
if USE_HOST_DEPS: | ||
print("Using dependencies from host python") | ||
|
||
# Set epochs to train VGG model for accuracy tests | ||
EPOCHS = 25 | ||
|
||
SUPPORTED_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"] | ||
|
||
nox.options.sessions = [ | ||
|
@@ -63,31 +66,6 @@ def install_torch_trt(session): | |
session.run("python", "setup.py", "develop") | ||
|
||
|
||
def download_datasets(session): | ||
print( | ||
"Downloading dataset to path", | ||
os.path.join(TOP_DIR, "examples/int8/training/vgg16"), | ||
) | ||
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16")) | ||
session.run_always( | ||
"wget", "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz", external=True | ||
) | ||
session.run_always("tar", "-xvzf", "cifar-10-binary.tar.gz", external=True) | ||
session.run_always( | ||
"mkdir", | ||
"-p", | ||
os.path.join(TOP_DIR, "tests/accuracy/datasets/data"), | ||
external=True, | ||
) | ||
session.run_always( | ||
"cp", | ||
"-rpf", | ||
os.path.join(TOP_DIR, "examples/int8/training/vgg16/cifar-10-batches-bin"), | ||
os.path.join(TOP_DIR, "tests/accuracy/datasets/data/cidar-10-batches-bin"), | ||
external=True, | ||
) | ||
|
||
|
||
def train_model(session): | ||
session.chdir(os.path.join(TOP_DIR, "examples/int8/training/vgg16")) | ||
session.install("-r", "requirements.txt") | ||
|
@@ -107,14 +85,14 @@ def train_model(session): | |
"--ckpt-dir", | ||
"vgg16_ckpts", | ||
"--epochs", | ||
"25", | ||
str(EPOCHS), | ||
env={"PYTHONPATH": PYT_PATH}, | ||
) | ||
|
||
session.run_always( | ||
"python", | ||
"export_ckpt.py", | ||
"vgg16_ckpts/ckpt_epoch25.pth", | ||
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth", | ||
env={"PYTHONPATH": PYT_PATH}, | ||
) | ||
else: | ||
|
@@ -130,10 +108,12 @@ def train_model(session): | |
"--ckpt-dir", | ||
"vgg16_ckpts", | ||
"--epochs", | ||
"25", | ||
str(EPOCHS), | ||
) | ||
|
||
session.run_always("python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch25.pth") | ||
session.run_always( | ||
"python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth" | ||
) | ||
|
||
|
||
def finetune_model(session): | ||
|
@@ -156,17 +136,17 @@ def finetune_model(session): | |
"--ckpt-dir", | ||
"vgg16_ckpts", | ||
"--start-from", | ||
"25", | ||
str(EPOCHS), | ||
"--epochs", | ||
"26", | ||
str(EPOCHS + 1), | ||
env={"PYTHONPATH": PYT_PATH}, | ||
) | ||
|
||
# Export model | ||
session.run_always( | ||
"python", | ||
"export_qat.py", | ||
"vgg16_ckpts/ckpt_epoch26.pth", | ||
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", | ||
env={"PYTHONPATH": PYT_PATH}, | ||
) | ||
else: | ||
|
@@ -182,13 +162,17 @@ def finetune_model(session): | |
"--ckpt-dir", | ||
"vgg16_ckpts", | ||
"--start-from", | ||
"25", | ||
str(EPOCHS), | ||
"--epochs", | ||
"26", | ||
str(EPOCHS + 1), | ||
) | ||
|
||
# Export model | ||
session.run_always("python", "export_qat.py", "vgg16_ckpts/ckpt_epoch26.pth") | ||
session.run_always( | ||
"python", | ||
"export_qat.py", | ||
"vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", | ||
) | ||
|
||
|
||
def cleanup(session): | ||
|
@@ -219,6 +203,19 @@ def run_base_tests(session): | |
session.run_always("pytest", test) | ||
|
||
|
||
def run_model_tests(session): | ||
print("Running model tests") | ||
session.chdir(os.path.join(TOP_DIR, "tests/py")) | ||
tests = [ | ||
"models", | ||
] | ||
for test in tests: | ||
if USE_HOST_DEPS: | ||
session.run_always("pytest", test, env={"PYTHONPATH": PYT_PATH}) | ||
else: | ||
session.run_always("pytest", test) | ||
|
||
|
||
def run_accuracy_tests(session): | ||
print("Running accuracy tests") | ||
session.chdir(os.path.join(TOP_DIR, "tests/py")) | ||
|
@@ -268,8 +265,8 @@ def run_trt_compatibility_tests(session): | |
copy_model(session) | ||
session.chdir(os.path.join(TOP_DIR, "tests/py")) | ||
tests = [ | ||
"test_trt_intercompatibility.py", | ||
"test_ptq_trt_calibrator.py", | ||
"integrations/test_trt_intercompatibility.py", | ||
# "ptq/test_ptq_trt_calibrator.py", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this test disabled? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the test with pybind issue that has been there for a while. This used to pass in NGC containers though. |
||
] | ||
for test in tests: | ||
if USE_HOST_DEPS: | ||
|
@@ -282,7 +279,7 @@ def run_dla_tests(session): | |
print("Running DLA tests") | ||
session.chdir(os.path.join(TOP_DIR, "tests/py")) | ||
tests = [ | ||
"test_api_dla.py", | ||
"hw/test_api_dla.py", | ||
] | ||
for test in tests: | ||
if USE_HOST_DEPS: | ||
|
@@ -295,7 +292,7 @@ def run_multi_gpu_tests(session): | |
print("Running multi GPU tests") | ||
session.chdir(os.path.join(TOP_DIR, "tests/py")) | ||
tests = [ | ||
"test_multi_gpu.py", | ||
"hw/test_multi_gpu.py", | ||
] | ||
for test in tests: | ||
if USE_HOST_DEPS: | ||
|
@@ -322,21 +319,19 @@ def run_l0_dla_tests(session): | |
cleanup(session) | ||
|
||
|
||
def run_l1_accuracy_tests(session): | ||
def run_l1_model_tests(session): | ||
if not USE_HOST_DEPS: | ||
install_deps(session) | ||
install_torch_trt(session) | ||
download_datasets(session) | ||
train_model(session) | ||
run_accuracy_tests(session) | ||
download_models(session) | ||
run_model_tests(session) | ||
cleanup(session) | ||
|
||
|
||
def run_l1_int8_accuracy_tests(session): | ||
if not USE_HOST_DEPS: | ||
install_deps(session) | ||
install_torch_trt(session) | ||
download_datasets(session) | ||
train_model(session) | ||
finetune_model(session) | ||
run_int8_accuracy_tests(session) | ||
|
@@ -347,9 +342,6 @@ def run_l2_trt_compatibility_tests(session): | |
if not USE_HOST_DEPS: | ||
install_deps(session) | ||
install_torch_trt(session) | ||
download_models(session) | ||
download_datasets(session) | ||
train_model(session) | ||
run_trt_compatibility_tests(session) | ||
cleanup(session) | ||
|
||
|
@@ -376,9 +368,9 @@ def l0_dla_tests(session): | |
|
||
|
||
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) | ||
def l1_accuracy_tests(session): | ||
"""Checking accuracy performance on various usecases""" | ||
run_l1_accuracy_tests(session) | ||
def l1_model_tests(session): | ||
"""When a user needs to test the functionality of standard models compilation and results""" | ||
run_l1_model_tests(session) | ||
|
||
|
||
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) | ||
|
@@ -397,13 +389,3 @@ def l2_trt_compatibility_tests(session): | |
def l2_multi_gpu_tests(session): | ||
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems""" | ||
run_l2_multi_gpu_tests(session) | ||
|
||
|
||
@nox.session(python=SUPPORTED_PYTHON_VERSIONS, reuse_venv=True) | ||
def download_test_models(session): | ||
"""Grab all the models needed for testing""" | ||
try: | ||
import torch | ||
except ModuleNotFoundError: | ||
install_deps(session) | ||
download_models(session) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -225,7 +225,7 @@ def _parse_input_signature(input_signature: Any): | |
|
||
|
||
def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: | ||
# TODO: Remove deep copy once collections does not need partial compilation | ||
# TODO: Use deepcopy to support partial compilation of collections | ||
compile_spec = deepcopy(compile_spec_) | ||
info = _ts_C.CompileSpec() | ||
|
||
|
@@ -301,7 +301,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: | |
compile_spec["enabled_precisions"] | ||
) | ||
|
||
if "calibrator" in compile_spec: | ||
if "calibrator" in compile_spec and compile_spec["calibrator"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed for 1.2 to work |
||
info.ptq_calibrator = compile_spec["calibrator"] | ||
|
||
if "sparse_weights" in compile_spec: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peri044: Why are we getting rid of download_datasets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't being used. The torch.datasets.CIFAR10 https://github.com/pytorch/TensorRT/blob/master/examples/int8/training/vgg16/main.py#L89-L103 is downloading the data automatically and the tests are basically downloading the data twice before.