Skip to content

Commit ef828ad

Browse files
committed
feat: Add testing functionality to Python API utils
- Add function to check the equivalence of two collection-based outputs for comparison across Torch-TRT and Torch outputs - Improved test robustness in end-to-end to check for equivalent output schemas in addition to successful compilation
1 parent 655ce22 commit ef828ad

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

tests/py/api/test_e2e_behavior.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchvision.models as models
55
import copy
66
from typing import Dict
7+
from utils import same_output_format
78

89

910
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
@@ -109,7 +110,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
109110
)
110111
trt_mod(self.input)
111112

112-
def test_nested_tuple_output_with_full_compilation(self):
113+
def test_nested_combination_tuple_list_output_with_full_compilation(self):
113114
class Sample(torch.nn.Module):
114115
def __init__(self):
115116
super(Sample, self).__init__()
@@ -119,7 +120,7 @@ def forward(self, x, y, z):
119120
b = x + 2.0 * z
120121
b = y + b
121122
a = b + c
122-
return (a, (b, c))
123+
return (a, [b, c])
123124

124125
self.model = Sample().eval().to("cuda")
125126
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
@@ -139,7 +140,11 @@ def forward(self, x, y, z):
139140
require_full_compilation=True,
140141
enabled_precisions={torch.float, torch.half},
141142
)
142-
trt_mod(self.input_1, self.input_2, self.input_3)
143+
trt_output = trt_mod(self.input_1, self.input_2, self.input_3)
144+
torch_output = self.model(self.input_1, self.input_2, self.input_3)
145+
assert same_output_format(
146+
trt_output, torch_output
147+
), "Found differing output formatting between Torch-TRT and Torch"
143148

144149

145150
if __name__ == "__main__":

tests/py/api/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,42 @@ def cosine_similarity(gt_tensor, pred_tensor):
1313
res = res.cpu().detach().item()
1414

1515
return res
16+
17+
18+
def same_output_format(trt_output, torch_output):
19+
# For each encountered collection type, ensure the torch and trt outputs agree
20+
# on type and size, checking recursively through all member elements.
21+
if isinstance(trt_output, tuple):
22+
return (
23+
isinstance(torch_output, tuple)
24+
and (len(trt_output) == len(torch_output))
25+
and all(
26+
same_output_format(trt_entry, torch_entry)
27+
for trt_entry, torch_entry in zip(trt_output, torch_output)
28+
)
29+
)
30+
elif isinstance(trt_output, list):
31+
return (
32+
isinstance(torch_output, list)
33+
and (len(trt_output) == len(torch_output))
34+
and all(
35+
same_output_format(trt_entry, torch_entry)
36+
for trt_entry, torch_entry in zip(trt_output, torch_output)
37+
)
38+
)
39+
elif isinstance(trt_output, dict):
40+
return (
41+
isinstance(torch_output, dict)
42+
and (len(trt_output) == len(torch_output))
43+
and (trt_output.keys() == torch_output.keys())
44+
and all(
45+
same_output_format(trt_output[key], torch_output[key])
46+
for key in trt_output.keys()
47+
)
48+
)
49+
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
50+
raise AssertionError(
51+
"Unsupported output type 'set' encountered in output format check."
52+
)
53+
else:
54+
return type(trt_output) is type(torch_output)

0 commit comments

Comments
 (0)