Skip to content

Commit 17753fc

Browse files
committed
feat: Add testing functionality to Python API utils
- Add function to check the equivalence of two collection-based outputs for comparison across Torch-TRT and Torch outputs - Improved test robustness in end-to-end to check for equivalent output schemas in addition to successful compilation
1 parent 655ce22 commit 17753fc

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

core/compiler.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
196196
// for collections processing
197197
if (expect_full_compilation) {
198198
for (auto torch_node : seg_block.block()->nodes()) {
199-
if (partitioning::CollectionSchemas.find(torch_node->kind().toQualString()) ==
200-
partitioning::CollectionSchemas.end()) {
201-
LOG_WARNING(
199+
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
200+
LOG_ERROR(
202201
"Full compilation specified but node " << torch_node->kind().toQualString()
203202
<< " was executed in Torch.");
204203
}
@@ -210,7 +209,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
210209
// If full compilation is expected, cannot have more than 2 Torch segments
211210
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
212211
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
213-
LOG_WARNING(
212+
LOG_ERROR(
214213
"Full compilation specified but number of torch segments was "
215214
<< num_torch_segments << " and number of trt segments was " << num_trt_segments
216215
<< ". Was expecting at most 2 Torch segments and 1 TRT segment.");

core/partitioning/partitioning.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::
2121
// Set of schemas allowed to be executed in Torch, even with require_full_compilation=true,
2222
// as necessary for returning collections of Tensors or other complex constructs, and for
2323
// processing inputs to TRT engines
24-
const std::unordered_set<std::string> CollectionSchemas = {
25-
"prim::Constant",
26-
"aten::__getitem__",
27-
"prim::ListConstruct",
28-
"prim::ListUnpack",
29-
"prim::TupleIndex",
30-
"prim::TupleConstruct",
31-
"prim::TupleUnpack",
24+
const std::unordered_set<c10::Symbol> CollectionNodeKinds = {
25+
c10::Symbol::fromQualString("prim::Constant"),
26+
c10::Symbol::fromQualString("aten::__getitem__"),
27+
c10::Symbol::fromQualString("prim::ListConstruct"),
28+
c10::Symbol::fromQualString("prim::ListUnpack"),
29+
c10::Symbol::fromQualString("prim::TupleIndex"),
30+
c10::Symbol::fromQualString("prim::TupleConstruct"),
31+
c10::Symbol::fromQualString("prim::TupleUnpack"),
3232
};
3333

3434
ExampleIValues generateRandomInputs(

tests/py/api/test_e2e_behavior.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchvision.models as models
55
import copy
66
from typing import Dict
7+
from utils import same_output_format
78

89

910
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
@@ -109,7 +110,7 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
109110
)
110111
trt_mod(self.input)
111112

112-
def test_nested_tuple_output_with_full_compilation(self):
113+
def test_nested_combination_tuple_list_output_with_full_compilation(self):
113114
class Sample(torch.nn.Module):
114115
def __init__(self):
115116
super(Sample, self).__init__()
@@ -119,7 +120,7 @@ def forward(self, x, y, z):
119120
b = x + 2.0 * z
120121
b = y + b
121122
a = b + c
122-
return (a, (b, c))
123+
return (a, [b, c])
123124

124125
self.model = Sample().eval().to("cuda")
125126
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
@@ -139,7 +140,11 @@ def forward(self, x, y, z):
139140
require_full_compilation=True,
140141
enabled_precisions={torch.float, torch.half},
141142
)
142-
trt_mod(self.input_1, self.input_2, self.input_3)
143+
trt_output = trt_mod(self.input_1, self.input_2, self.input_3)
144+
torch_output = self.model(self.input_1, self.input_2, self.input_3)
145+
assert same_output_format(
146+
trt_output, torch_output
147+
), "Found differing output formatting between Torch-TRT and Torch"
143148

144149

145150
if __name__ == "__main__":

tests/py/api/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,42 @@ def cosine_similarity(gt_tensor, pred_tensor):
1313
res = res.cpu().detach().item()
1414

1515
return res
16+
17+
18+
def same_output_format(trt_output, torch_output):
19+
# For each encountered collection type, ensure the torch and trt outputs agree
20+
# on type and size, checking recursively through all member elements.
21+
if isinstance(trt_output, tuple):
22+
return (
23+
isinstance(torch_output, tuple)
24+
and (len(trt_output) == len(torch_output))
25+
and all(
26+
same_output_format(trt_entry, torch_entry)
27+
for trt_entry, torch_entry in zip(trt_output, torch_output)
28+
)
29+
)
30+
elif isinstance(trt_output, list):
31+
return (
32+
isinstance(torch_output, list)
33+
and (len(trt_output) == len(torch_output))
34+
and all(
35+
same_output_format(trt_entry, torch_entry)
36+
for trt_entry, torch_entry in zip(trt_output, torch_output)
37+
)
38+
)
39+
elif isinstance(trt_output, dict):
40+
return (
41+
isinstance(torch_output, dict)
42+
and (len(trt_output) == len(torch_output))
43+
and (trt_output.keys() == torch_output.keys())
44+
and all(
45+
same_output_format(trt_output[key], torch_output[key])
46+
for key in trt_output.keys()
47+
)
48+
)
49+
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
50+
raise AssertionError(
51+
"Unsupported output type 'set' encountered in output format check."
52+
)
53+
else:
54+
return type(trt_output) is type(torch_output)

0 commit comments

Comments
 (0)